Skip to content

Commit 045dfac

Browse files
committed
added enable_profiler flag in yml file, fix in attention_flax.py for kv_max_block_size
1 parent 1f2e44b commit 045dfac

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ ici_data_parallelism: 1
6565
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
6666
ici_context_parallelism: 1
6767
ici_tensor_parallelism: 1
68+
enable_profiler: False
6869

6970
replicate_vae: False
7071

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,7 @@ def _tpu_flash_attention(
235235
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
236236
# This is the case for cross-attn.
237237
if key.shape[1] != query.shape[1]:
238-
assert key.shape[1] % 128 == 0
239-
kv_max_block_size = key.shape[1]
238+
kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
240239
else:
241240
kv_max_block_size = q_max_block_size
242241
# ensure that for cross attention we override the block sizes.

0 commit comments

Comments
 (0)