@@ -52,7 +52,7 @@ def transfer_conv_weights(pt_conv, jax_conv):
5252 else :
5353 jax_conv .kernel [...] = jnp .array (pt_conv .weight .detach ().numpy ().transpose (2 , 3 , 1 , 0 ))
5454 if pt_conv .bias is not None :
55- jax_conv .bias . value = jnp .array (pt_conv .bias .detach ().numpy ())
55+ jax_conv .bias [...] = jnp .array (pt_conv .bias .detach ().numpy ())
5656
5757
5858def transfer_linear_weights (pt_linear , jax_linear ):
@@ -67,7 +67,7 @@ def transfer_linear_weights(pt_linear, jax_linear):
6767 elif hasattr (jax_linear , "kernel" ):
6868 jax_linear .kernel [...] = jnp .array (pt_linear .weight .detach ().numpy ().T )
6969 if pt_linear .bias is not None :
70- jax_linear .bias . value = jnp .array (pt_linear .bias .detach ().numpy ())
70+ jax_linear .bias [...] = jnp .array (pt_linear .bias .detach ().numpy ())
7171
7272
7373def transfer_transformer_weights (pt_model , jax_model ):
@@ -273,50 +273,50 @@ def test_motion_conv_equivalence(self):
273273 max_diff = np .max (np .abs (np_out_pt - np_out_jax ))
274274 print (f"Max absolute difference: { max_diff :.8f} " )
275275
276- assert np .allclose (np_out_pt , np_out_jax , atol = 1e-5 ), f"Outputs do not match! max_diff={ max_diff } "
276+ np .testing . assert_allclose (np_out_pt , np_out_jax , rtol = 1e-3 , atol = 5e-3 , err_msg = f"Outputs do not match! max_diff={ max_diff } " )
277277
278278 def test_fused_leaky_relu_shape (self ):
279279 rngs = nnx .Rngs (0 )
280280 x = jnp .ones ((2 , 4 , 16 , 16 ))
281281 model = FlaxFusedLeakyReLU (rngs = rngs , bias_channels = 4 )
282282 out = model (x )
283- assert out .shape == x .shape
283+ np . testing . assert_equal ( out .shape , x .shape )
284284
285285 def test_motion_linear_shape (self ):
286286 rngs = nnx .Rngs (0 )
287287 x = jnp .ones ((2 , 4 ))
288288 model = FlaxMotionLinear (rngs = rngs , in_dim = 4 , out_dim = 8 )
289289 out = model (x )
290- assert out .shape == (2 , 8 )
290+ np . testing . assert_equal ( out .shape , (2 , 8 ) )
291291
292292 def test_motion_encoder_res_block_shape (self ):
293293 rngs = nnx .Rngs (0 )
294294 x = jnp .ones ((2 , 4 , 16 , 16 ))
295295 model = FlaxMotionEncoderResBlock (rngs = rngs , in_channels = 4 , out_channels = 8 )
296296 out = model (x )
297- assert out .shape == (2 , 8 , 8 , 8 )
297+ np . testing . assert_equal ( out .shape , (2 , 8 , 8 , 8 ) )
298298
299299 def test_wan_animate_motion_encoder_shape (self ):
300300 rngs = nnx .Rngs (0 )
301301 x = jnp .ones ((2 , 3 , 512 , 512 )) # size size
302302 model = FlaxWanAnimateMotionEncoder (rngs = rngs , size = 512 , style_dim = 512 , motion_dim = 20 , out_dim = 512 )
303303 out = model (x )
304- assert out .shape == (2 , 512 )
304+ np . testing . assert_equal ( out .shape , (2 , 512 ) )
305305
306306 def test_wan_animate_face_encoder_shape (self ):
307307 rngs = nnx .Rngs (0 )
308308 x = jnp .ones ((2 , 10 , 512 )) # Batch, Time, Dim
309309 model = FlaxWanAnimateFaceEncoder (rngs = rngs , in_dim = 512 , out_dim = 512 , num_heads = 4 )
310310 out = model (x )
311- assert out .shape == (2 , 3 , 5 , 512 )
311+ np . testing . assert_equal ( out .shape , (2 , 3 , 5 , 512 ) )
312312
313313 def test_wan_animate_face_block_cross_attention_shape (self ):
314314 rngs = nnx .Rngs (0 )
315315 hidden_states = jnp .ones ((2 , 10 , 512 )) # B, Q_len, Dim
316316 encoder_hidden_states = jnp .ones ((2 , 1 , 5 , 512 )) # B, T, N, Dim
317317 model = FlaxWanAnimateFaceBlockCrossAttention (rngs = rngs , dim = 512 , heads = 8 )
318318 out = model (hidden_states , encoder_hidden_states )
319- assert out .shape == hidden_states .shape
319+ np . testing . assert_equal ( out .shape , hidden_states .shape )
320320
321321 def test_nnx_wan_animate_transformer_3d_model_shape (self ):
322322 rngs = nnx .Rngs (0 )
@@ -370,7 +370,7 @@ def test_nnx_wan_animate_transformer_3d_model_shape(self):
370370 )
371371 if isinstance (out , (list , tuple )):
372372 out = out [0 ]
373- assert out .shape == (batch_size , 16 , num_frames , height , width )
373+ np . testing . assert_equal ( out .shape , (batch_size , 16 , num_frames , height , width ) )
374374
375375 def test_nnx_wan_animate_transformer_3d_model_shape_with_face (self ):
376376 rngs = nnx .Rngs (0 )
@@ -424,7 +424,7 @@ def test_nnx_wan_animate_transformer_3d_model_shape_with_face(self):
424424 )
425425 if isinstance (out , (list , tuple )):
426426 out = out [0 ]
427- assert out .shape == (batch_size , 16 , num_frames , height , width )
427+ np . testing . assert_equal ( out .shape , (batch_size , 16 , num_frames , height , width ) )
428428
429429 def test_equivalence_motion_encoder (self ):
430430 from diffusers .models .transformers .transformer_wan_animate import (
@@ -686,5 +686,5 @@ def test_equivalence_wan_animate_transformer(self):
686686 np_pt = pt_out .detach ().numpy ()
687687 np_jax = np .array (jax_out )
688688
689- assert np_pt .shape == np_jax .shape
689+ np . testing . assert_equal ( np_pt .shape , np_jax .shape )
690690 np .testing .assert_allclose (np_pt , np_jax , rtol = 1e-4 , atol = 1e-4 )
0 commit comments