@@ -56,6 +56,17 @@ split_head_dim: True
5656attention : ' flash' # Supported attention: dot_product, flash, cudnn_flash_te
5757
5858flash_block_sizes : {}
59+ # Use on v6e
60+ # flash_block_sizes: {
61+ # "block_q" : 3024,
62+ # "block_kv_compute" : 1024,
63+ # "block_kv" : 2048,
64+ # "block_q_dkv" : 3024,
65+ # "block_kv_dkv" : 2048,
66+ # "block_kv_dkv_compute" : 2048,
67+ # "block_q_dq" : 3024,
68+ # "block_kv_dq" : 2048
69+ # }
5970# GroupNorm groups
6071norm_num_groups : 32
6172
@@ -115,17 +126,15 @@ mesh_axes: ['data', 'fsdp', 'tensor']
115126# conv_out : conv.shape[-1] weight
116127logical_axis_rules : [
117128 ['batch', 'data'],
118- ['activation_heads ', 'fsdp'],
119- ['activation_batch ', ['data','fsdp'] ],
120- ['activation_kv ', 'tensor '],
129+ ['activation_length ', 'fsdp'],
130+ ['activation_heads ', 'tensor' ],
131+ ['activation_batch ', 'data '],
121132 ['mlp','tensor'],
122133 ['embed','fsdp'],
123- ['heads', 'tensor'],
124- ['norm', 'fsdp'],
134+ ['norm', 'tensor'],
125135 ['conv_batch', ['data','fsdp']],
126136 ['out_channels', 'tensor'],
127- ['conv_out', 'fsdp'],
128- ['conv_in', 'fsdp']
137+ ['conv_in', 'fsdp'],
129138 ]
130139data_sharding : [['data', 'fsdp', 'tensor']]
131140
@@ -140,6 +149,8 @@ ici_data_parallelism: 1
140149ici_fsdp_parallelism : -1 # recommended ICI axis to be auto-sharded
141150ici_tensor_parallelism : 1
142151
152+ allow_split_physical_axes : False
153+
143154# Dataset
144155# Replace with dataset path or train_data_dir. One has to be set.
145156dataset_name : ' diffusers/pokemon-gpt4-captions'
0 commit comments