@@ -60,18 +60,27 @@ jit_initializers: True
6060from_pt : True
6161split_head_dim : True
6262attention : ' flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
63- flash_min_seq_length : 4096
63+ flash_min_seq_length : 0
64+
65+ # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
66+ # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
67+ # However, when padding tokens are significant, this will lead to worse quality and should be set to True.
68+ mask_padding_tokens : True
69+ # Maxdiffusion has 2 types of attention sharding strategies:
70+ # 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
71+ # 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
72+ # in cross attention q.
73+ attention_sharding_uniform : True
6474dropout : 0.1
6575
6676flash_block_sizes : {
67- " block_q" : 1024,
68- " block_kv_compute" : 256,
69- " block_kv" : 1024,
70- " block_q_dkv" : 1024,
71- " block_kv_dkv" : 1024,
72- " block_kv_dkv_compute" : 256,
73- " block_q_dq" : 1024,
74- " block_kv_dq" : 1024
77+ " block_q" : 2048,
78+ " block_kv_compute" : 512,
79+ " block_kv" : 2048,
80+ " block_q_dkv" : 2048,
81+ " block_kv_dkv" : 2048,
82+ " block_kv_dkv_compute" : 512,
83+ " use_fused_bwd_kernel " : True
7584}
7685# Use on v6e
7786# flash_block_sizes: {
@@ -80,11 +89,22 @@ flash_block_sizes: {
8089# "block_kv" : 2048,
8190# "block_q_dkv" : 3024,
8291# "block_kv_dkv" : 2048,
83- # "block_kv_dkv_compute" : 2048 ,
92+ # "block_kv_dkv_compute" : 1024 ,
8493# "block_q_dq" : 3024,
8594# "block_kv_dq" : 2048,
8695# "use_fused_bwd_kernel": False,
8796# }
97+ # Use on v5p
98+ # flash_block_sizes: {
99+ # "block_q" : 3024,
100+ # "block_kv_compute" : 1024,
101+ # "block_kv" : 2048,
102+ # "block_q_dkv" : 1024,
103+ # "block_kv_dkv" : 3072,
104+ # "block_kv_dkv_compute" : 256,
105+ # "block_q_dq" : 1024,
106+ # "block_kv_dq" : 3072
107+ # }
88108# GroupNorm groups
89109norm_num_groups : 32
90110
@@ -145,8 +165,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
145165logical_axis_rules : [
146166 ['batch', 'data'],
147167 ['activation_batch', 'data'],
168+ ['activation_self_attn_heads', ['fsdp', 'tensor']],
169+ ['activation_cross_attn_q_length', ['fsdp', 'tensor']],
148170 ['activation_length', 'fsdp'],
149-
150171 ['activation_heads', 'tensor'],
151172 ['mlp','tensor'],
152173 ['embed','fsdp'],
@@ -280,7 +301,7 @@ flow_shift: 3.0
280301# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
281302guidance_rescale : 0.0
282303num_inference_steps : 30
283- fps : 24
304+ fps : 16
284305save_final_checkpoint : False
285306
286307# SDXL Lightning parameters
0 commit comments