Skip to content

Commit 49f2a47

Browse files
committed
change in embeddings_connector_ltx2.py
1 parent daf33a0 commit 49f2a47

2 files changed

Lines changed: 11 additions & 14 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,6 @@ def rename_for_ltx2_connector(key):
415415
key = key.replace(".weight", ".scale")
416416
else:
417417
key = key.replace(".weight", ".kernel")
418-
419-
if "learnable_registers" in key and not key.endswith(".value"):
420-
key = key + ".value"
421418

422419
return key
423420

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,10 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
132132

133133
num_duplications = t // self.num_learnable_registers
134134
registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1))
135-
registers = jnp.expand_dims(registers, 0)
136-
137-
if attention_mask.ndim == 2:
135+
136+
if attention_mask.ndim == 4:
137+
mask = attention_mask.squeeze(1).squeeze(1)
138+
elif attention_mask.ndim == 2:
138139
mask = attention_mask
139140
else:
140141
mask = attention_mask.squeeze(-1) # [B, T]
@@ -154,16 +155,15 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
154155
shifted_hidden_states = jnp.zeros_like(hidden_states)
155156
shifted_hidden_states = shifted_hidden_states.at[b_idx, target_indices, :].set(hidden_states)
156157

157-
# Shift mask
158-
shifted_mask = jnp.zeros_like(curr_mask)
159-
shifted_mask = shifted_mask.at[b_idx, target_indices].set(curr_mask)
160-
161158
# 2. Add Learnable Registers
162-
# Where shifted_mask is 1, keep valid tokens. Where 0, insert registers.
163-
output = jnp.where(shifted_mask[..., None] == 1, shifted_hidden_states, registers)
159+
# Where flipped_mask is 1, keep valid tokens. Where 0, insert registers.
160+
flipped_mask = jnp.flip(curr_mask, axis=[1])
161+
flipped_mask_expanded = flipped_mask[..., None]
162+
163+
output = jnp.where(flipped_mask_expanded == 1, shifted_hidden_states, registers)
164164

165-
# Overwrite attention_mask with all-ones since padding is now filled with registers.
166-
new_mask = jnp.ones_like(attention_mask)
165+
# Overwrite attention_mask with all-zeros since padding is now filled with registers.
166+
new_mask = jnp.zeros_like(attention_mask)
167167
return output, new_mask
168168

169169
def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:

0 commit comments

Comments
 (0)