@@ -166,20 +166,25 @@ def _tpu_flash_attention(
166166 dtype : jnp .dtype = jnp .float32 ,
167167) -> jax .Array :
168168 """TPU Flash Attention"""
169-
170- max_block_size = 1024 if dtype == jnp .bfloat16 else 512
169+ q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
170+ # Cross-attention where kv dims are much smaller due to encoder_hidden_states.
171+ # If kv seq_len is padded too much, it causes issues in attention calculations.
172+ if key .shape [1 ] != query .shape [1 ]:
173+ kv_max_block_size = key .shape [1 ]
174+ else :
175+ kv_max_block_size = q_max_block_size
171176 if flash_block_sizes :
172177 block_sizes = flash_block_sizes
173178 else :
174179 block_sizes = splash_attention_kernel .BlockSizes (
175- block_q = min (max_block_size , query .shape [2 ]),
176- block_kv_compute = min (max_block_size , key .shape [2 ]),
177- block_kv = min (max_block_size , key .shape [2 ]),
178- block_q_dkv = min (max_block_size , query .shape [2 ]),
179- block_kv_dkv = min (max_block_size , key .shape [2 ]),
180- block_kv_dkv_compute = min (max_block_size , query .shape [2 ]),
181- block_q_dq = min (max_block_size , query .shape [2 ]),
182- block_kv_dq = min (max_block_size , query .shape [2 ]),
180+ block_q = min (q_max_block_size , query .shape [2 ]),
181+ block_kv_compute = min (kv_max_block_size , key .shape [2 ]),
182+ block_kv = min (kv_max_block_size , key .shape [2 ]),
183+ block_q_dkv = min (q_max_block_size , query .shape [2 ]),
184+ block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
185+ block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
186+ block_q_dq = min (q_max_block_size , query .shape [2 ]),
187+ block_kv_dq = min (kv_max_block_size , query .shape [2 ]),
183188 )
184189
185190 num_fsdp_shards = mesh .shape ["fsdp" ]
0 commit comments