Skip to content

Commit 9660fa0

Browse files
committed
Merge conflict error
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent d843dc0 commit 9660fa0

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,14 @@ def _tpu_flash_attention(
195195
block_q_dkv=min(q_max_block_size, query.shape[2]),
196196
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
197197
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
198+
<<<<<<< Updated upstream
198199
block_q_dq=min(q_max_block_size, query.shape[2]),
199200
block_kv_dq=min(kv_max_block_size, query.shape[2]),
201+
=======
202+
block_q_dq=None if attention_kernel == "tokamax_flash" else min(q_max_block_size, query.shape[2]),
203+
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
204+
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
205+
>>>>>>> Stashed changes
200206
)
201207
num_fsdp_shards = mesh.shape["fsdp"]
202208
query = _reshape_data_for_flash(query, heads)

0 commit comments

Comments
 (0)