Skip to content

Commit 5ca96b5

Browse files
committed
registering a2v and v2a attention
1 parent 56c2fc7 commit 5ca96b5

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
44
attention: 'ulysses'
5-
a2v_attention_kernel: 'ulysses'
5+
a2v_attention_kernel: 'dot_product'
66
v2a_attention_kernel: 'dot_product'
77
attention_sharding_uniform: True
88
precision: 'bf16'

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
127127
ltx2_config["dtype"] = config.activations_dtype
128128
ltx2_config["weights_dtype"] = config.weights_dtype
129129
ltx2_config["attention_kernel"] = config.attention
130+
ltx2_config["a2v_attention_kernel"] = getattr(config, "a2v_attention_kernel", "flash")
131+
ltx2_config["v2a_attention_kernel"] = getattr(config, "v2a_attention_kernel", "dot_product")
130132
ltx2_config["precision"] = get_precision(config)
131133
ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config)
132134
ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096)

0 commit comments

Comments
 (0)