File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -239,7 +239,7 @@ def __init__(
239239 eps = norm_eps ,
240240 dtype = dtype ,
241241 mesh = mesh ,
242- attention_kernel = "dot_product" ,
242+ attention_kernel = self . attention_kernel ,
243243 rope_type = rope_type ,
244244 flash_block_sizes = flash_block_sizes ,
245245 )
@@ -353,9 +353,6 @@ def __call__(
353353 axis_names_audio = nn .logical_to_mesh_axes (("activation_batch" , None , "activation_embed" ))
354354 audio_hidden_states = jax .lax .with_sharding_constraint (audio_hidden_states , axis_names_audio )
355355
356- jax .debug .print ("[LTX2] Video hidden states shape: {shape}" , shape = hidden_states .shape )
357- jax .debug .print ("[LTX2] Audio hidden states shape: {shape}" , shape = audio_hidden_states .shape )
358-
359356 if encoder_hidden_states is not None :
360357 encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , axis_names )
361358 if audio_encoder_hidden_states is not None :
You can’t perform that action at this time.
0 commit comments