We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent bf161f9 commit a2cd284Copy full SHA for a2cd284
1 file changed
src/maxdiffusion/models/ltx2/transformer_ltx2.py
@@ -1012,6 +1012,16 @@ def scan_fn(carry, block):
1012
else:
1013
for i, block in enumerate(self.transformer_blocks):
1014
with jax.named_scope(f"Transformer Block {i}"):
1015
+ graphdef, state = nnx.split(block)
1016
+
1017
+ def _apply_sharding(x):
1018
+ if hasattr(x, "sharding") and x.sharding is not None:
1019
+ return jax.lax.with_sharding_constraint(x, x.sharding)
1020
+ return x
1021
1022
+ state = jax.tree_util.tree_map(_apply_sharding, state)
1023
+ nnx.update(block, state)
1024
1025
hidden_states, audio_hidden_states = block(
1026
hidden_states=hidden_states,
1027
audio_hidden_states=audio_hidden_states,
0 commit comments