@@ -503,9 +503,7 @@ def fake_encode(video, dtype):
503503 latent_t = (video .shape [2 ] - 1 ) // self .max_pipeline .vae_scale_factor_temporal + 1
504504 latent_h = video .shape [3 ] // self .max_pipeline .vae_scale_factor_spatial
505505 latent_w = video .shape [4 ] // self .max_pipeline .vae_scale_factor_spatial
506- return jnp .zeros (
507- (video .shape [0 ], latent_t , latent_h , latent_w , self .max_pipeline .vae .z_dim ), dtype = jnp .float32
508- )
506+ return jnp .zeros ((video .shape [0 ], latent_t , latent_h , latent_w , self .max_pipeline .vae .z_dim ), dtype = jnp .float32 )
509507
510508 self .max_pipeline ._encode_video_to_latents = fake_encode
511509
@@ -722,9 +720,7 @@ def __call__(
722720 np .testing .assert_allclose (to_numpy (capture ["pose_hidden_states" ]), to_numpy (expected_pose ), atol = 0.0 , rtol = 0.0 )
723721 np .testing .assert_allclose (to_numpy (capture ["face_pixel_values" ]), to_numpy (face_video ), atol = 0.0 , rtol = 0.0 )
724722 np .testing .assert_allclose (to_numpy (capture ["encoder_hidden_states" ]), to_numpy (prompt_embeds ), atol = 0.0 , rtol = 0.0 )
725- np .testing .assert_allclose (
726- to_numpy (capture ["encoder_hidden_states_image" ]), to_numpy (image_embeds ), atol = 0.0 , rtol = 0.0
727- )
723+ np .testing .assert_allclose (to_numpy (capture ["encoder_hidden_states_image" ]), to_numpy (image_embeds ), atol = 0.0 , rtol = 0.0 )
728724 self .assertEqual (capture ["motion_encode_batch_size" ], 7 )
729725 self .assertFalse (capture ["return_dict" ])
730726 np .testing .assert_allclose (to_numpy (noise_pred ), to_numpy (latents ), atol = 0.0 , rtol = 0.0 )
@@ -872,9 +868,7 @@ def test_flax_unipc_flow_sigmas_match_diffusers(self):
872868 max_model_output = jnp .array (to_numpy (hf_model_output ))
873869
874870 hf_sample = hf_scheduler .step (hf_model_output , int (timestep ), hf_sample , return_dict = False )[0 ]
875- max_sample , max_state = max_scheduler .step (
876- max_state , max_model_output , int (timestep ), max_sample , return_dict = False
877- )
871+ max_sample , max_state = max_scheduler .step (max_state , max_model_output , int (timestep ), max_sample , return_dict = False )
878872
879873 np .testing .assert_allclose (to_numpy (max_sample ), to_numpy (hf_sample ), atol = 1e-4 , rtol = 1e-5 )
880874
0 commit comments