Skip to content

Commit ab36ad8

Browse files
committed
explicit sharding to fix scan layers False issue
1 parent a2cd284 commit ab36ad8

1 file changed

Lines changed: 4 additions & 9 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,17 +1010,12 @@ def scan_fn(carry, block):
10101010
transform_metadata={nnx.PARTITION_NAME: "layers"},
10111011
)(carry, self.transformer_blocks)
10121012
else:
1013+
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1014+
10131015
for i, block in enumerate(self.transformer_blocks):
10141016
with jax.named_scope(f"Transformer Block {i}"):
1015-
graphdef, state = nnx.split(block)
1016-
1017-
def _apply_sharding(x):
1018-
if hasattr(x, "sharding") and x.sharding is not None:
1019-
return jax.lax.with_sharding_constraint(x, x.sharding)
1020-
return x
1021-
1022-
state = jax.tree_util.tree_map(_apply_sharding, state)
1023-
nnx.update(block, state)
1017+
hidden_states = jax.lax.with_sharding_constraint(hidden_states, activation_axis_names)
1018+
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, activation_axis_names)
10241019

10251020
hidden_states, audio_hidden_states = block(
10261021
hidden_states=hidden_states,

0 commit comments

Comments
 (0)