Skip to content

Commit 7012885

Browse files
committed
debug added in transformer_wan.py
1 parent 2eb2008 commit 7012885

1 file changed

Lines changed: 17 additions & 1 deletion

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,25 @@ def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], ma
7171
self.use_real = use_real
7272

7373
def __call__(self, hidden_states: jax.Array) -> jax.Array:
74+
print(f"[DEBUG] WanRotaryPosEmbed hidden_states shape: {hidden_states.shape}")
7475
_, num_frames, height, width, _ = hidden_states.shape
76+
print(f"[DEBUG] WanRotaryPosEmbed unpacked shapes: num_frames={num_frames}, height={height}, width={width}")
7577
p_t, p_h, p_w = self.patch_size
78+
print(f"[DEBUG] WanRotaryPosEmbed patch_size: p_t={p_t}, p_h={p_h}, p_w={p_w}")
7679
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
80+
print(f"[DEBUG] WanRotaryPosEmbed ppf={ppf}, pph={pph}, ppw={ppw}")
7781

7882
freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim, self.use_real)
83+
print(f"[DEBUG] WanRotaryPosEmbed freqs_split shapes: { [f.shape for f in freqs_split] }")
7984

8085
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
81-
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
86+
print(f"[DEBUG] WanRotaryPosEmbed freqs_f shape BEFORE broadcast: {freqs_f.shape}")
87+
88+
target_shape_f = (ppf, pph, ppw, freqs_split[0].shape[-1])
89+
print(f"[DEBUG] WanRotaryPosEmbed freqs_f TARGET shape: {target_shape_f}")
90+
91+
freqs_f = jnp.broadcast_to(freqs_f, target_shape_f)
92+
print(f"[DEBUG] WanRotaryPosEmbed freqs_f shape AFTER broadcast: {freqs_f.shape}")
8293

8394
freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2)
8495
freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1]))
@@ -580,17 +591,22 @@ def __call__(
580591
deterministic: bool = True,
581592
rngs: nnx.Rngs = None,
582593
) -> Union[jax.Array, Dict[str, jax.Array]]:
594+
print(f"[DEBUG] WanModel __call__ hidden_states IN shape: {hidden_states.shape}")
583595
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
584596
batch_size, _, num_frames, height, width = hidden_states.shape
597+
print(f"[DEBUG] WanModel __call__ unpacked: B={batch_size}, C={c}, T={num_frames}, H={height}, W={width}")
585598
p_t, p_h, p_w = self.config.patch_size
586599
post_patch_num_frames = num_frames // p_t
587600
post_patch_height = height // p_h
588601
post_patch_width = width // p_w
589602

590603
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
604+
print(f"[DEBUG] WanModel __call__ hidden_states AFTER transpose: {hidden_states.shape}")
591605
rotary_emb = self.rope(hidden_states)
606+
print(f"[DEBUG] WanModel __call__ rotary_emb shape: {rotary_emb.shape}")
592607

593608
hidden_states = self.patch_embedding(hidden_states)
609+
print(f"[DEBUG] WanModel __call__ hidden_states after patch_embedding: {hidden_states.shape}")
594610
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
595611
if timestep.ndim == 2:
596612
ts_seq_len = timestep.shape[1]

0 commit comments

Comments
 (0)