From 490dd70d82f5ede4ee49f80605d841ea097bb58e Mon Sep 17 00:00:00 2001 From: mbohlool Date: Fri, 17 Apr 2026 21:05:21 +0000 Subject: [PATCH] fix(ltx2): resolve flash attention block size mismatch and missing config overrides This commit addresses two issues in the LTX-2 pipeline: 1. Pipeline Config Overrides: Fixed a bug in `ltx2_pipeline.py` where `a2v_attention_kernel` and `v2a_attention_kernel` configurations were ignored. The model previously hardcoded a fallback to "flash" because these values were not mapped from the user config to `ltx2_config`. 2. Flash Attention Padding Mismatch: Fixed a `ValueError` (e.g., `kv_block_size=126 should divide kv_seq_len=128`) in `attention_flax.py` that occurred for specific video frame counts. A previous fix padded sequences to satisfy `shard_map` context dimension requirements, but `_select_flash_block_sizes` was calculating block sizes based on the unpadded length. Moved the block size calculation to occur *after* `_reshape_data_for_flash` so that the dynamic `min()` bounds correctly align with the newly padded sequence lengths, keeping cross-attention optimizations intact and unit tests passing. --- src/maxdiffusion/models/attention_flax.py | 2 +- src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 583695fa7..7b707bff3 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -287,11 +287,11 @@ def _tpu_flash_attention( ) -> jax.Array: """TPU Flash Attention""" - block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) num_context_shards = mesh.shape["context"] query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) key, _ = _reshape_data_for_flash(key, heads, num_context_shards) value, _ = _reshape_data_for_flash(value, heads, num_context_shards) + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 485d500ef..569f194ab 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -127,6 +127,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict): ltx2_config["dtype"] = config.activations_dtype ltx2_config["weights_dtype"] = config.weights_dtype ltx2_config["attention_kernel"] = config.attention + ltx2_config["a2v_attention_kernel"] = getattr(config, "a2v_attention_kernel", "flash") + ltx2_config["v2a_attention_kernel"] = getattr(config, "v2a_attention_kernel", "dot_product") ltx2_config["precision"] = get_precision(config) ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config) ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096)