2323from flax import nnx
2424import numpy as np
2525import unittest
26- import pytest
2726from absl .testing import absltest
2827from skimage .metrics import structural_similarity as ssim
2928from ..models .wan .autoencoder_kl_wan import (
@@ -172,7 +171,7 @@ def test_wanrms_norm(self):
172171 dummy_input = jnp .ones (input_shape )
173172 output = wanrms_norm (dummy_input )
174173 output_np = np .array (output )
175- assert np .allclose (output_np , torch_output_np ) == True
174+ assert np .allclose (output_np , torch_output_np ) is True
176175
177176 # --- Test Case 2: images == False ---
178177 model = TorchWanRMS_norm (dim , images = False )
@@ -186,7 +185,7 @@ def test_wanrms_norm(self):
186185 dummy_input = jnp .ones (input_shape )
187186 output = wanrms_norm (dummy_input )
188187 output_np = np .array (output )
189- assert np .allclose (output_np , torch_output_np ) == True
188+ assert np .allclose (output_np , torch_output_np ) is True
190189
191190 def test_zero_padded_conv (self ):
192191
@@ -235,8 +234,6 @@ def test_wan_resample(self):
235234 w = 720
236235 mode = "downsample2d"
237236 input_shape = (batch , dim , t , h , w )
238- expected_output_shape = (1 , dim , 1 , 240 , 360 )
239- # output dim should be (1, 96, 1, 480, 720)
240237 dummy_input = torch .ones (input_shape )
241238 torch_wan_resample = TorchWanResample (dim = dim , mode = mode )
242239 torch_output = torch_wan_resample (dummy_input )
@@ -426,7 +423,7 @@ def vae_encode(video, wan_vae, vae_cache, key):
426423 rngs = nnx .Rngs (key )
427424 wan_vae = AutoencoderKLWan .from_config (pretrained_model_name_or_path , subfolder = "vae" , rngs = rngs )
428425 vae_cache = AutoencoderKLWanCache (wan_vae )
429- video_path , fps = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" , 8
426+ video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
430427 video = load_video (video_path )
431428
432429 vae_scale_factor_spatial = 2 ** len (wan_vae .temperal_downsample )
0 commit comments