Skip to content

Commit 0b6410b

Browse files
Merge pull request #382 from AI-Hypercomputer:ltx2_bugfix
PiperOrigin-RevId: 903276594
2 parents bc3bca0 + 490dd70 commit 0b6410b

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,11 +317,11 @@ def _tpu_flash_attention(
317317
) -> jax.Array:
318318
"""TPU Flash Attention"""
319319

320-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
321320
num_context_shards = mesh.shape["context"]
322321
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
323322
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
324323
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
324+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
325325

326326
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
327327
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
127127
ltx2_config["dtype"] = config.activations_dtype
128128
ltx2_config["weights_dtype"] = config.weights_dtype
129129
ltx2_config["attention_kernel"] = config.attention
130+
ltx2_config["a2v_attention_kernel"] = getattr(config, "a2v_attention_kernel", "flash")
131+
ltx2_config["v2a_attention_kernel"] = getattr(config, "v2a_attention_kernel", "dot_product")
130132
ltx2_config["precision"] = get_precision(config)
131133
ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config)
132134
ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096)

0 commit comments

Comments
 (0)