Skip to content

Commit b77938d

Browse files
committed
added attention_kernel param
1 parent a25129b commit b77938d

2 files changed

Lines changed: 76 additions & 6 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,12 @@ def __init__(
9292
precision: jax.lax.Precision = None,
9393
names_which_can_be_saved: list = [],
9494
names_which_can_be_offloaded: list = [],
95+
attention_kernel: str = "flash",
9596
):
9697
self.dim = dim
9798
self.norm_eps = norm_eps
9899
self.norm_elementwise_affine = norm_elementwise_affine
100+
self.attention_kernel = attention_kernel
99101

100102
# 1. Self-Attention (video and audio)
101103
self.norm1 = nnx.RMSNorm(self.dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
@@ -109,7 +111,8 @@ def __init__(
109111
out_bias=attention_out_bias,
110112
eps=norm_eps,
111113
dtype=dtype,
112-
mesh=mesh
114+
mesh=mesh,
115+
attention_kernel=self.attention_kernel
113116
)
114117

115118
self.audio_norm1 = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
@@ -123,7 +126,8 @@ def __init__(
123126
out_bias=attention_out_bias,
124127
eps=norm_eps,
125128
dtype=dtype,
126-
mesh=mesh
129+
mesh=mesh,
130+
attention_kernel=self.attention_kernel
127131
)
128132

129133
# 2. Prompt Cross-Attention
@@ -139,7 +143,8 @@ def __init__(
139143
out_bias=attention_out_bias,
140144
eps=norm_eps,
141145
dtype=dtype,
142-
mesh=mesh
146+
mesh=mesh,
147+
attention_kernel=self.attention_kernel
143148
)
144149

145150
self.audio_norm2 = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
@@ -154,7 +159,8 @@ def __init__(
154159
out_bias=attention_out_bias,
155160
eps=norm_eps,
156161
dtype=dtype,
157-
mesh=mesh
162+
mesh=mesh,
163+
attention_kernel=self.attention_kernel
158164
)
159165

160166
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -170,7 +176,8 @@ def __init__(
170176
out_bias=attention_out_bias,
171177
eps=norm_eps,
172178
dtype=dtype,
173-
mesh=mesh
179+
mesh=mesh,
180+
attention_kernel=self.attention_kernel
174181
)
175182

176183
self.video_to_audio_norm = nnx.RMSNorm(audio_dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype)
@@ -185,7 +192,8 @@ def __init__(
185192
out_bias=attention_out_bias,
186193
eps=norm_eps,
187194
dtype=dtype,
188-
mesh=mesh
195+
mesh=mesh,
196+
attention_kernel=self.attention_kernel
189197
)
190198

191199
# 4. Feed Forward
@@ -523,6 +531,7 @@ def __init__(
523531
names_which_can_be_saved: list = [],
524532
names_which_can_be_offloaded: list = [],
525533
scan_layers: bool = True,
534+
attention_kernel: str = "flash",
526535
):
527536
self.in_channels = in_channels
528537
self.out_channels = out_channels
@@ -568,6 +577,7 @@ def __init__(
568577
self.names_which_can_be_saved = names_which_can_be_saved
569578
self.names_which_can_be_offloaded = names_which_can_be_offloaded
570579
self.scan_layers = scan_layers
580+
self.attention_kernel = attention_kernel
571581

572582
_out_channels = self.out_channels or self.in_channels
573583
_audio_out_channels = self.audio_out_channels or self.audio_in_channels
@@ -723,6 +733,7 @@ def init_block(rngs):
723733
precision=self.precision,
724734
names_which_can_be_saved=self.names_which_can_be_saved,
725735
names_which_can_be_offloaded=self.names_which_can_be_offloaded,
736+
attention_kernel=self.attention_kernel,
726737
)
727738

728739
if self.scan_layers:
@@ -754,6 +765,7 @@ def init_block(rngs):
754765
precision=self.precision,
755766
names_which_can_be_saved=self.names_which_can_be_saved,
756767
names_which_can_be_offloaded=self.names_which_can_be_offloaded,
768+
attention_kernel=self.attention_kernel,
757769
)
758770
blocks.append(block)
759771
self.transformer_blocks = nnx.List(blocks)

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,64 @@ def test_transformer_3d_model_instantiation_and_forward(self):
258258
self.assertEqual(sample.shape, (self.batch_size, self.seq_len, self.out_channels))
259259
self.assertEqual(audio_sample.shape, (self.batch_size, 128, self.audio_in_channels))
260260

261+
def test_transformer_3d_model_dot_product_attention(self):
262+
"""Verifies LTX2VideoTransformer3DModel full instantiation and forward pass with dot_product attention."""
263+
264+
# 1. Instantiate Model with dot_product kernel
265+
model = LTX2VideoTransformer3DModel(
266+
rngs=nnx.Rngs(0),
267+
in_channels=self.in_channels,
268+
out_channels=self.out_channels,
269+
patch_size=self.patch_size,
270+
patch_size_t=self.patch_size_t,
271+
num_attention_heads=self.num_attention_heads,
272+
attention_head_dim=self.attention_head_dim,
273+
cross_attention_dim=self.cross_attention_dim,
274+
audio_in_channels=self.audio_in_channels,
275+
audio_out_channels=self.audio_out_channels,
276+
audio_patch_size=self.audio_patch_size,
277+
audio_patch_size_t=self.audio_patch_size_t,
278+
audio_num_attention_heads=self.audio_num_attention_heads,
279+
audio_attention_head_dim=self.audio_attention_head_dim,
280+
audio_cross_attention_dim=self.audio_cross_attention_dim,
281+
num_layers=1, # Reduced layers for speed
282+
config=self.config,
283+
scan_layers=False,
284+
mesh=self.mesh,
285+
attention_kernel="dot_product"
286+
)
287+
288+
# 2. Inputs
289+
hidden_states = jnp.ones((self.batch_size, self.seq_len, self.in_channels)) * 0.5
290+
audio_hidden_states = jnp.ones((self.batch_size, 128, self.audio_in_channels)) * 0.5
291+
timestep = jnp.array([1.0]) # (B,)
292+
293+
encoder_hidden_states = jnp.zeros((self.batch_size, 128, 32)) # (B, Lc, Dc)
294+
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 128, 32))
295+
encoder_attention_mask = jnp.ones((self.batch_size, 128), dtype=jnp.float32)
296+
audio_encoder_attention_mask = jnp.ones((self.batch_size, 128), dtype=jnp.float32)
297+
298+
# Forward
299+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
300+
output = model(
301+
hidden_states=hidden_states,
302+
audio_hidden_states=audio_hidden_states,
303+
encoder_hidden_states=encoder_hidden_states,
304+
audio_encoder_hidden_states=audio_encoder_hidden_states,
305+
timestep=timestep,
306+
num_frames=self.num_frames,
307+
height=self.height,
308+
width=self.width,
309+
audio_num_frames=128,
310+
fps=24.0,
311+
return_dict=True,
312+
encoder_attention_mask=encoder_attention_mask,
313+
audio_encoder_attention_mask=audio_encoder_attention_mask
314+
)
315+
316+
self.assertEqual(output.sample.shape, hidden_states.shape)
317+
self.assertEqual(output.audio_sample.shape, audio_hidden_states.shape)
318+
261319
def test_scan_remat_parity(self):
262320
"""
263321
Verifies that scan_layers=True produces identical output to scan_layers=False.

0 commit comments

Comments
 (0)