@@ -97,23 +97,7 @@ def test_ltx2_rope(self):
9797 base_width = base_width ,
9898 modality = "video"
9999 )
100-
101- # Create dummy grid
102- # For video: (B, T, H, W) -> flattened indices? No, RoPE takes `ids` often, or computes them internally?
103- # LTX2RotaryPosEmbed.__call__ takes `ids`.
104- # Let's check how it's called in `transformer_ltx2.py`.
105- # It seems `transformer_ltx2.py` calls `pool_patches` or similar to generate grid?
106- # Actually `LTX2RotaryPosEmbed` seems to have logic to generate embeddings from indices.
107-
108- # Let's try calling it with dummy IDs if expected, or if it generates them.
109- # Looking at `test_attention_ltx2.py`, it passes `ids` or similar.
110- # In `transformer_ltx2.py`, `prepare_video_coords` generates coordinates.
111- # But `LTX2RotaryPosEmbed` forward might take `ids`.
112- # Wait, `transformer_ltx2.py` defines `self.rope`.
113- # Let's verify `LTX2RotaryPosEmbed` signature in `test_attention_ltx2.py` or implementation.
114- # `test_attention_ltx2.py`: `rope_jax(jnp.array(np_ids))`
115-
116- ids = jnp .ones ((1 , 10 , 3 )) # (B, S, 3) for 3D coords
100+ ids = jnp .ones ((1 , 3 , 10 )) # (B, Axes, S) for 3D coords
117101 cos , sin = rope (ids )
118102
119103 # Check output shape
@@ -224,15 +208,15 @@ def test_ltx2_transformer_model(self):
224208 num_attention_heads = self .num_heads ,
225209 attention_head_dim = self .head_dim ,
226210 cross_attention_dim = self .cross_dim ,
227- caption_channels = 32 , # kept small for now, or match parity if needed
211+ caption_channels = 32 ,
228212 audio_in_channels = audio_in_channels ,
229213 audio_out_channels = audio_in_channels ,
230214 audio_num_attention_heads = self .audio_num_heads ,
231215 audio_attention_head_dim = self .audio_head_dim ,
232216 audio_cross_attention_dim = self .audio_cross_dim ,
233217 num_layers = 1 ,
234218 mesh = self .mesh ,
235- attention_kernel = "dot_product" # Force dot_product for test stability on CPU/small config
219+ attention_kernel = "dot_product"
236220 )
237221
238222 batch_size = self .batch_size
0 commit comments