@@ -273,7 +273,9 @@ def test_motion_conv_equivalence(self):
273273 max_diff = np .max (np .abs (np_out_pt - np_out_jax ))
274274 print (f"Max absolute difference: { max_diff :.8f} " )
275275
276- np .testing .assert_allclose (np_out_pt , np_out_jax , rtol = 1e-3 , atol = 5e-3 , err_msg = f"Outputs do not match! max_diff={ max_diff } " )
276+ np .testing .assert_allclose (
277+ np_out_pt , np_out_jax , rtol = 1e-3 , atol = 5e-3 , err_msg = f"Outputs do not match! max_diff={ max_diff } "
278+ )
277279
278280 def test_fused_leaky_relu_shape (self ):
279281 rngs = nnx .Rngs (0 )
@@ -461,7 +463,7 @@ def test_equivalence_motion_encoder(self):
461463
462464 jax_out = jax_model (jax_input )
463465
464- np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), rtol = 1e-4 , atol = 1e-4 )
466+ np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), rtol = 1e-3 , atol = 5e-3 )
465467
466468 def test_equivalence_face_encoder (self ):
467469 from diffusers .models .transformers .transformer_wan_animate import (
@@ -503,8 +505,8 @@ def test_equivalence_face_encoder(self):
503505 np .testing .assert_allclose (
504506 pt_out .numpy (),
505507 np .array (jax_out ),
506- rtol = 1e-4 ,
507- atol = 1e -3 , # Slightly higher tolerance for convolutions
508+ rtol = 1e-3 ,
509+ atol = 5e -3 , # Slightly higher tolerance for convolutions
508510 )
509511
510512 def test_equivalence_face_block_cross_attention (self ):
@@ -557,7 +559,7 @@ def test_equivalence_face_block_cross_attention(self):
557559
558560 jax_out = jax_model (hidden_states_jax , encoder_hidden_states_jax )
559561
560- np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), rtol = 1e-4 , atol = 1e-4 )
562+ np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), rtol = 1e-3 , atol = 5e-3 )
561563
562564 def test_equivalence_wan_animate_transformer (self ):
563565 from diffusers .models .transformers .transformer_wan_animate import (
@@ -687,4 +689,4 @@ def test_equivalence_wan_animate_transformer(self):
687689 np_jax = np .array (jax_out )
688690
689691 np .testing .assert_equal (np_pt .shape , np_jax .shape )
690- np .testing .assert_allclose (np_pt , np_jax , rtol = 1e-4 , atol = 1e-4 )
692+ np .testing .assert_allclose (np_pt , np_jax , rtol = 1e-3 , atol = 5e-3 )
0 commit comments