Skip to content

Commit ad5d293

Browse files
committed
debug
1 parent 3646f84 commit ad5d293

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx_2/transformer_ltx2.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,15 @@ def __call__(
297297

298298
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
299299
audio_scale_shift_table_reshaped = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1))
300+
301+
print(f"DEBUG_BLOCK_AUDIO: audio_scale_shift_table_reshaped shape: {audio_scale_shift_table_reshaped.shape}")
302+
print(f"DEBUG_BLOCK_AUDIO: temb_audio shape before reshape: {temb_audio.shape}")
303+
sys.stdout.flush()
304+
300305
temb_audio_reshaped = temb_audio.reshape(batch_size, 1, num_audio_ada_params, -1)
306+
307+
print(f"DEBUG_BLOCK_AUDIO: temb_audio_reshaped shape: {temb_audio_reshaped.shape}")
308+
sys.stdout.flush()
301309
audio_ada_values = audio_scale_shift_table_reshaped + temb_audio_reshaped
302310

303311
audio_shift_msa = audio_ada_values[:, :, 0, :]
@@ -512,6 +520,7 @@ def __init__(
512520
)
513521

514522
print(f"DEBUG IN INIT: inner_dim={inner_dim}, num_attention_heads={num_attention_heads}, attention_head_dim={attention_head_dim}")
523+
sys.stdout.flush()
515524

516525
# 3. Timestep Modulation Params and Embedding
517526
self.time_embed = LTX2AdaLayerNormSingle(

0 commit comments

Comments
 (0)