11"""
2- Copyright 2025 Google LLC
2+ Copyright 2026 Google LLC
33
44Licensed under the Apache License, Version 2.0 (the "License");
55you may not use this file except in compliance with the License.
1616
1717import os
1818import torch
19- import torch .nn as nn
20- import torch .nn .functional as F
2119import numpy as np
2220import jax .numpy as jnp
2321from flax import nnx
24- import math
25- import sys
26- import pytest
2722
2823# Import from the codebase
2924# Make sure to set Python path or import correctly
4237from jax .sharding import Mesh
4338
4439
45-
46-
47-
4840def transfer_conv_weights (pt_conv , jax_conv ):
4941 if hasattr (jax_conv , "weight" ):
5042 jax_conv .weight [...] = jnp .array (pt_conv .weight .detach ().numpy ())
@@ -227,8 +219,6 @@ def transfer_transformer_weights(pt_model, jax_model):
227219 jax_model .scale_shift_table [...] = jnp .array (pt_model .scale_shift_table .detach ().numpy ())
228220
229221
230-
231-
232222class TestWanAnimateTransformer :
233223
234224 def test_motion_conv_equivalence (self ):
@@ -244,6 +234,7 @@ def test_motion_conv_equivalence(self):
244234 x_jax = jnp .array (x_np )
245235
246236 from diffusers .models .transformers .transformer_wan_animate import MotionConv2d
237+
247238 pt_model = MotionConv2d (
248239 in_channels = C_in ,
249240 out_channels = C_out ,
@@ -697,5 +688,3 @@ def test_equivalence_wan_animate_transformer(self):
697688
698689 assert np_pt .shape == np_jax .shape
699690 np .testing .assert_allclose (np_pt , np_jax , rtol = 1e-4 , atol = 1e-4 )
700-
701-
0 commit comments