Skip to content

Commit 6760ab9

Browse files
committed
Switch WAN configs to tensor-parallel and fix activation constraints
Priority 1: Switch default parallelism from context-parallel to tensor-parallel - Changed ici_tensor_parallelism from 1 to -1 (auto) and ici_context_parallelism from -1 to 1 across all WAN T2V configs (14B, 1.3B, 27B/A14B). - Updated logical_axis_rules to match TP strategy: - 'embed' axis: ['context', 'fsdp'] -> ['fsdp', 'tensor'] so QKV/FFN input dimensions are properly sharded across TP devices. - 'activation_self_attn_heads': ['context', 'tensor'] -> 'tensor' (pure TP head sharding for self-attention splash kernel). - 'activation_cross_attn_q_length': ['context', 'tensor'] -> 'tensor' (Q sequence sharding for cross-attention splash kernel). - 'activation_length': 'context' -> None (no sequence sharding in TP mode). - Conv axes updated to use 'tensor' instead of 'context'. This aligns with the reference torchax benchmark which uses pure TP and achieves ~1.8x faster inference than context-parallel mode. Priority 4: Fix activation constraints for TP compatibility - WanTransformerBlock.__call__: Changed hidden_states constraint from ('activation_batch', 'activation_length', 'activation_heads') to ('activation_batch', 'activation_length', None). - FlaxWanAttention.__call__: Changed constraint from (BATCH, LENGTH, HEAD) to (BATCH, LENGTH, None). The model dim of hidden_states between blocks should remain replicated (not sharded on tensor axis) because column-parallel QKV projections expect replicated input. The old constraint forced the model dim onto the tensor axis, which in TP mode caused an unnecessary all-scatter before every attention block. In context-parallel mode, tensor=1 made this a no-op, so the change is backwards-compatible.
1 parent ceca471 commit 6760ab9

5 files changed

Lines changed: 31 additions & 27 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,20 +165,24 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
165165
# keep_2 : conv.shape[1] weight
166166
# conv_in : conv.shape[2] weight
167167
# conv_out : conv.shape[-1] weight
168+
#
169+
# Default: Tensor Parallel (TP) mode — shards QKV/FFN weights across tensor axis.
170+
# For context-parallel mode (sequence sharding), set ici_context_parallelism: -1,
171+
# ici_tensor_parallelism: 1, and swap embed to ['context', 'fsdp'].
168172
logical_axis_rules: [
169173
['batch', ['data', 'fsdp']],
170174
['activation_batch', ['data', 'fsdp']],
171-
['activation_self_attn_heads', ['context', 'tensor']],
172-
['activation_cross_attn_q_length', ['context', 'tensor']],
173-
['activation_length', 'context'],
175+
['activation_self_attn_heads', 'tensor'],
176+
['activation_cross_attn_q_length', 'tensor'],
177+
['activation_length', None],
174178
['activation_heads', 'tensor'],
175179
['mlp','tensor'],
176-
['embed', ['context', 'fsdp']],
180+
['embed', ['fsdp', 'tensor']],
177181
['heads', 'tensor'],
178182
['norm', 'tensor'],
179-
['conv_batch', ['data', 'context', 'fsdp']],
183+
['conv_batch', ['data', 'fsdp']],
180184
['out_channels', 'tensor'],
181-
['conv_out', 'context'],
185+
['conv_out', 'tensor'],
182186
]
183187
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
184188

@@ -192,8 +196,8 @@ dcn_context_parallelism: -1
192196
dcn_tensor_parallelism: 1
193197
ici_data_parallelism: 1
194198
ici_fsdp_parallelism: 1
195-
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
196-
ici_tensor_parallelism: 1
199+
ici_context_parallelism: 1
200+
ici_tensor_parallelism: -1 # recommended ICI axis to be auto-sharded
197201

198202
allow_split_physical_axes: False
199203

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Google LLC
1+
# Copyright 2023 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -145,17 +145,17 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
145145
logical_axis_rules: [
146146
['batch', ['data', 'fsdp']],
147147
['activation_batch', ['data', 'fsdp']],
148-
['activation_self_attn_heads', ['context', 'tensor']],
149-
['activation_cross_attn_q_length', ['context', 'tensor']],
150-
['activation_length', 'context'],
148+
['activation_self_attn_heads', 'tensor'],
149+
['activation_cross_attn_q_length', 'tensor'],
150+
['activation_length', None],
151151
['activation_heads', 'tensor'],
152152
['mlp','tensor'],
153-
['embed', ['context', 'fsdp']],
153+
['embed', ['fsdp', 'tensor']],
154154
['heads', 'tensor'],
155155
['norm', 'tensor'],
156-
['conv_batch', ['data', 'context', 'fsdp']],
156+
['conv_batch', ['data', 'fsdp']],
157157
['out_channels', 'tensor'],
158-
['conv_out', 'context'],
158+
['conv_out', 'tensor'],
159159
]
160160
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
161161

@@ -169,8 +169,8 @@ dcn_context_parallelism: -1
169169
dcn_tensor_parallelism: 1
170170
ici_data_parallelism: 1
171171
ici_fsdp_parallelism: 1
172-
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
173-
ici_tensor_parallelism: 1
172+
ici_context_parallelism: 1
173+
ici_tensor_parallelism: -1 # recommended ICI axis to be auto-sharded
174174

175175
allow_split_physical_axes: False
176176

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,17 +156,17 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
156156
logical_axis_rules: [
157157
['batch', ['data', 'fsdp']],
158158
['activation_batch', ['data', 'fsdp']],
159-
['activation_self_attn_heads', ['context', 'tensor']],
160-
['activation_cross_attn_q_length', ['context', 'tensor']],
161-
['activation_length', 'context'],
159+
['activation_self_attn_heads', 'tensor'],
160+
['activation_cross_attn_q_length', 'tensor'],
161+
['activation_length', None],
162162
['activation_heads', 'tensor'],
163163
['mlp','tensor'],
164-
['embed', ['context', 'fsdp']],
164+
['embed', ['fsdp', 'tensor']],
165165
['heads', 'tensor'],
166166
['norm', 'tensor'],
167-
['conv_batch', ['data', 'context', 'fsdp']],
167+
['conv_batch', ['data', 'fsdp']],
168168
['out_channels', 'tensor'],
169-
['conv_out', 'context'],
169+
['conv_out', 'tensor'],
170170
]
171171
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
172172

@@ -180,8 +180,8 @@ dcn_context_parallelism: -1
180180
dcn_tensor_parallelism: 1
181181
ici_data_parallelism: 1
182182
ici_fsdp_parallelism: 1
183-
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
184-
ici_tensor_parallelism: 1
183+
ici_context_parallelism: 1
184+
ici_tensor_parallelism: -1 # recommended ICI axis to be auto-sharded
185185

186186
allow_split_physical_axes: False
187187

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ def __call__(
11331133
deterministic: bool = True,
11341134
rngs: nnx.Rngs = None,
11351135
) -> jax.Array:
1136-
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
1136+
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, None))
11371137
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
11381138
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
11391139
dtype = hidden_states.dtype

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def __call__(
379379
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
380380
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
381381
)
382-
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
382+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", None))
383383
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
384384
hidden_states = checkpoint_name(hidden_states, "hidden_states")
385385
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv"))

0 commit comments

Comments
 (0)