Skip to content

Commit e0888db

Browse files
committed
fix
1 parent 0f83ffe commit e0888db

2 files changed

Lines changed: 1 addition & 23 deletions

File tree

src/maxdiffusion/models/ltx_2/transformer_ltx2.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -258,22 +258,12 @@ def __call__(
258258
# 1. Video and Audio Self-Attention
259259
norm_hidden_states = self.norm1(hidden_states)
260260

261-
import sys
262-
263261
# Calculate Video AdaLN values
264262
num_ada_params = self.scale_shift_table.shape[0]
265263
# table shape: (6, dim) -> (1, 1, 6, dim)
266264
scale_shift_table_reshaped = jnp.expand_dims(self.scale_shift_table, axis=(0, 1))
267265
# temb shape: (batch, temb_dim) -> (batch, 1, 6, dim) (assuming temb_dim is num_ada_params * dim)
268-
print(f"DEBUG_BLOCK: scale_shift_table_reshaped shape: {scale_shift_table_reshaped.shape}")
269-
print(f"DEBUG_BLOCK: temb shape before reshape: {temb.shape}")
270-
sys.stdout.flush()
271-
272266
temb_reshaped = temb.reshape(batch_size, 1, num_ada_params, -1)
273-
274-
print(f"DEBUG_BLOCK: temb_reshaped shape: {temb_reshaped.shape}")
275-
sys.stdout.flush()
276-
277267
ada_values = scale_shift_table_reshaped + temb_reshaped
278268

279269
shift_msa = ada_values[:, :, 0, :]
@@ -297,15 +287,7 @@ def __call__(
297287

298288
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
299289
audio_scale_shift_table_reshaped = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1))
300-
301-
print(f"DEBUG_BLOCK_AUDIO: audio_scale_shift_table_reshaped shape: {audio_scale_shift_table_reshaped.shape}")
302-
print(f"DEBUG_BLOCK_AUDIO: temb_audio shape before reshape: {temb_audio.shape}")
303-
sys.stdout.flush()
304-
305290
temb_audio_reshaped = temb_audio.reshape(batch_size, 1, num_audio_ada_params, -1)
306-
307-
print(f"DEBUG_BLOCK_AUDIO: temb_audio_reshaped shape: {temb_audio_reshaped.shape}")
308-
sys.stdout.flush()
309291
audio_ada_values = audio_scale_shift_table_reshaped + temb_audio_reshaped
310292

311293
audio_shift_msa = audio_ada_values[:, :, 0, :]
@@ -518,10 +500,6 @@ def __init__(
518500
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
519501
rngs=rngs, in_features=self.caption_channels, hidden_size=audio_inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype
520502
)
521-
import sys
522-
print(f"DEBUG IN INIT: inner_dim={inner_dim}, num_attention_heads={num_attention_heads}, attention_head_dim={attention_head_dim}")
523-
sys.stdout.flush()
524-
525503
# 3. Timestep Modulation Params and Embedding
526504
self.time_embed = LTX2AdaLayerNormSingle(
527505
rngs=rngs, embedding_dim=inner_dim, num_mod_params=6, use_additional_conditions=False, dtype=self.dtype, weights_dtype=self.weights_dtype

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def test_transformer_3d_model_instantiation_and_forward(self):
221221
hidden_states = jnp.zeros((self.batch_size, self.seq_len, self.in_channels))
222222
audio_hidden_states = jnp.zeros((self.batch_size, 10, self.audio_in_channels))
223223

224-
timestep = jnp.array([1.0, 2.0]) # (B,)
224+
timestep = jnp.array([1.0]) # (B,)
225225

226226
encoder_hidden_states = jnp.zeros((self.batch_size, 5, 32)) # (B, Lc, Dc)
227227
audio_encoder_hidden_states = jnp.zeros((self.batch_size, 5, 32))

0 commit comments

Comments
 (0)