@@ -71,25 +71,14 @@ def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], ma
7171 self .use_real = use_real
7272
7373 def __call__ (self , hidden_states : jax .Array ) -> jax .Array :
74- print (f"[DEBUG] WanRotaryPosEmbed hidden_states shape: { hidden_states .shape } " )
7574 _ , num_frames , height , width , _ = hidden_states .shape
76- print (f"[DEBUG] WanRotaryPosEmbed unpacked shapes: num_frames={ num_frames } , height={ height } , width={ width } " )
7775 p_t , p_h , p_w = self .patch_size
78- print (f"[DEBUG] WanRotaryPosEmbed patch_size: p_t={ p_t } , p_h={ p_h } , p_w={ p_w } " )
7976 ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
80- print (f"[DEBUG] WanRotaryPosEmbed ppf={ ppf } , pph={ pph } , ppw={ ppw } " )
8177
8278 freqs_split = get_frequencies (self .max_seq_len , self .theta , self .attention_head_dim , self .use_real )
83- print (f"[DEBUG] WanRotaryPosEmbed freqs_split shapes: { [f .shape for f in freqs_split ] } " )
8479
8580 freqs_f = jnp .expand_dims (jnp .expand_dims (freqs_split [0 ][:ppf ], axis = 1 ), axis = 1 )
86- print (f"[DEBUG] WanRotaryPosEmbed freqs_f shape BEFORE broadcast: { freqs_f .shape } " )
87-
88- target_shape_f = (ppf , pph , ppw , freqs_split [0 ].shape [- 1 ])
89- print (f"[DEBUG] WanRotaryPosEmbed freqs_f TARGET shape: { target_shape_f } " )
90-
91- freqs_f = jnp .broadcast_to (freqs_f , target_shape_f )
92- print (f"[DEBUG] WanRotaryPosEmbed freqs_f shape AFTER broadcast: { freqs_f .shape } " )
81+ freqs_f = jnp .broadcast_to (freqs_f , (ppf , pph , ppw , freqs_split [0 ].shape [- 1 ]))
9382
9483 freqs_h = jnp .expand_dims (jnp .expand_dims (freqs_split [1 ][:pph ], axis = 0 ), axis = 2 )
9584 freqs_h = jnp .broadcast_to (freqs_h , (ppf , pph , ppw , freqs_split [1 ].shape [- 1 ]))
@@ -591,32 +580,17 @@ def __call__(
591580 deterministic : bool = True ,
592581 rngs : nnx .Rngs = None ,
593582 ) -> Union [jax .Array , Dict [str , jax .Array ]]:
594- print (f"[DEBUG] WanModel __call__ hidden_states IN shape: { hidden_states .shape } " )
595583 hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
596- dim0 , dim1 , dim2 , dim3 , dim4 = hidden_states .shape
597- print (f"[DEBUG] WanModel __call__ unpacked: dim0={ dim0 } , dim1={ dim1 } , dim2={ dim2 } , dim3={ dim3 } , dim4={ dim4 } " )
598-
599- batch_size = dim0
600- c = dim1 # This is ACTUALLY Time
601- num_frames = dim2 # This is ACTUALLY Height
602- height = dim3 # This is ACTUALLY Width
603- width = dim4 # This is ACTUALLY Channels
604-
605-
606- # batch_size, _, num_frames, height, width = hidden_states.shape
607- print (f"[DEBUG] WanModel __call__ INTERPRETED as B,C,T,H,W: B={ batch_size } , C={ c } , T={ num_frames } , H={ height } , W={ width } " )
584+ batch_size , _ , num_frames , height , width = hidden_states .shape
608585 p_t , p_h , p_w = self .config .patch_size
609586 post_patch_num_frames = num_frames // p_t
610587 post_patch_height = height // p_h
611588 post_patch_width = width // p_w
612589
613590 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
614- print (f"[DEBUG] WanModel __call__ hidden_states AFTER transpose: { hidden_states .shape } " )
615591 rotary_emb = self .rope (hidden_states )
616- print (f"[DEBUG] WanModel __call__ rotary_emb shape: { rotary_emb .shape } " )
617592
618593 hidden_states = self .patch_embedding (hidden_states )
619- print (f"[DEBUG] WanModel __call__ hidden_states after patch_embedding: { hidden_states .shape } " )
620594 hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
621595 if timestep .ndim == 2 :
622596 ts_seq_len = timestep .shape [1 ]
@@ -689,4 +663,4 @@ def layer_forward(hidden_states):
689663 hidden_states = jax .lax .collapse (hidden_states , 6 , None )
690664 hidden_states = jax .lax .collapse (hidden_states , 4 , 6 )
691665 hidden_states = jax .lax .collapse (hidden_states , 2 , 4 )
692- return hidden_states
666+ return hidden_states
0 commit comments