Skip to content

Commit ec9554a

Browse files
committed
removed debug in transformer_wan.py
1 parent 085774a commit ec9554a

1 file changed

Lines changed: 3 additions & 29 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,14 @@ def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], ma
7171
self.use_real = use_real
7272

7373
def __call__(self, hidden_states: jax.Array) -> jax.Array:
74-
print(f"[DEBUG] WanRotaryPosEmbed hidden_states shape: {hidden_states.shape}")
7574
_, num_frames, height, width, _ = hidden_states.shape
76-
print(f"[DEBUG] WanRotaryPosEmbed unpacked shapes: num_frames={num_frames}, height={height}, width={width}")
7775
p_t, p_h, p_w = self.patch_size
78-
print(f"[DEBUG] WanRotaryPosEmbed patch_size: p_t={p_t}, p_h={p_h}, p_w={p_w}")
7976
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
80-
print(f"[DEBUG] WanRotaryPosEmbed ppf={ppf}, pph={pph}, ppw={ppw}")
8177

8278
freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim, self.use_real)
83-
print(f"[DEBUG] WanRotaryPosEmbed freqs_split shapes: { [f.shape for f in freqs_split] }")
8479

8580
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
86-
print(f"[DEBUG] WanRotaryPosEmbed freqs_f shape BEFORE broadcast: {freqs_f.shape}")
87-
88-
target_shape_f = (ppf, pph, ppw, freqs_split[0].shape[-1])
89-
print(f"[DEBUG] WanRotaryPosEmbed freqs_f TARGET shape: {target_shape_f}")
90-
91-
freqs_f = jnp.broadcast_to(freqs_f, target_shape_f)
92-
print(f"[DEBUG] WanRotaryPosEmbed freqs_f shape AFTER broadcast: {freqs_f.shape}")
81+
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
9382

9483
freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2)
9584
freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1]))
@@ -591,32 +580,17 @@ def __call__(
591580
deterministic: bool = True,
592581
rngs: nnx.Rngs = None,
593582
) -> Union[jax.Array, Dict[str, jax.Array]]:
594-
print(f"[DEBUG] WanModel __call__ hidden_states IN shape: {hidden_states.shape}")
595583
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
596-
dim0, dim1, dim2, dim3, dim4 = hidden_states.shape
597-
print(f"[DEBUG] WanModel __call__ unpacked: dim0={dim0}, dim1={dim1}, dim2={dim2}, dim3={dim3}, dim4={dim4}")
598-
599-
batch_size = dim0
600-
c = dim1 # This is ACTUALLY Time
601-
num_frames = dim2 # This is ACTUALLY Height
602-
height = dim3 # This is ACTUALLY Width
603-
width = dim4 # This is ACTUALLY Channels
604-
605-
606-
# batch_size, _, num_frames, height, width = hidden_states.shape
607-
print(f"[DEBUG] WanModel __call__ INTERPRETED as B,C,T,H,W: B={batch_size}, C={c}, T={num_frames}, H={height}, W={width}")
584+
batch_size, _, num_frames, height, width = hidden_states.shape
608585
p_t, p_h, p_w = self.config.patch_size
609586
post_patch_num_frames = num_frames // p_t
610587
post_patch_height = height // p_h
611588
post_patch_width = width // p_w
612589

613590
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
614-
print(f"[DEBUG] WanModel __call__ hidden_states AFTER transpose: {hidden_states.shape}")
615591
rotary_emb = self.rope(hidden_states)
616-
print(f"[DEBUG] WanModel __call__ rotary_emb shape: {rotary_emb.shape}")
617592

618593
hidden_states = self.patch_embedding(hidden_states)
619-
print(f"[DEBUG] WanModel __call__ hidden_states after patch_embedding: {hidden_states.shape}")
620594
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
621595
if timestep.ndim == 2:
622596
ts_seq_len = timestep.shape[1]
@@ -689,4 +663,4 @@ def layer_forward(hidden_states):
689663
hidden_states = jax.lax.collapse(hidden_states, 6, None)
690664
hidden_states = jax.lax.collapse(hidden_states, 4, 6)
691665
hidden_states = jax.lax.collapse(hidden_states, 2, 4)
692-
return hidden_states
666+
return hidden_states

0 commit comments

Comments
 (0)