Skip to content

Commit 947d902

Browse files
committed
logical_axis rules and attention_sharding_uniform added in config files
1 parent bd86792 commit 947d902

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ dropout: 0.1
6868
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
6969
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
7070
mask_padding_tokens: True
71+
attention_sharding_uniform: True
7172

7273
flash_block_sizes: {
7374
"block_q" : 2048,
@@ -150,8 +151,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
150151
logical_axis_rules: [
151152
['batch', 'data'],
152153
['activation_batch', 'data'],
154+
['activation_self_attn_heads', ['fsdp', 'tensor']],
155+
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
153156
['activation_length', 'fsdp'],
154-
155157
['activation_heads', 'tensor'],
156158
['mlp','tensor'],
157159
['embed','fsdp'],

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ dropout: 0.1
6868
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
6969
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
7070
mask_padding_tokens: True
71+
attention_sharding_uniform: True
7172

7273
flash_block_sizes: {
7374
"block_q" : 1024,
@@ -151,8 +152,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
151152
logical_axis_rules: [
152153
['batch', 'data'],
153154
['activation_batch', 'data'],
154-
['activation_length', 'fsdp'],
155-
155+
['activation_self_attn_heads', ['fsdp', 'tensor']],
156+
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
157+
['activation_length', 'fsdp'],
156158
['activation_heads', 'tensor'],
157159
['mlp','tensor'],
158160
['embed','fsdp'],

0 commit comments

Comments
 (0)