Skip to content

Commit e2a78ae

Browse files
committed
debug added in transformer_wan.py
1 parent 7012885 commit e2a78ae

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,18 @@ def __call__(
593593
) -> Union[jax.Array, Dict[str, jax.Array]]:
594594
print(f"[DEBUG] WanModel __call__ hidden_states IN shape: {hidden_states.shape}")
595595
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
596-
batch_size, _, num_frames, height, width = hidden_states.shape
597-
print(f"[DEBUG] WanModel __call__ unpacked: B={batch_size}, C={c}, T={num_frames}, H={height}, W={width}")
596+
dim0, dim1, dim2, dim3, dim4 = hidden_states.shape
597+
print(f"[DEBUG] WanModel __call__ unpacked: dim0={dim0}, dim1={dim1}, dim2={dim2}, dim3={dim3}, dim4={dim4}")
598+
599+
batch_size = dim0
600+
c = dim1 # This is ACTUALLY Time
601+
num_frames = dim2 # This is ACTUALLY Height
602+
height = dim3 # This is ACTUALLY Width
603+
width = dim4 # This is ACTUALLY Channels
604+
605+
606+
# batch_size, _, num_frames, height, width = hidden_states.shape
607+
print(f"[DEBUG] WanModel __call__ INTERPRETED as B,C,T,H,W: B={batch_size}, C={c}, T={num_frames}, H={height}, W={width}")
598608
p_t, p_h, p_w = self.config.patch_size
599609
post_patch_num_frames = num_frames // p_t
600610
post_patch_height = height // p_h

0 commit comments

Comments
 (0)