Skip to content

Commit 00a1d51

Browse files
committed
Add attention block
1 parent b801ebf commit 00a1d51

2 files changed

Lines changed: 2 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/__init__.py

Whitespace-only changes.

src/maxdiffusion/tests/ltx2/test_attention_ltx2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pandas as pd
2424
from jax.sharding import Mesh
2525
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
26-
26+
from ...models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
2727
# Set JAX to use float32 for higher precision checks
2828
jax.config.update("jax_default_matmul_precision", "float32")
2929

@@ -175,9 +175,8 @@ def forward(self, x, context=None, q_rope=None, k_rope=None, mask=None):
175175

176176

177177
# ==========================================
178-
# 2. JAX Imports & Test Suite
178+
# 2. JAX Test Suite
179179
# ==========================================
180-
from ...models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
181180

182181

183182
class LTX2AttentionTest(unittest.TestCase):

0 commit comments

Comments
 (0)