Skip to content

Commit 513a0c3

Browse files
committed
trying block size fix
1 parent ada408b commit 513a0c3

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#hardware
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
4-
attention: 'dot_product'
4+
attention: 'flash'
55
attention_sharding_uniform: True
66
precision: 'bf16'
77
data_sharding: ['data', 'fsdp', 'context', 'tensor']

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,10 @@ def _tpu_flash_attention(
235235
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
236236
# This is the case for cross-attn.
237237
if key.shape[1] != query.shape[1]:
238-
assert key.shape[1] % 128 == 0
239-
kv_max_block_size = key.shape[1]
238+
if key.shape[1] % 128 != 0:
239+
kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
240+
else:
241+
kv_max_block_size = key.shape[1]
240242
else:
241243
kv_max_block_size = q_max_block_size
242244
# ensure that for cross attention we override the block sizes.

0 commit comments

Comments
 (0)