@@ -593,8 +593,18 @@ def __call__(
593593 ) -> Union [jax .Array , Dict [str , jax .Array ]]:
594594 print (f"[DEBUG] WanModel __call__ hidden_states IN shape: { hidden_states .shape } " )
595595 hidden_states = nn .with_logical_constraint (hidden_states , ("batch" , None , None , None , None ))
596- batch_size , _ , num_frames , height , width = hidden_states .shape
597- print (f"[DEBUG] WanModel __call__ unpacked: B={ batch_size } , C={ c } , T={ num_frames } , H={ height } , W={ width } " )
596+ dim0 , dim1 , dim2 , dim3 , dim4 = hidden_states .shape
597+ print (f"[DEBUG] WanModel __call__ unpacked: dim0={ dim0 } , dim1={ dim1 } , dim2={ dim2 } , dim3={ dim3 } , dim4={ dim4 } " )
598+
599+ batch_size = dim0
600+ c = dim1 # This is ACTUALLY Time
601+ num_frames = dim2 # This is ACTUALLY Height
602+ height = dim3 # This is ACTUALLY Width
603+ width = dim4 # This is ACTUALLY Channels
604+
605+
606+ # batch_size, _, num_frames, height, width = hidden_states.shape
607+ print (f"[DEBUG] WanModel __call__ INTERPRETED as B,C,T,H,W: B={ batch_size } , C={ c } , T={ num_frames } , H={ height } , W={ width } " )
598608 p_t , p_h , p_w = self .config .patch_size
599609 post_patch_num_frames = num_frames // p_t
600610 post_patch_height = height // p_h
0 commit comments