Skip to content

Commit 691c4b7

Browse files
committed
testing flash with audio_to_video attn
1 parent 449d78c commit 691c4b7

1 file changed

Lines changed: 1 addition & 4 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
eps=norm_eps,
240240
dtype=dtype,
241241
mesh=mesh,
242-
attention_kernel="dot_product",
242+
attention_kernel=self.attention_kernel,
243243
rope_type=rope_type,
244244
flash_block_sizes=flash_block_sizes,
245245
)
@@ -353,9 +353,6 @@ def __call__(
353353
axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed"))
354354
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names_audio)
355355

356-
jax.debug.print("[LTX2] Video hidden states shape: {shape}", shape=hidden_states.shape)
357-
jax.debug.print("[LTX2] Audio hidden states shape: {shape}", shape=audio_hidden_states.shape)
358-
359356
if encoder_hidden_states is not None:
360357
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
361358
if audio_encoder_hidden_states is not None:

0 commit comments

Comments
 (0)