File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -159,7 +159,7 @@ def __init__(
159159 eps = norm_eps ,
160160 dtype = dtype ,
161161 mesh = mesh ,
162- attention_kernel = self . attention_kernel ,
162+ attention_kernel = "dot_product" ,
163163 rope_type = rope_type ,
164164 flash_block_sizes = flash_block_sizes ,
165165 )
@@ -212,7 +212,7 @@ def __init__(
212212 eps = norm_eps ,
213213 dtype = dtype ,
214214 mesh = mesh ,
215- attention_kernel = self . attention_kernel ,
215+ attention_kernel = "dot_product" ,
216216 rope_type = rope_type ,
217217 flash_block_sizes = flash_block_sizes ,
218218 )
@@ -350,7 +350,8 @@ def __call__(
350350
351351 axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
352352 hidden_states = jax .lax .with_sharding_constraint (hidden_states , axis_names )
353- audio_hidden_states = jax .lax .with_sharding_constraint (audio_hidden_states , axis_names )
353+ axis_names_audio = nn .logical_to_mesh_axes (("activation_batch" , None , "activation_embed" ))
354+ audio_hidden_states = jax .lax .with_sharding_constraint (audio_hidden_states , axis_names_audio )
354355
355356 if encoder_hidden_states is not None :
356357 encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , axis_names )
You can’t perform that action at this time.
0 commit comments