@@ -907,8 +907,7 @@ def __call__(
907907 p_t = self .patch_size_t
908908
909909 hidden_states = sample .reshape (B , T // p_t , p_t , H // p , p , W // p , p , C )
910- # 0:B, 1:T_p, 3:H_p, 5:W_p, 7:C, 2:p_t, 4:p_h, 6:p_w
911- hidden_states = hidden_states .transpose (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 )
910+ hidden_states = hidden_states .transpose (0 , 1 , 3 , 5 , 7 , 2 , 6 , 4 )
912911 hidden_states = hidden_states .reshape (B , T // p_t , H // p , W // p , - 1 )
913912
914913 num_blocks = len (self .down_blocks ) + 1
@@ -1108,8 +1107,7 @@ def __call__(
11081107 hidden_states = hidden_states .reshape (B , T , H , W , C_out_final , p_t , p , p )
11091108
11101109 # Pair H (2) with p_h (7) and W (3) with p_w (6)
1111- # 0:B, 1:T, 5:p_t, 2:H, 6:p_h, 3:W, 7:p_w, 4:C_out_final
1112- hidden_states = hidden_states .transpose (0 , 1 , 5 , 2 , 6 , 3 , 7 , 4 )
1110+ hidden_states = hidden_states .transpose (0 , 1 , 5 , 2 , 7 , 3 , 6 , 4 )
11131111 hidden_states = hidden_states .reshape (B , T * p_t , H * p , W * p , C_out_final )
11141112
11151113 return hidden_states
0 commit comments