Skip to content

Commit 015ccc2

Browse files
committed
add split self/cross attention sharding rules and configuration
1 parent 8f1ffde commit 015ccc2

12 files changed

Lines changed: 164 additions & 15 deletions

src/maxdiffusion/common_types.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
BlockSizes = splash_attention_kernel.BlockSizes
3434

3535
AxisNames = tuple[str, ...]
36-
36+
# Physical axis names for device meshes.
37+
DATA = "data"
38+
FSDP = "fsdp"
39+
TENSOR = "tensor"
40+
# Logical axis names for model parameters and activations.
3741
BATCH = "activation_batch"
3842
LENGTH = "activation_length"
3943
KV_LENGTH = "activation_kv_length"
@@ -44,4 +48,31 @@
4448
KEEP_2 = "activation_keep_2"
4549
CONV_OUT = "activation_conv_out_channels"
4650

51+
# For setting self/cross attention independently in splash kernel
52+
SELF_ATTN_HEAD = "activation_self_attn_heads"
53+
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
54+
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
55+
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
56+
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
57+
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"
58+
4759
WAN_MODEL = "Wan2.1"
60+
61+
### Common axis rules for ring attention ###
62+
RING_ATTENTION_AXIS_RULES = [
63+
[SELF_ATTN_HEAD, None],
64+
[SELF_ATTN_Q_LENGTH, FSDP],
65+
[SELF_ATTN_KV_LENGTH, FSDP],
66+
[CROSS_ATTN_HEAD, None],
67+
[CROSS_ATTN_Q_LENGTH, FSDP],
68+
[CROSS_ATTN_KV_LENGTH, FSDP],
69+
]
70+
71+
SEQUENCE_PARALLEL_AXIS_RULES = [
72+
[SELF_ATTN_HEAD, None],
73+
[SELF_ATTN_Q_LENGTH, FSDP],
74+
[SELF_ATTN_KV_LENGTH, None],
75+
[CROSS_ATTN_HEAD, None],
76+
[CROSS_ATTN_Q_LENGTH, FSDP],
77+
[CROSS_ATTN_KV_LENGTH, None],
78+
]

src/maxdiffusion/configs/base14.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ jit_initializers: True
5050
from_pt: False
5151
split_head_dim: True
5252
attention: 'dot_product' # Supported attention: dot_product, flash
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
5362
flash_block_sizes: {}
5463
# GroupNorm groups
5564
norm_num_groups: 32

src/maxdiffusion/configs/base21.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ jit_initializers: True
4949
from_pt: False
5050
split_head_dim: True
5151
attention: 'dot_product' # Supported attention: dot_product, flash
52+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
53+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
54+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
55+
mask_padding_tokens: True
56+
# Maxdiffusion has 2 types of attention sharding strategies:
57+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
58+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
59+
# in cross attention q.
60+
attention_sharding_uniform: True
5261
flash_block_sizes: {}
5362
# GroupNorm groups
5463
norm_num_groups: 32

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ jit_initializers: True
5050
from_pt: True
5151
split_head_dim: True
5252
attention: 'flash' # Supported attention: dot_product, flash
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
5362
flash_block_sizes: {}
5463
# to override default block sizes for flash attention
5564
# flash_block_sizes:

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ jit_initializers: True
6363
from_pt: True
6464
split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
66+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69+
mask_padding_tokens: True
70+
# Maxdiffusion has 2 types of attention sharding strategies:
71+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73+
# in cross attention q.
74+
attention_sharding_uniform: True
6675

6776
flash_block_sizes: {}
6877
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ jit_initializers: True
6363
from_pt: True
6464
split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
66+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69+
mask_padding_tokens: True
70+
# Maxdiffusion has 2 types of attention sharding strategies:
71+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73+
# in cross attention q.
74+
attention_sharding_uniform: True
6675

6776
#flash_block_sizes: {}
6877
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
65+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
66+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
67+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
68+
mask_padding_tokens: True
69+
# Maxdiffusion has 2 types of attention sharding strategies:
70+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
71+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
72+
# in cross attention q.
73+
attention_sharding_uniform: True
6574
flash_block_sizes: {
6675
"block_q" : 256,
6776
"block_kv_compute" : 256,

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,27 @@ jit_initializers: True
6060
from_pt: True
6161
split_head_dim: True
6262
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63-
flash_min_seq_length: 4096
63+
flash_min_seq_length: 0
64+
65+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
66+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
67+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
68+
mask_padding_tokens: True
69+
# Maxdiffusion has 2 types of attention sharding strategies:
70+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
71+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
72+
# in cross attention q.
73+
attention_sharding_uniform: True
6474
dropout: 0.1
6575

6676
flash_block_sizes: {
67-
"block_q" : 1024,
68-
"block_kv_compute" : 256,
69-
"block_kv" : 1024,
70-
"block_q_dkv" : 1024,
71-
"block_kv_dkv" : 1024,
72-
"block_kv_dkv_compute" : 256,
73-
"block_q_dq" : 1024,
74-
"block_kv_dq" : 1024
77+
"block_q" : 2048,
78+
"block_kv_compute" : 512,
79+
"block_kv" : 2048,
80+
"block_q_dkv" : 2048,
81+
"block_kv_dkv" : 2048,
82+
"block_kv_dkv_compute" : 512,
83+
"use_fused_bwd_kernel": True
7584
}
7685
# Use on v6e
7786
# flash_block_sizes: {
@@ -80,11 +89,22 @@ flash_block_sizes: {
8089
# "block_kv" : 2048,
8190
# "block_q_dkv" : 3024,
8291
# "block_kv_dkv" : 2048,
83-
# "block_kv_dkv_compute" : 2048,
92+
# "block_kv_dkv_compute" : 1024,
8493
# "block_q_dq" : 3024,
8594
# "block_kv_dq" : 2048,
8695
# "use_fused_bwd_kernel": False,
8796
# }
97+
# Use on v5p
98+
# flash_block_sizes: {
99+
# "block_q" : 3024,
100+
# "block_kv_compute" : 1024,
101+
# "block_kv" : 2048,
102+
# "block_q_dkv" : 1024,
103+
# "block_kv_dkv" : 3072,
104+
# "block_kv_dkv_compute" : 256,
105+
# "block_q_dq" : 1024,
106+
# "block_kv_dq" : 3072
107+
# }
88108
# GroupNorm groups
89109
norm_num_groups: 32
90110

@@ -145,8 +165,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
145165
logical_axis_rules: [
146166
['batch', 'data'],
147167
['activation_batch', 'data'],
168+
['activation_self_attn_heads', ['fsdp', 'tensor']],
169+
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
148170
['activation_length', 'fsdp'],
149-
150171
['activation_heads', 'tensor'],
151172
['mlp','tensor'],
152173
['embed','fsdp'],
@@ -280,7 +301,7 @@ flow_shift: 3.0
280301
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
281302
guidance_rescale: 0.0
282303
num_inference_steps: 30
283-
fps: 24
304+
fps: 16
284305
save_final_checkpoint: False
285306

286307
# SDXL Lightning parameters

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ from_pt: True
6161
split_head_dim: True
6262
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6363
flash_min_seq_length: 4096
64+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
65+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
66+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
67+
mask_padding_tokens: True
68+
# Maxdiffusion has 2 types of attention sharding strategies:
69+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
70+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
71+
# in cross attention q.
72+
attention_sharding_uniform: True
6473
dropout: 0.1
6574

6675
flash_block_sizes: {

src/maxdiffusion/configs/base_xl.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ jit_initializers: True
5050
from_pt: False
5151
split_head_dim: True
5252
attention: 'dot_product' # Supported attention: dot_product, flash
53+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56+
mask_padding_tokens: True
57+
# Maxdiffusion has 2 types of attention sharding strategies:
58+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60+
# in cross attention q.
61+
attention_sharding_uniform: True
5362
flash_block_sizes: {}
5463
# GroupNorm groups
5564
norm_num_groups: 32

0 commit comments

Comments
 (0)