@@ -71,14 +71,25 @@ def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], ma
7171 self .use_real = use_real
7272
7373 def __call__ (self , hidden_states : jax .Array ) -> jax .Array :
74+ print (f"[DEBUG] WanRotaryPosEmbed hidden_states shape: { hidden_states .shape } " )
7475 _ , num_frames , height , width , _ = hidden_states .shape
76+ print (f"[DEBUG] WanRotaryPosEmbed unpacked shapes: num_frames={ num_frames } , height={ height } , width={ width } " )
7577 p_t , p_h , p_w = self .patch_size
78+ print (f"[DEBUG] WanRotaryPosEmbed patch_size: p_t={ p_t } , p_h={ p_h } , p_w={ p_w } " )
7679 ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
80+ print (f"[DEBUG] WanRotaryPosEmbed ppf={ ppf } , pph={ pph } , ppw={ ppw } " )
7781
7882 freqs_split = get_frequencies (self .max_seq_len , self .theta , self .attention_head_dim , self .use_real )
83+ print (f"[DEBUG] WanRotaryPosEmbed freqs_split shapes: { [f .shape for f in freqs_split ] } " )
7984
8085 freqs_f = jnp .expand_dims (jnp .expand_dims (freqs_split [0 ][:ppf ], axis = 1 ), axis = 1 )
81- freqs_f = jnp .broadcast_to (freqs_f , (ppf , pph , ppw , freqs_split [0 ].shape [- 1 ]))
86+ print (f"[DEBUG] WanRotaryPosEmbed freqs_f shape BEFORE broadcast: { freqs_f .shape } " )
87+
88+ target_shape_f = (ppf , pph , ppw , freqs_split [0 ].shape [- 1 ])
89+ print (f"[DEBUG] WanRotaryPosEmbed freqs_f TARGET shape: { target_shape_f } " )
90+
91+ freqs_f = jnp .broadcast_to (freqs_f , target_shape_f )
92+ print (f"[DEBUG] WanRotaryPosEmbed freqs_f shape AFTER broadcast: { freqs_f .shape } " )
8293
8394 freqs_h = jnp .expand_dims (jnp .expand_dims (freqs_split [1 ][:pph ], axis = 0 ), axis = 2 )
8495 freqs_h = jnp .broadcast_to (freqs_h , (ppf , pph , ppw , freqs_split [1 ].shape [- 1 ]))
@@ -580,17 +591,22 @@ def __call__(
580591 deterministic : bool = True ,
581592 rngs : nnx .Rngs = None ,
582593 ) -> Union [jax .Array , Dict [str , jax .Array ]]:
594+ print (f"[DEBUG] WanModel __call__ hidden_states IN shape: { hidden_states .shape } " )
583595 hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
584596 batch_size , _ , num_frames , height , width = hidden_states .shape
597+ print (f"[DEBUG] WanModel __call__ unpacked: B={ batch_size } , C={ c } , T={ num_frames } , H={ height } , W={ width } " )
585598 p_t , p_h , p_w = self .config .patch_size
586599 post_patch_num_frames = num_frames // p_t
587600 post_patch_height = height // p_h
588601 post_patch_width = width // p_w
589602
590603 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
604+ print (f"[DEBUG] WanModel __call__ hidden_states AFTER transpose: { hidden_states .shape } " )
591605 rotary_emb = self .rope (hidden_states )
606+ print (f"[DEBUG] WanModel __call__ rotary_emb shape: { rotary_emb .shape } " )
592607
593608 hidden_states = self .patch_embedding (hidden_states )
609+ print (f"[DEBUG] WanModel __call__ hidden_states after patch_embedding: { hidden_states .shape } " )
594610 hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
595611 if timestep .ndim == 2 :
596612 ts_seq_len = timestep .shape [1 ]
0 commit comments