File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1010,17 +1010,12 @@ def scan_fn(carry, block):
10101010 transform_metadata = {nnx .PARTITION_NAME : "layers" },
10111011 )(carry , self .transformer_blocks )
10121012 else :
1013+ activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
1014+
10131015 for i , block in enumerate (self .transformer_blocks ):
10141016 with jax .named_scope (f"Transformer Block { i } " ):
1015- graphdef , state = nnx .split (block )
1016-
1017- def _apply_sharding (x ):
1018- if hasattr (x , "sharding" ) and x .sharding is not None :
1019- return jax .lax .with_sharding_constraint (x , x .sharding )
1020- return x
1021-
1022- state = jax .tree_util .tree_map (_apply_sharding , state )
1023- nnx .update (block , state )
1017+ hidden_states = jax .lax .with_sharding_constraint (hidden_states , activation_axis_names )
1018+ audio_hidden_states = jax .lax .with_sharding_constraint (audio_hidden_states , activation_axis_names )
10241019
10251020 hidden_states , audio_hidden_states = block (
10261021 hidden_states = hidden_states ,
You can’t perform that action at this time.
0 commit comments