Skip to content

Commit 2207e08

Browse files
committed
test for transformer
1 parent 6e66e84 commit 2207e08

1 file changed

Lines changed: 3 additions & 19 deletions

File tree

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,7 @@ def test_ltx2_rope(self):
9797
base_width=base_width,
9898
modality="video"
9999
)
100-
101-
# Create dummy grid
102-
# For video: (B, T, H, W) -> flattened indices? No, RoPE takes `ids` often, or computes them internally?
103-
# LTX2RotaryPosEmbed.__call__ takes `ids`.
104-
# Let's check how it's called in `transformer_ltx2.py`.
105-
# It seems `transformer_ltx2.py` calls `pool_patches` or similar to generate grid?
106-
# Actually `LTX2RotaryPosEmbed` seems to have logic to generate embeddings from indices.
107-
108-
# Let's try calling it with dummy IDs if expected, or if it generates them.
109-
# Looking at `test_attention_ltx2.py`, it passes `ids` or similar.
110-
# In `transformer_ltx2.py`, `prepare_video_coords` generates coordinates.
111-
# But `LTX2RotaryPosEmbed` forward might take `ids`.
112-
# Wait, `transformer_ltx2.py` defines `self.rope`.
113-
# Let's verify `LTX2RotaryPosEmbed` signature in `test_attention_ltx2.py` or implementation.
114-
# `test_attention_ltx2.py`: `rope_jax(jnp.array(np_ids))`
115-
116-
ids = jnp.ones((1, 10, 3)) # (B, S, 3) for 3D coords
100+
ids = jnp.ones((1, 3, 10)) # (B, Axes, S) for 3D coords
117101
cos, sin = rope(ids)
118102

119103
# Check output shape
@@ -224,15 +208,15 @@ def test_ltx2_transformer_model(self):
224208
num_attention_heads=self.num_heads,
225209
attention_head_dim=self.head_dim,
226210
cross_attention_dim=self.cross_dim,
227-
caption_channels=32, # kept small for now, or match parity if needed
211+
caption_channels=32,
228212
audio_in_channels=audio_in_channels,
229213
audio_out_channels=audio_in_channels,
230214
audio_num_attention_heads=self.audio_num_heads,
231215
audio_attention_head_dim=self.audio_head_dim,
232216
audio_cross_attention_dim=self.audio_cross_dim,
233217
num_layers=1,
234218
mesh=self.mesh,
235-
attention_kernel="dot_product" # Force dot_product for test stability on CPU/small config
219+
attention_kernel="dot_product"
236220
)
237221

238222
batch_size = self.batch_size

0 commit comments

Comments
 (0)