33import jax
44import jax .numpy as jnp
55from flax import nnx
6+ from flax .linen import partitioning as nn_partitioning
7+ from jax .sharding import Mesh
8+ import os
9+ from maxdiffusion import pyconfig
10+ from maxdiffusion .max_utils import create_device_mesh
611from maxdiffusion .models .ltx_2 .transformer_ltx2 import LTX2VideoTransformerBlock , LTX2VideoTransformer3DModel
712
813class LTX2TransformerTest (unittest .TestCase ):
@@ -13,9 +18,19 @@ class LTX2TransformerTest(unittest.TestCase):
1318 """
1419
1520 def setUp (self ):
21+ # Initialize config and mesh for sharding
22+ # using standard MaxDiffusion pattern
23+ pyconfig .initialize (
24+ [None , os .path .join (os .path .dirname (__file__ ), ".." , "configs" , "ltx_video.yml" )],
25+ unittest = True ,
26+ )
27+ self .config = pyconfig .config
28+ devices_array = create_device_mesh (self .config )
29+ self .mesh = Mesh (devices_array , self .config .mesh_axes )
30+
1631 # random seed for reproducibility
1732 self .rngs = nnx .Rngs (0 )
18- self .batch_size = 2
33+ self .batch_size = 1 # Use 1 for determinism in unit tests often easier
1934 self .num_frames = 4
2035 self .height = 32
2136 self .width = 32
@@ -54,76 +69,78 @@ def test_transformer_block_shapes(self):
5469 audio_dim = 16
5570 cross_dim = 20 # context dim
5671
57- block = LTX2VideoTransformerBlock (
58- rngs = self .rngs ,
59- dim = dim ,
60- num_attention_heads = 4 ,
61- attention_head_dim = 8 ,
62- cross_attention_dim = cross_dim ,
63- audio_dim = audio_dim ,
64- audio_num_attention_heads = 4 ,
65- audio_attention_head_dim = 4 ,
66- audio_cross_attention_dim = cross_dim ,
67- activation_fn = "gelu" ,
68- qk_norm = "rms_norm_across_heads" ,
69- )
70-
71- # Create dummy inputs
72- hidden_states = jnp .zeros ((self .batch_size , self .seq_len , dim ))
73- audio_hidden_states = jnp .zeros ((self .batch_size , 10 , audio_dim )) # 10 audio frames
74- encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , cross_dim ))
75- audio_encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , cross_dim )) # reusing cross_dim for audio context
76-
77- # Dummy scale/shift/gate modulations
78- # These match the shapes expected by the block internal calculation logic
79- # For simplicity, we create them to match 'temb_reshaped' broadcasting or direct add
80- # The block expects raw scale/shift/gate inputs often, OR temb vectors.
81- # Let's check block calls:
82- # It takes `temb` and `temb_ca...`
83- # temb: (B, 1, 6, -1) or similar depending on reshape.
84- # Actually in `transformer_ltx2.py`, call signature takes:
85- # temb: jax.Array
86- # And reshapes it: temb.reshape(batch_size, 1, num_ada_params, -1)
87- # So input `temb` should be (B, num_ada_params * dim) roughly, or (B, num_ada_params, dim)
88-
89- num_ada_params = 6
90- te_dim = num_ada_params * dim # simplified assumption for test
91- temb = jnp .zeros ((self .batch_size , te_dim ))
92-
93- num_audio_ada_params = 6
94- te_audio_dim = num_audio_ada_params * audio_dim
95- temb_audio = jnp .zeros ((self .batch_size , te_audio_dim ))
96-
97- # CA modulations
98- # 4 params for scale/shift, 1 for gate
99- temb_ca_scale_shift = jnp .zeros ((self .batch_size , 4 * dim ))
100- temb_ca_audio_scale_shift = jnp .zeros ((self .batch_size , 4 * audio_dim ))
101- temb_ca_gate = jnp .zeros ((self .batch_size , 1 * dim ))
102- temb_ca_audio_gate = jnp .zeros ((self .batch_size , 1 * audio_dim ))
72+ # NNX sharding context
73+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
74+ block = LTX2VideoTransformerBlock (
75+ rngs = self .rngs ,
76+ dim = dim ,
77+ num_attention_heads = 4 ,
78+ attention_head_dim = 8 ,
79+ cross_attention_dim = cross_dim ,
80+ audio_dim = audio_dim ,
81+ audio_num_attention_heads = 4 ,
82+ audio_attention_head_dim = 4 ,
83+ audio_cross_attention_dim = cross_dim ,
84+ activation_fn = "gelu" ,
85+ qk_norm = "rms_norm_across_heads" ,
86+ )
87+
88+ # Create dummy inputs
89+ hidden_states = jnp .zeros ((self .batch_size , self .seq_len , dim ))
90+ audio_hidden_states = jnp .zeros ((self .batch_size , 10 , audio_dim )) # 10 audio frames
91+ encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , cross_dim ))
92+ audio_encoder_hidden_states = jnp .zeros ((self .batch_size , 5 , cross_dim )) # reusing cross_dim for audio context
93+
94+ # Dummy scale/shift/gate modulations
95+ # These match the shapes expected by the block internal calculation logic
96+ # For simplicity, we create them to match 'temb_reshaped' broadcasting or direct add
97+ # The block expects raw scale/shift/gate inputs often, OR temb vectors.
98+ # Let's check block calls:
99+ # It takes `temb` and `temb_ca...`
100+ # temb: (B, 1, 6, -1) or similar depending on reshape.
101+ # Actually in `transformer_ltx2.py`, call signature takes:
102+ # temb: jax.Array
103+ # And reshapes it: temb.reshape(batch_size, 1, num_ada_params, -1)
104+ # So input `temb` should be (B, num_ada_params * dim) roughly, or (B, num_ada_params, dim)
105+
106+ num_ada_params = 6
107+ te_dim = num_ada_params * dim # simplified assumption for test
108+ temb = jnp .zeros ((self .batch_size , te_dim ))
109+
110+ num_audio_ada_params = 6
111+ te_audio_dim = num_audio_ada_params * audio_dim
112+ temb_audio = jnp .zeros ((self .batch_size , te_audio_dim ))
113+
114+ # CA modulations
115+ # 4 params for scale/shift, 1 for gate
116+ temb_ca_scale_shift = jnp .zeros ((self .batch_size , 4 * dim ))
117+ temb_ca_audio_scale_shift = jnp .zeros ((self .batch_size , 4 * audio_dim ))
118+ temb_ca_gate = jnp .zeros ((self .batch_size , 1 * dim ))
119+ temb_ca_audio_gate = jnp .zeros ((self .batch_size , 1 * audio_dim ))
103120
104- # Perform forward
105- out_hidden , out_audio = block (
106- hidden_states = hidden_states ,
107- audio_hidden_states = audio_hidden_states ,
108- encoder_hidden_states = encoder_hidden_states ,
109- audio_encoder_hidden_states = audio_encoder_hidden_states ,
110- temb = temb ,
111- temb_audio = temb_audio ,
112- temb_ca_scale_shift = temb_ca_scale_shift ,
113- temb_ca_audio_scale_shift = temb_ca_audio_scale_shift ,
114- temb_ca_gate = temb_ca_gate ,
115- temb_ca_audio_gate = temb_ca_audio_gate ,
116- video_rotary_emb = None , # Dummy takes None
117- audio_rotary_emb = None
118- )
119-
120- print (f"Input Video Shape: { hidden_states .shape } " )
121- print (f"Output Video Shape: { out_hidden .shape } " )
122- print (f"Input Audio Shape: { audio_hidden_states .shape } " )
123- print (f"Output Audio Shape: { out_audio .shape } " )
124-
125- self .assertEqual (out_hidden .shape , hidden_states .shape )
126- self .assertEqual (out_audio .shape , audio_hidden_states .shape )
121+ # Perform forward
122+ out_hidden , out_audio = block (
123+ hidden_states = hidden_states ,
124+ audio_hidden_states = audio_hidden_states ,
125+ encoder_hidden_states = encoder_hidden_states ,
126+ audio_encoder_hidden_states = audio_encoder_hidden_states ,
127+ temb = temb ,
128+ temb_audio = temb_audio ,
129+ temb_ca_scale_shift = temb_ca_scale_shift ,
130+ temb_ca_audio_scale_shift = temb_ca_audio_scale_shift ,
131+ temb_ca_gate = temb_ca_gate ,
132+ temb_ca_audio_gate = temb_ca_audio_gate ,
133+ video_rotary_emb = None , # Dummy takes None
134+ audio_rotary_emb = None
135+ )
136+
137+ print (f"Input Video Shape: { hidden_states .shape } " )
138+ print (f"Output Video Shape: { out_hidden .shape } " )
139+ print (f"Input Audio Shape: { audio_hidden_states .shape } " )
140+ print (f"Output Audio Shape: { out_audio .shape } " )
141+
142+ self .assertEqual (out_hidden .shape , hidden_states .shape )
143+ self .assertEqual (out_audio .shape , audio_hidden_states .shape )
127144
128145
129146 def test_transformer_3d_model_instantiation_and_forward (self ):
@@ -151,26 +168,6 @@ def test_transformer_3d_model_instantiation_and_forward(self):
151168 """
152169 print ("\n === Testing LTX2VideoTransformer3DModel Integration ===" )
153170
154- model = LTX2VideoTransformer3DModel (
155- rngs = self .rngs ,
156- in_channels = self .in_channels ,
157- out_channels = self .out_channels ,
158- patch_size = self .patch_size ,
159- patch_size_t = self .patch_size_t ,
160- num_attention_heads = 2 ,
161- attention_head_dim = 8 ,
162- num_layers = 1 , # 1 layer for speed
163- caption_channels = 32 , # small for test
164- cross_attention_dim = 32 ,
165- audio_in_channels = self .audio_in_channels ,
166- audio_out_channels = self .audio_in_channels ,
167- audio_num_attention_heads = 2 ,
168- audio_attention_head_dim = 8 ,
169- audio_cross_attention_dim = 32
170- )
171-
172- # Inputs
173- # hidden_states: (B, F, H, W, C) or (B, L, C)?
174171 # Diffusers `forward` takes `hidden_states` usually as latents.
175172 # If it's 3D, it might expect (B, C, F, H, W) or (B, F, C, H, W)?
176173 # Checking `transformer_ltx2.py` `__call__` Line 680:
0 commit comments