Skip to content

Commit 3646f84

Browse files
committed
debug
1 parent f2cbf5a commit 3646f84

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx_2/transformer_ltx2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,17 +258,21 @@ def __call__(
258258
# 1. Video and Audio Self-Attention
259259
norm_hidden_states = self.norm1(hidden_states)
260260

261+
import sys
262+
261263
# Calculate Video AdaLN values
262264
num_ada_params = self.scale_shift_table.shape[0]
263265
# table shape: (6, dim) -> (1, 1, 6, dim)
264266
scale_shift_table_reshaped = jnp.expand_dims(self.scale_shift_table, axis=(0, 1))
265267
# temb shape: (batch, temb_dim) -> (batch, 1, 6, dim) (assuming temb_dim is num_ada_params * dim)
266-
print(f"DEBUG: scale_shift_table_reshaped shape: {scale_shift_table_reshaped.shape}")
267-
print(f"DEBUG: temb shape before reshape: {temb.shape}")
268+
print(f"DEBUG_BLOCK: scale_shift_table_reshaped shape: {scale_shift_table_reshaped.shape}")
269+
print(f"DEBUG_BLOCK: temb shape before reshape: {temb.shape}")
270+
sys.stdout.flush()
268271

269272
temb_reshaped = temb.reshape(batch_size, 1, num_ada_params, -1)
270273

271-
print(f"DEBUG: temb_reshaped shape: {temb_reshaped.shape}")
274+
print(f"DEBUG_BLOCK: temb_reshaped shape: {temb_reshaped.shape}")
275+
sys.stdout.flush()
272276

273277
ada_values = scale_shift_table_reshaped + temb_reshaped
274278

0 commit comments

Comments
 (0)