Skip to content

Commit d71aa8d

Browse files
committed
fix scan layers False issue
1 parent ab36ad8 commit d71aa8d

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,12 +1010,24 @@ def scan_fn(carry, block):
10101010
transform_metadata={nnx.PARTITION_NAME: "layers"},
10111011
)(carry, self.transformer_blocks)
10121012
else:
1013-
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1013+
mlp_rules = nn.logical_to_mesh_axes(("mlp", "tensor"))
1014+
tensor_rules = nn.logical_to_mesh_axes(("tensor",))
10141015

10151016
for i, block in enumerate(self.transformer_blocks):
10161017
with jax.named_scope(f"Transformer Block {i}"):
1017-
hidden_states = jax.lax.with_sharding_constraint(hidden_states, activation_axis_names)
1018-
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, activation_axis_names)
1018+
graphdef, state = nnx.split(block)
1019+
1020+
def _apply_weight_sharding(path, x):
1021+
path_str = "/".join(getattr(p, "name", getattr(p, "key", str(p))) for p in path)
1022+
if "kernel" in path_str:
1023+
if "ff" in path_str:
1024+
return jax.lax.with_sharding_constraint(x, mlp_rules)
1025+
elif "attn" in path_str:
1026+
return jax.lax.with_sharding_constraint(x, tensor_rules)
1027+
return x
1028+
1029+
state = jax.tree_util.tree_map_with_path(_apply_weight_sharding, state)
1030+
nnx.update(block, state)
10191031

10201032
hidden_states, audio_hidden_states = block(
10211033
hidden_states=hidden_states,

0 commit comments

Comments
 (0)