You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Switch WAN configs to tensor-parallel and fix activation constraints
Priority 1: Switch default parallelism from context-parallel to tensor-parallel
- Changed ici_tensor_parallelism from 1 to -1 (auto) and
ici_context_parallelism from -1 to 1 across all WAN T2V configs
(14B, 1.3B, 27B/A14B).
- Updated logical_axis_rules to match TP strategy:
- 'embed' axis: ['context', 'fsdp'] -> ['fsdp', 'tensor'] so QKV/FFN
input dimensions are properly sharded across TP devices.
- 'activation_self_attn_heads': ['context', 'tensor'] -> 'tensor'
(pure TP head sharding for self-attention splash kernel).
- 'activation_cross_attn_q_length': ['context', 'tensor'] -> 'tensor'
(Q sequence sharding for cross-attention splash kernel).
- 'activation_length': 'context' -> None (no sequence sharding in TP mode).
- Conv axes updated to use 'tensor' instead of 'context'.
This aligns with the reference torchax benchmark which uses pure TP and
achieves ~1.8x faster inference than context-parallel mode.
Priority 4: Fix activation constraints for TP compatibility
- WanTransformerBlock.__call__: Changed hidden_states constraint from
('activation_batch', 'activation_length', 'activation_heads') to
('activation_batch', 'activation_length', None).
- FlaxWanAttention.__call__: Changed constraint from (BATCH, LENGTH, HEAD)
to (BATCH, LENGTH, None).
The model dim of hidden_states between blocks should remain replicated
(not sharded on tensor axis) because column-parallel QKV projections
expect replicated input. The old constraint forced the model dim onto
the tensor axis, which in TP mode caused an unnecessary all-scatter
before every attention block. In context-parallel mode, tensor=1 made
this a no-op, so the change is backwards-compatible.
0 commit comments