Skip to content

Commit 8722641

Browse files
committed
mesh init in test
1 parent 7bd49ec commit 8722641

1 file changed

Lines changed: 87 additions & 90 deletions

File tree

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 87 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
import jax
44
import jax.numpy as jnp
55
from flax import nnx
6+
from flax.linen import partitioning as nn_partitioning
7+
from jax.sharding import Mesh
8+
import os
9+
from maxdiffusion import pyconfig
10+
from maxdiffusion.max_utils import create_device_mesh
611
from maxdiffusion.models.ltx_2.transformer_ltx2 import LTX2VideoTransformerBlock, LTX2VideoTransformer3DModel
712

813
class LTX2TransformerTest(unittest.TestCase):
@@ -13,9 +18,19 @@ class LTX2TransformerTest(unittest.TestCase):
1318
"""
1419

1520
def setUp(self):
21+
# Initialize config and mesh for sharding
22+
# using standard MaxDiffusion pattern
23+
pyconfig.initialize(
24+
[None, os.path.join(os.path.dirname(__file__), "..", "configs", "ltx_video.yml")],
25+
unittest=True,
26+
)
27+
self.config = pyconfig.config
28+
devices_array = create_device_mesh(self.config)
29+
self.mesh = Mesh(devices_array, self.config.mesh_axes)
30+
1631
# random seed for reproducibility
1732
self.rngs = nnx.Rngs(0)
18-
self.batch_size = 2
33+
self.batch_size = 1 # Use 1 for determinism in unit tests often easier
1934
self.num_frames = 4
2035
self.height = 32
2136
self.width = 32
@@ -54,76 +69,78 @@ def test_transformer_block_shapes(self):
5469
audio_dim = 16
5570
cross_dim = 20 # context dim
5671

57-
block = LTX2VideoTransformerBlock(
58-
rngs=self.rngs,
59-
dim=dim,
60-
num_attention_heads=4,
61-
attention_head_dim=8,
62-
cross_attention_dim=cross_dim,
63-
audio_dim=audio_dim,
64-
audio_num_attention_heads=4,
65-
audio_attention_head_dim=4,
66-
audio_cross_attention_dim=cross_dim,
67-
activation_fn="gelu",
68-
qk_norm="rms_norm_across_heads",
69-
)
70-
71-
# Create dummy inputs
72-
hidden_states = jnp.zeros((self.batch_size, self.seq_len, dim))
73-
audio_hidden_states = jnp.zeros((self.batch_size, 10, audio_dim)) # 10 audio frames
74-
encoder_hidden_states = jnp.zeros((self.batch_size, 5, cross_dim))
75-
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 5, cross_dim)) # reusing cross_dim for audio context
76-
77-
# Dummy scale/shift/gate modulations
78-
# These match the shapes expected by the block internal calculation logic
79-
# For simplicity, we create them to match 'temb_reshaped' broadcasting or direct add
80-
# The block expects raw scale/shift/gate inputs often, OR temb vectors.
81-
# Let's check block calls:
82-
# It takes `temb` and `temb_ca...`
83-
# temb: (B, 1, 6, -1) or similar depending on reshape.
84-
# Actually in `transformer_ltx2.py`, call signature takes:
85-
# temb: jax.Array
86-
# And reshapes it: temb.reshape(batch_size, 1, num_ada_params, -1)
87-
# So input `temb` should be (B, num_ada_params * dim) roughly, or (B, num_ada_params, dim)
88-
89-
num_ada_params = 6
90-
te_dim = num_ada_params * dim # simplified assumption for test
91-
temb = jnp.zeros((self.batch_size, te_dim))
92-
93-
num_audio_ada_params = 6
94-
te_audio_dim = num_audio_ada_params * audio_dim
95-
temb_audio = jnp.zeros((self.batch_size, te_audio_dim))
96-
97-
# CA modulations
98-
# 4 params for scale/shift, 1 for gate
99-
temb_ca_scale_shift = jnp.zeros((self.batch_size, 4 * dim))
100-
temb_ca_audio_scale_shift = jnp.zeros((self.batch_size, 4 * audio_dim))
101-
temb_ca_gate = jnp.zeros((self.batch_size, 1 * dim))
102-
temb_ca_audio_gate = jnp.zeros((self.batch_size, 1 * audio_dim))
72+
# NNX sharding context
73+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
74+
block = LTX2VideoTransformerBlock(
75+
rngs=self.rngs,
76+
dim=dim,
77+
num_attention_heads=4,
78+
attention_head_dim=8,
79+
cross_attention_dim=cross_dim,
80+
audio_dim=audio_dim,
81+
audio_num_attention_heads=4,
82+
audio_attention_head_dim=4,
83+
audio_cross_attention_dim=cross_dim,
84+
activation_fn="gelu",
85+
qk_norm="rms_norm_across_heads",
86+
)
87+
88+
# Create dummy inputs
89+
hidden_states = jnp.zeros((self.batch_size, self.seq_len, dim))
90+
audio_hidden_states = jnp.zeros((self.batch_size, 10, audio_dim)) # 10 audio frames
91+
encoder_hidden_states = jnp.zeros((self.batch_size, 5, cross_dim))
92+
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 5, cross_dim)) # reusing cross_dim for audio context
93+
94+
# Dummy scale/shift/gate modulations
95+
# These match the shapes expected by the block internal calculation logic
96+
# For simplicity, we create them to match 'temb_reshaped' broadcasting or direct add
97+
# The block expects raw scale/shift/gate inputs often, OR temb vectors.
98+
# Let's check block calls:
99+
# It takes `temb` and `temb_ca...`
100+
# temb: (B, 1, 6, -1) or similar depending on reshape.
101+
# Actually in `transformer_ltx2.py`, call signature takes:
102+
# temb: jax.Array
103+
# And reshapes it: temb.reshape(batch_size, 1, num_ada_params, -1)
104+
# So input `temb` should be (B, num_ada_params * dim) roughly, or (B, num_ada_params, dim)
105+
106+
num_ada_params = 6
107+
te_dim = num_ada_params * dim # simplified assumption for test
108+
temb = jnp.zeros((self.batch_size, te_dim))
109+
110+
num_audio_ada_params = 6
111+
te_audio_dim = num_audio_ada_params * audio_dim
112+
temb_audio = jnp.zeros((self.batch_size, te_audio_dim))
113+
114+
# CA modulations
115+
# 4 params for scale/shift, 1 for gate
116+
temb_ca_scale_shift = jnp.zeros((self.batch_size, 4 * dim))
117+
temb_ca_audio_scale_shift = jnp.zeros((self.batch_size, 4 * audio_dim))
118+
temb_ca_gate = jnp.zeros((self.batch_size, 1 * dim))
119+
temb_ca_audio_gate = jnp.zeros((self.batch_size, 1 * audio_dim))
103120

104-
# Perform forward
105-
out_hidden, out_audio = block(
106-
hidden_states=hidden_states,
107-
audio_hidden_states=audio_hidden_states,
108-
encoder_hidden_states=encoder_hidden_states,
109-
audio_encoder_hidden_states=audio_encoder_hidden_states,
110-
temb=temb,
111-
temb_audio=temb_audio,
112-
temb_ca_scale_shift=temb_ca_scale_shift,
113-
temb_ca_audio_scale_shift=temb_ca_audio_scale_shift,
114-
temb_ca_gate=temb_ca_gate,
115-
temb_ca_audio_gate=temb_ca_audio_gate,
116-
video_rotary_emb=None, # Dummy takes None
117-
audio_rotary_emb=None
118-
)
119-
120-
print(f"Input Video Shape: {hidden_states.shape}")
121-
print(f"Output Video Shape: {out_hidden.shape}")
122-
print(f"Input Audio Shape: {audio_hidden_states.shape}")
123-
print(f"Output Audio Shape: {out_audio.shape}")
124-
125-
self.assertEqual(out_hidden.shape, hidden_states.shape)
126-
self.assertEqual(out_audio.shape, audio_hidden_states.shape)
121+
# Perform forward
122+
out_hidden, out_audio = block(
123+
hidden_states=hidden_states,
124+
audio_hidden_states=audio_hidden_states,
125+
encoder_hidden_states=encoder_hidden_states,
126+
audio_encoder_hidden_states=audio_encoder_hidden_states,
127+
temb=temb,
128+
temb_audio=temb_audio,
129+
temb_ca_scale_shift=temb_ca_scale_shift,
130+
temb_ca_audio_scale_shift=temb_ca_audio_scale_shift,
131+
temb_ca_gate=temb_ca_gate,
132+
temb_ca_audio_gate=temb_ca_audio_gate,
133+
video_rotary_emb=None, # Dummy takes None
134+
audio_rotary_emb=None
135+
)
136+
137+
print(f"Input Video Shape: {hidden_states.shape}")
138+
print(f"Output Video Shape: {out_hidden.shape}")
139+
print(f"Input Audio Shape: {audio_hidden_states.shape}")
140+
print(f"Output Audio Shape: {out_audio.shape}")
141+
142+
self.assertEqual(out_hidden.shape, hidden_states.shape)
143+
self.assertEqual(out_audio.shape, audio_hidden_states.shape)
127144

128145

129146
def test_transformer_3d_model_instantiation_and_forward(self):
@@ -151,26 +168,6 @@ def test_transformer_3d_model_instantiation_and_forward(self):
151168
"""
152169
print("\n=== Testing LTX2VideoTransformer3DModel Integration ===")
153170

154-
model = LTX2VideoTransformer3DModel(
155-
rngs=self.rngs,
156-
in_channels=self.in_channels,
157-
out_channels=self.out_channels,
158-
patch_size=self.patch_size,
159-
patch_size_t=self.patch_size_t,
160-
num_attention_heads=2,
161-
attention_head_dim=8,
162-
num_layers=1, # 1 layer for speed
163-
caption_channels=32, # small for test
164-
cross_attention_dim=32,
165-
audio_in_channels=self.audio_in_channels,
166-
audio_out_channels= self.audio_in_channels,
167-
audio_num_attention_heads=2,
168-
audio_attention_head_dim=8,
169-
audio_cross_attention_dim=32
170-
)
171-
172-
# Inputs
173-
# hidden_states: (B, F, H, W, C) or (B, L, C)?
174171
# Diffusers `forward` takes `hidden_states` usually as latents.
175172
# If it's 3D, it might expect (B, C, F, H, W) or (B, F, C, H, W)?
176173
# Checking `transformer_ltx2.py` `__call__` Line 680:

0 commit comments

Comments
 (0)