We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1f2e44b commit 045dfacCopy full SHA for 045dfac
2 files changed
src/maxdiffusion/configs/ltx2_video.yml
@@ -65,6 +65,7 @@ ici_data_parallelism: 1
65
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
66
ici_context_parallelism: 1
67
ici_tensor_parallelism: 1
68
+enable_profiler: False
69
70
replicate_vae: False
71
src/maxdiffusion/models/attention_flax.py
@@ -235,8 +235,7 @@ def _tpu_flash_attention(
235
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
236
# This is the case for cross-attn.
237
if key.shape[1] != query.shape[1]:
238
- assert key.shape[1] % 128 == 0
239
- kv_max_block_size = key.shape[1]
+ kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
240
else:
241
kv_max_block_size = q_max_block_size
242
# ensure that for cross attention we override the block sizes.
0 commit comments