Skip to content

Commit 3391a5e

Browse files
committed
cleanup
1 parent 22036e6 commit 3391a5e

1 file changed

Lines changed: 1 addition & 18 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def __call__(
971971
# 5. Run transformer blocks
972972
def scan_fn(carry, block):
973973
hidden_states, audio_hidden_states, rngs_carry = carry
974-
with jax.named_scope("Transformer Block i"):
974+
with jax.named_scope("Transformer Layer"):
975975
hidden_states_out, audio_hidden_states_out = block(
976976
hidden_states=hidden_states,
977977
audio_hidden_states=audio_hidden_states,
@@ -1010,25 +1010,8 @@ def scan_fn(carry, block):
10101010
transform_metadata={nnx.PARTITION_NAME: "layers"},
10111011
)(carry, self.transformer_blocks)
10121012
else:
1013-
mlp_rules = nn.logical_to_mesh_axes(("mlp", "tensor"))
1014-
tensor_rules = nn.logical_to_mesh_axes(("tensor",))
1015-
10161013
for i, block in enumerate(self.transformer_blocks):
10171014
with jax.named_scope(f"Transformer Block {i}"):
1018-
graphdef, state = nnx.split(block)
1019-
1020-
def _apply_weight_sharding(path, x):
1021-
path_str = "/".join(getattr(p, "name", getattr(p, "key", str(p))) for p in path)
1022-
if "kernel" in path_str:
1023-
if "ff" in path_str:
1024-
return jax.lax.with_sharding_constraint(x, mlp_rules)
1025-
elif "attn" in path_str:
1026-
return jax.lax.with_sharding_constraint(x, tensor_rules)
1027-
return x
1028-
1029-
state = jax.tree_util.tree_map_with_path(_apply_weight_sharding, state)
1030-
nnx.update(block, state)
1031-
10321015
hidden_states, audio_hidden_states = block(
10331016
hidden_states=hidden_states,
10341017
audio_hidden_states=audio_hidden_states,

0 commit comments

Comments
 (0)