Skip to content

Commit e7047f7

Browse files
committed
debug
1 parent 122ef73 commit e7047f7

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx_2/transformer_ltx2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,13 @@ def __call__(
263263
# table shape: (6, dim) -> (1, 1, 6, dim)
264264
scale_shift_table_reshaped = jnp.expand_dims(self.scale_shift_table, axis=(0, 1))
265265
# temb shape: (batch, temb_dim) -> (batch, 1, 6, dim) (assuming temb_dim is num_ada_params * dim)
266+
print(f"DEBUG: scale_shift_table_reshaped shape: {scale_shift_table_reshaped.shape}")
267+
print(f"DEBUG: temb shape before reshape: {temb.shape}")
268+
266269
temb_reshaped = temb.reshape(batch_size, 1, num_ada_params, -1)
270+
271+
print(f"DEBUG: temb_reshaped shape: {temb_reshaped.shape}")
272+
267273
ada_values = scale_shift_table_reshaped + temb_reshaped
268274

269275
shift_msa = ada_values[:, :, 0, :]

0 commit comments

Comments
 (0)