Skip to content

Commit a2cd284

Browse files
committed
fix for scan layers = false case
1 parent bf161f9 commit a2cd284

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,16 @@ def scan_fn(carry, block):
10121012
else:
10131013
for i, block in enumerate(self.transformer_blocks):
10141014
with jax.named_scope(f"Transformer Block {i}"):
1015+
graphdef, state = nnx.split(block)
1016+
1017+
def _apply_sharding(x):
1018+
if hasattr(x, "sharding") and x.sharding is not None:
1019+
return jax.lax.with_sharding_constraint(x, x.sharding)
1020+
return x
1021+
1022+
state = jax.tree_util.tree_map(_apply_sharding, state)
1023+
nnx.update(block, state)
1024+
10151025
hidden_states, audio_hidden_states = block(
10161026
hidden_states=hidden_states,
10171027
audio_hidden_states=audio_hidden_states,

0 commit comments

Comments
 (0)