Skip to content

Commit 422b22e

Browse files
cpersson-amdentrpn
authored andcommitted
Add support for TransformerEngine flash attention in WAN (#299)
* add flash attn te support for wan * add gpu optimized sharding parallelism * sharding bugfixes * generalize across sharding parallelisms * fix issue with inference using fsdp + te flash attention * revert fsdp_tpu name change * update readme with wan2.1 gpu notes * re-order parallelism axes and revert dynamic context parallel axes selection * remove unused max_utils imports * change mesh names to more accurately reflect sharding * cleanup * fix lint errors * update configs for unit tests. --------- Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
1 parent 17507c4 commit 422b22e

28 files changed

Lines changed: 236 additions & 151 deletions

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ After installation completes, run the training script.
255255
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
256256
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
257257
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
258+
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
259+
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
260+
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.
258261

259262
You should eventually see a training run as:
260263

src/maxdiffusion/common_types.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
# Physical axis names for device meshes.
3737
DATA = "data"
3838
FSDP = "fsdp"
39+
CONTEXT = "context"
3940
TENSOR = "tensor"
4041
# Logical axis names for model parameters and activations.
4142
BATCH = "activation_batch"
@@ -67,18 +68,18 @@
6768
### Common axis rules for ring attention ###
6869
RING_ATTENTION_AXIS_RULES = [
6970
[SELF_ATTN_HEAD, None],
70-
[SELF_ATTN_Q_LENGTH, FSDP],
71-
[SELF_ATTN_KV_LENGTH, FSDP],
71+
[SELF_ATTN_Q_LENGTH, CONTEXT],
72+
[SELF_ATTN_KV_LENGTH, CONTEXT],
7273
[CROSS_ATTN_HEAD, None],
73-
[CROSS_ATTN_Q_LENGTH, FSDP],
74-
[CROSS_ATTN_KV_LENGTH, FSDP],
74+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
75+
[CROSS_ATTN_KV_LENGTH, CONTEXT],
7576
]
7677

7778
SEQUENCE_PARALLEL_AXIS_RULES = [
7879
[SELF_ATTN_HEAD, None],
79-
[SELF_ATTN_Q_LENGTH, FSDP],
80+
[SELF_ATTN_Q_LENGTH, CONTEXT],
8081
[SELF_ATTN_KV_LENGTH, None],
8182
[CROSS_ATTN_HEAD, None],
82-
[CROSS_ATTN_Q_LENGTH, FSDP],
83+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
8384
[CROSS_ATTN_KV_LENGTH, None],
8485
]

src/maxdiffusion/configs/base14.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ skip_jax_distributed_system: False
106106
base_output_directory: ""
107107

108108
# Parallelism
109-
mesh_axes: ['data', 'fsdp', 'tensor']
109+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
110110

111111
# batch : batch dimension of data and activations
112112
# hidden :
@@ -131,17 +131,19 @@ logical_axis_rules: [
131131
['out_channels', 'tensor'],
132132
['conv_out', 'fsdp'],
133133
]
134-
data_sharding: [['data', 'fsdp', 'tensor']]
134+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
135135

136136
# One axis for each parallelism type may hold a placeholder (-1)
137137
# value to auto-shard based on available slices and devices.
138138
# By default, product of the DCN axes should equal number of slices
139139
# and product of the ICI axes should equal number of devices per slice.
140140
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
141141
dcn_fsdp_parallelism: 1
142+
dcn_context_parallelism: 1
142143
dcn_tensor_parallelism: 1
143144
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
144145
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
146+
ici_context_parallelism: 1
145147
ici_tensor_parallelism: 1
146148

147149
allow_split_physical_axes: False

src/maxdiffusion/configs/base21.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ skip_jax_distributed_system: False
108108
base_output_directory: ""
109109

110110
# Parallelism
111-
mesh_axes: ['data', 'fsdp', 'tensor']
111+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
112112

113113
# batch : batch dimension of data and activations
114114
# hidden :
@@ -133,17 +133,19 @@ logical_axis_rules: [
133133
['out_channels', 'tensor'],
134134
['conv_out', 'fsdp'],
135135
]
136-
data_sharding: [['data', 'fsdp', 'tensor']]
136+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
137137

138138
# One axis for each parallelism type may hold a placeholder (-1)
139139
# value to auto-shard based on available slices and devices.
140140
# By default, product of the DCN axes should equal number of slices
141141
# and product of the ICI axes should equal number of devices per slice.
142142
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
143143
dcn_fsdp_parallelism: 1
144+
dcn_context_parallelism: 1
144145
dcn_tensor_parallelism: 1
145146
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
146147
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
148+
ici_context_parallelism: 1
147149
ici_tensor_parallelism: 1
148150

149151
allow_split_physical_axes: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ skip_jax_distributed_system: False
121121
base_output_directory: ""
122122

123123
# Parallelism
124-
mesh_axes: ['data', 'fsdp', 'tensor']
124+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
125125

126126
# batch : batch dimension of data and activations
127127
# hidden :
@@ -146,17 +146,19 @@ logical_axis_rules: [
146146
['out_channels', 'tensor'],
147147
['conv_out', 'fsdp'],
148148
]
149-
data_sharding: [['data', 'fsdp', 'tensor']]
149+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
150150

151151
# One axis for each parallelism type may hold a placeholder (-1)
152152
# value to auto-shard based on available slices and devices.
153153
# By default, product of the DCN axes should equal number of slices
154154
# and product of the ICI axes should equal number of devices per slice.
155155
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
156156
dcn_fsdp_parallelism: 1
157+
dcn_context_parallelism: 1
157158
dcn_tensor_parallelism: 1
158159
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
159160
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
161+
ici_context_parallelism: 1
160162
ici_tensor_parallelism: 1
161163

162164
allow_split_physical_axes: False

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
132132
skip_jax_distributed_system: False
133133

134134
# Parallelism
135-
mesh_axes: ['data', 'fsdp', 'tensor']
135+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
136136

137137
# batch : batch dimension of data and activations
138138
# hidden :
@@ -158,17 +158,19 @@ logical_axis_rules: [
158158
['out_channels', 'tensor'],
159159
['conv_out', 'fsdp'],
160160
]
161-
data_sharding: [['data', 'fsdp', 'tensor']]
161+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
162162

163163
# One axis for each parallelism type may hold a placeholder (-1)
164164
# value to auto-shard based on available slices and devices.
165165
# By default, product of the DCN axes should equal number of slices
166166
# and product of the ICI axes should equal number of devices per slice.
167167
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
168168
dcn_fsdp_parallelism: -1
169+
dcn_context_parallelism: 1
169170
dcn_tensor_parallelism: 1
170171
ici_data_parallelism: -1
171172
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
173+
ici_context_parallelism: 1
172174
ici_tensor_parallelism: 1
173175

174176
allow_split_physical_axes: False

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
132132
skip_jax_distributed_system: False
133133

134134
# Parallelism
135-
mesh_axes: ['data', 'fsdp', 'tensor']
135+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
136136

137137
# batch : batch dimension of data and activations
138138
# hidden :
@@ -158,17 +158,19 @@ logical_axis_rules: [
158158
['out_channels', 'tensor'],
159159
['conv_out', 'fsdp'],
160160
]
161-
data_sharding: [['data', 'fsdp', 'tensor']]
161+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
162162

163163
# One axis for each parallelism type may hold a placeholder (-1)
164164
# value to auto-shard based on available slices and devices.
165165
# By default, product of the DCN axes should equal number of slices
166166
# and product of the ICI axes should equal number of devices per slice.
167167
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
168168
dcn_fsdp_parallelism: -1
169+
dcn_context_parallelism: 1
169170
dcn_tensor_parallelism: 1
170171
ici_data_parallelism: -1
171172
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
173+
ici_context_parallelism: 1
172174
ici_tensor_parallelism: 1
173175

174176
allow_split_physical_axes: False

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
140140
skip_jax_distributed_system: False
141141

142142
# Parallelism
143-
mesh_axes: ['data', 'fsdp', 'tensor']
143+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
144144

145145
# batch : batch dimension of data and activations
146146
# hidden :
@@ -166,17 +166,19 @@ logical_axis_rules: [
166166
['out_channels', 'tensor'],
167167
['conv_out', 'fsdp'],
168168
]
169-
data_sharding: [['data', 'fsdp', 'tensor']]
169+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
170170

171171
# One axis for each parallelism type may hold a placeholder (-1)
172172
# value to auto-shard based on available slices and devices.
173173
# By default, product of the DCN axes should equal number of slices
174174
# and product of the ICI axes should equal number of devices per slice.
175175
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
176176
dcn_fsdp_parallelism: 1
177+
dcn_context_parallelism: 1
177178
dcn_tensor_parallelism: 1
178179
ici_data_parallelism: -1
179180
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
181+
ici_context_parallelism: 1
180182
ici_tensor_parallelism: 1
181183

182184
allow_split_physical_axes: False

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
151151
skip_jax_distributed_system: False
152152

153153
# Parallelism
154-
mesh_axes: ['data', 'fsdp', 'tensor']
154+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
155155

156156
# batch : batch dimension of data and activations
157157
# hidden :
@@ -166,31 +166,33 @@ mesh_axes: ['data', 'fsdp', 'tensor']
166166
# conv_in : conv.shape[2] weight
167167
# conv_out : conv.shape[-1] weight
168168
logical_axis_rules: [
169-
['batch', 'data'],
170-
['activation_batch', 'data'],
171-
['activation_self_attn_heads', ['fsdp', 'tensor']],
172-
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
173-
['activation_length', 'fsdp'],
169+
['batch', ['data', 'fsdp']],
170+
['activation_batch', ['data', 'fsdp']],
171+
['activation_self_attn_heads', ['context', 'tensor']],
172+
['activation_cross_attn_q_length', ['context', 'tensor']],
173+
['activation_length', 'context'],
174174
['activation_heads', 'tensor'],
175175
['mlp','tensor'],
176-
['embed','fsdp'],
176+
['embed', ['context', 'fsdp']],
177177
['heads', 'tensor'],
178178
['norm', 'tensor'],
179-
['conv_batch', ['data','fsdp']],
179+
['conv_batch', ['data', 'context', 'fsdp']],
180180
['out_channels', 'tensor'],
181-
['conv_out', 'fsdp'],
181+
['conv_out', 'context'],
182182
]
183-
data_sharding: [['data', 'fsdp', 'tensor']]
183+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
184184

185185
# One axis for each parallelism type may hold a placeholder (-1)
186186
# value to auto-shard based on available slices and devices.
187187
# By default, product of the DCN axes should equal number of slices
188188
# and product of the ICI axes should equal number of devices per slice.
189189
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
190-
dcn_fsdp_parallelism: -1
190+
dcn_fsdp_parallelism: 1
191+
dcn_context_parallelism: -1
191192
dcn_tensor_parallelism: 1
192193
ici_data_parallelism: 1
193-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
194+
ici_fsdp_parallelism: 1
195+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
194196
ici_tensor_parallelism: 1
195197

196198
allow_split_physical_axes: False

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
139139
skip_jax_distributed_system: False
140140

141141
# Parallelism
142-
mesh_axes: ['data', 'fsdp', 'tensor']
142+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
143143

144144
# batch : batch dimension of data and activations
145145
# hidden :
@@ -154,30 +154,33 @@ mesh_axes: ['data', 'fsdp', 'tensor']
154154
# conv_in : conv.shape[2] weight
155155
# conv_out : conv.shape[-1] weight
156156
logical_axis_rules: [
157-
['batch', 'data'],
158-
['activation_batch', 'data'],
159-
['activation_length', 'fsdp'],
160-
157+
['batch', ['data', 'fsdp']],
158+
['activation_batch', ['data', 'fsdp']],
159+
['activation_self_attn_heads', ['context', 'tensor']],
160+
['activation_cross_attn_q_length', ['context', 'tensor']],
161+
['activation_length', 'context'],
161162
['activation_heads', 'tensor'],
162163
['mlp','tensor'],
163-
['embed','fsdp'],
164+
['embed', ['context', 'fsdp']],
164165
['heads', 'tensor'],
165166
['norm', 'tensor'],
166-
['conv_batch', ['data','fsdp']],
167+
['conv_batch', ['data', 'context', 'fsdp']],
167168
['out_channels', 'tensor'],
168-
['conv_out', 'fsdp'],
169+
['conv_out', 'context'],
169170
]
170-
data_sharding: [['data', 'fsdp', 'tensor']]
171+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
171172

172173
# One axis for each parallelism type may hold a placeholder (-1)
173174
# value to auto-shard based on available slices and devices.
174175
# By default, product of the DCN axes should equal number of slices
175176
# and product of the ICI axes should equal number of devices per slice.
176177
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
177-
dcn_fsdp_parallelism: -1
178+
dcn_fsdp_parallelism: 1
179+
dcn_context_parallelism: -1
178180
dcn_tensor_parallelism: 1
179181
ici_data_parallelism: 1
180-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
182+
ici_fsdp_parallelism: 1
183+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
181184
ici_tensor_parallelism: 1
182185

183186
allow_split_physical_axes: False

0 commit comments

Comments
 (0)