Skip to content

Commit 073f703

Browse files
committed
force sharding on FFN intermediate activation
1 parent 2c361e4 commit 073f703

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: n
250250
jax.debug.print(f"MLP input shape: {{shape}}", shape=hidden_states.shape)
251251
jax.debug.inspect_array_sharding(hidden_states, callback=print)
252252
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
253+
# Add logical constraint to ensure batch dimension is properly sharded
254+
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "mlp"))
253255
jax.debug.print(f"MLP intermediate activation shape: {{shape}}", shape=hidden_states.shape)
254256
jax.debug.inspect_array_sharding(hidden_states, callback=print)
255257
hidden_states = checkpoint_name(hidden_states, "ffn_activation")

0 commit comments

Comments
 (0)