File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -195,8 +195,14 @@ def _tpu_flash_attention(
195195 block_q_dkv = min (q_max_block_size , query .shape [2 ]),
196196 block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
197197 block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
198+ << << << < Updated upstream
198199 block_q_dq = min (q_max_block_size , query .shape [2 ]),
199200 block_kv_dq = min (kv_max_block_size , query .shape [2 ]),
201+ == == == =
202+ block_q_dq = None if attention_kernel == "tokamax_flash" else min (q_max_block_size , query .shape [2 ]),
203+ block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
204+ use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
205+ > >> >> >> Stashed changes
200206 )
201207 num_fsdp_shards = mesh .shape ["fsdp" ]
202208 query = _reshape_data_for_flash (query , heads )
You can’t perform that action at this time.
0 commit comments