Skip to content

Commit 587bc6a

Browse files
optimized flash block sizes for trillium.
1 parent df25e47 commit 587bc6a

2 files changed

Lines changed: 29 additions & 6 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,19 @@ precision: "DEFAULT"
5555
from_pt: True
5656
split_head_dim: True
5757
attention: 'flash' # Supported attention: dot_product, flash
58+
5859
flash_block_sizes: {}
60+
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
61+
# flash_block_sizes: {
62+
# "block_q" : 1536,
63+
# "block_kv_compute" : 1536,
64+
# "block_kv" : 1536,
65+
# "block_q_dkv" : 1536,
66+
# "block_kv_dkv" : 1536,
67+
# "block_kv_dkv_compute" : 1536,
68+
# "block_q_dq" : 1536,
69+
# "block_kv_dq" : 1536
70+
# }
5971
# GroupNorm groups
6072
norm_num_groups: 32
6173

@@ -137,8 +149,8 @@ data_sharding: [['data', 'fsdp', 'tensor']]
137149
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
138150
dcn_fsdp_parallelism: -1
139151
dcn_tensor_parallelism: 1
140-
ici_data_parallelism: 1
141-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
152+
ici_data_parallelism: -1
153+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
142154
ici_tensor_parallelism: 1
143155

144156
# Dataset

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
5757
flash_block_sizes: {}
58+
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
59+
# flash_block_sizes: {
60+
# "block_q" : 1536,
61+
# "block_kv_compute" : 1536,
62+
# "block_kv" : 1536,
63+
# "block_q_dkv" : 1536,
64+
# "block_kv_dkv" : 1536,
65+
# "block_kv_dkv_compute" : 1536,
66+
# "block_q_dq" : 1536,
67+
# "block_kv_dq" : 1536
68+
# }
5869
# GroupNorm groups
5970
norm_num_groups: 32
6071

@@ -133,11 +144,11 @@ data_sharding: [['data', 'fsdp', 'tensor']]
133144
# value to auto-shard based on available slices and devices.
134145
# By default, product of the DCN axes should equal number of slices
135146
# and product of the ICI axes should equal number of devices per slice.
136-
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
137-
dcn_fsdp_parallelism: -1
147+
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
148+
dcn_fsdp_parallelism: 1
138149
dcn_tensor_parallelism: 1
139-
ici_data_parallelism: 1
140-
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
150+
ici_data_parallelism: -1
151+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
141152
ici_tensor_parallelism: 1
142153

143154
# Dataset

0 commit comments

Comments
 (0)