Skip to content

Commit 755e1b2

Browse files
committed
fixes missing mesh issue and rope reshape error
1 parent 0dd14bb commit 755e1b2

2 files changed

Lines changed: 13 additions & 6 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
230230
# We want [B, 3, N] / [3] (broadcasting over B, N)
231231
# JAX broadcasting: last dims match? No.
232232
# reshape max_positions to [1, 3, 1]
233+
max_positions = max_positions[:num_pos_dims]
233234
max_positions = max_positions.reshape(1, num_pos_dims, 1)
234235
grid = coords / max_positions
235236

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def __init__(
108108
bias=attention_bias,
109109
out_bias=attention_out_bias,
110110
eps=norm_eps,
111-
dtype=dtype
111+
dtype=dtype,
112+
mesh=mesh
112113
)
113114

114115
self.audio_norm1 = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
@@ -121,7 +122,8 @@ def __init__(
121122
bias=attention_bias,
122123
out_bias=attention_out_bias,
123124
eps=norm_eps,
124-
dtype=dtype
125+
dtype=dtype,
126+
mesh=mesh
125127
)
126128

127129
# 2. Prompt Cross-Attention
@@ -136,7 +138,8 @@ def __init__(
136138
bias=attention_bias,
137139
out_bias=attention_out_bias,
138140
eps=norm_eps,
139-
dtype=dtype
141+
dtype=dtype,
142+
mesh=mesh
140143
)
141144

142145
self.audio_norm2 = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
@@ -150,7 +153,8 @@ def __init__(
150153
bias=attention_bias,
151154
out_bias=attention_out_bias,
152155
eps=norm_eps,
153-
dtype=dtype
156+
dtype=dtype,
157+
mesh=mesh
154158
)
155159

156160
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -165,7 +169,8 @@ def __init__(
165169
bias=attention_bias,
166170
out_bias=attention_out_bias,
167171
eps=norm_eps,
168-
dtype=dtype
172+
dtype=dtype,
173+
mesh=mesh
169174
)
170175

171176
self.video_to_audio_norm = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
@@ -179,7 +184,8 @@ def __init__(
179184
bias=attention_bias,
180185
out_bias=attention_out_bias,
181186
eps=norm_eps,
182-
dtype=dtype
187+
dtype=dtype,
188+
mesh=mesh
183189
)
184190

185191
# 4. Feed Forward

0 commit comments

Comments
 (0)