4646from ..models .wan .wan_utils import load_wan_vae
4747from ..utils import load_video
4848from ..video_processor import VideoProcessor
49+ from flax .linen import partitioning as nn_partitioning
4950
5051THIS_DIR = os .path .dirname (os .path .abspath (__file__ ))
5152
@@ -160,6 +161,16 @@ class WanVaeTest(unittest.TestCase):
160161
161162 def setUp (self ):
162163 WanVaeTest .dummy_data = {}
164+ pyconfig .initialize (
165+ [
166+ None ,
167+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
168+ ],
169+ unittest = True ,
170+ )
171+ self .config = pyconfig .config
172+ devices_array = create_device_mesh (self .config )
173+ self .mesh = Mesh (devices_array , self .config .mesh_axes )
163174
164175 def test_wanrms_norm (self ):
165176 """Test against the Pytorch implementation"""
@@ -209,7 +220,8 @@ def test_zero_padded_conv(self):
209220 output_torch = resample (input )
210221 assert output_torch .shape == (1 , 96 , 240 , 360 )
211222
212- model = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (1 , 3 , 3 ), stride = (1 , 2 , 2 ))
223+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
224+ model = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (1 , 3 , 3 ), stride = (1 , 2 , 2 ))
213225 dummy_input = jnp .ones (input_shape )
214226 dummy_input = jnp .transpose (dummy_input , (0 , 2 , 3 , 1 ))
215227 output = model (dummy_input )
@@ -247,7 +259,8 @@ def test_wan_resample(self):
247259 torch_output = torch_wan_resample (dummy_input )
248260 assert torch_output .shape == (batch , dim , t , h // 2 , w // 2 )
249261
250- wan_resample = WanResample (dim , mode = mode , rngs = rngs )
262+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
263+ wan_resample = WanResample (dim , mode = mode , rngs = rngs )
251264 # channels is always last here
252265 input_shape = (batch , t , h , w , dim )
253266 dummy_input = jnp .ones (input_shape )
@@ -257,16 +270,6 @@ def test_wan_resample(self):
257270 def test_3d_conv (self ):
258271 key = jax .random .key (0 )
259272 rngs = nnx .Rngs (key )
260- pyconfig .initialize (
261- [
262- None ,
263- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
264- ],
265- unittest = True ,
266- )
267- config = pyconfig .config
268- devices_array = create_device_mesh (config )
269- mesh = Mesh (devices_array , config .mesh_axes )
270273
271274 batch_size = 1
272275 in_depth , in_height , in_width = 10 , 32 , 32
@@ -283,14 +286,15 @@ def test_3d_conv(self):
283286 dummy_cache = jnp .zeros ((batch_size , cache_depth , in_height , in_width , in_channels ))
284287
285288 # Instantiate the module
286- causal_conv_layer = WanCausalConv3d (
287- in_channels = in_channels ,
288- out_channels = out_channels ,
289- kernel_size = (kernel_d , kernel_h , kernel_w ),
290- padding = (padding_d , padding_h , padding_w ),
291- rngs = rngs , # Pass rngs for initialization,
292- mesh = mesh ,
293- )
289+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
290+ causal_conv_layer = WanCausalConv3d (
291+ in_channels = in_channels ,
292+ out_channels = out_channels ,
293+ kernel_size = (kernel_d , kernel_h , kernel_w ),
294+ padding = (padding_d , padding_h , padding_w ),
295+ rngs = rngs , # Pass rngs for initialization,
296+ mesh = self .mesh ,
297+ )
294298
295299 # --- Test Case 1: No Cache ---
296300 output_no_cache = causal_conv_layer (dummy_input )
@@ -309,16 +313,6 @@ def test_3d_conv(self):
309313 def test_wan_residual (self ):
310314 key = jax .random .key (0 )
311315 rngs = nnx .Rngs (key )
312- pyconfig .initialize (
313- [
314- None ,
315- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
316- ],
317- unittest = True ,
318- )
319- config = pyconfig .config
320- devices_array = create_device_mesh (config )
321- mesh = Mesh (devices_array , config .mesh_axes )
322316 # --- Test Case 1: same in/out dim ---
323317 in_dim = out_dim = 96
324318 batch = 1
@@ -329,7 +323,8 @@ def test_wan_residual(self):
329323 input_shape = (batch , t , height , width , dim )
330324 expected_output_shape = (batch , t , height , width , dim )
331325
332- wan_residual_block = WanResidualBlock (in_dim = in_dim , out_dim = out_dim , rngs = rngs , mesh = mesh )
326+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
327+ wan_residual_block = WanResidualBlock (in_dim = in_dim , out_dim = out_dim , rngs = rngs , mesh = self .mesh )
333328 dummy_input = jnp .ones (input_shape )
334329 dummy_output = wan_residual_block (dummy_input )
335330 assert dummy_output .shape == expected_output_shape
@@ -339,7 +334,8 @@ def test_wan_residual(self):
339334 out_dim = 196
340335 expected_output_shape = (batch , t , height , width , out_dim )
341336
342- wan_residual_block = WanResidualBlock (in_dim = in_dim , out_dim = out_dim , rngs = rngs , mesh = mesh )
337+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
338+ wan_residual_block = WanResidualBlock (in_dim = in_dim , out_dim = out_dim , rngs = rngs , mesh = self .mesh )
343339 dummy_input = jnp .ones (input_shape )
344340 dummy_output = wan_residual_block (dummy_input )
345341 assert dummy_output .shape == expected_output_shape
@@ -361,56 +357,38 @@ def test_wan_attention(self):
361357 def test_wan_midblock (self ):
362358 key = jax .random .key (0 )
363359 rngs = nnx .Rngs (key )
364- pyconfig .initialize (
365- [
366- None ,
367- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
368- ],
369- unittest = True ,
370- )
371- config = pyconfig .config
372- devices_array = create_device_mesh (config )
373- mesh = Mesh (devices_array , config .mesh_axes )
374360 batch = 1
375361 t = 1
376362 dim = 384
377363 height = 60
378364 width = 90
379365 input_shape = (batch , t , height , width , dim )
380- wan_midblock = WanMidBlock (dim = dim , rngs = rngs , mesh = mesh )
366+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
367+ wan_midblock = WanMidBlock (dim = dim , rngs = rngs , mesh = self .mesh )
381368 dummy_input = jnp .ones (input_shape )
382369 output = wan_midblock (dummy_input )
383370 assert output .shape == input_shape
384371
385372 def test_wan_decode (self ):
386373 key = jax .random .key (0 )
387374 rngs = nnx .Rngs (key )
388- pyconfig .initialize (
389- [
390- None ,
391- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
392- ],
393- unittest = True ,
394- )
395- config = pyconfig .config
396- devices_array = create_device_mesh (config )
397- mesh = Mesh (devices_array , config .mesh_axes )
398375 dim = 96
399376 z_dim = 16
400377 dim_mult = [1 , 2 , 4 , 4 ]
401378 num_res_blocks = 2
402379 attn_scales = []
403380 temperal_downsample = [False , True , True ]
404- wan_vae = AutoencoderKLWan (
405- rngs = rngs ,
406- base_dim = dim ,
407- z_dim = z_dim ,
408- dim_mult = dim_mult ,
409- num_res_blocks = num_res_blocks ,
410- attn_scales = attn_scales ,
411- temperal_downsample = temperal_downsample ,
412- mesh = mesh ,
413- )
381+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
382+ wan_vae = AutoencoderKLWan (
383+ rngs = rngs ,
384+ base_dim = dim ,
385+ z_dim = z_dim ,
386+ dim_mult = dim_mult ,
387+ num_res_blocks = num_res_blocks ,
388+ attn_scales = attn_scales ,
389+ temperal_downsample = temperal_downsample ,
390+ mesh = self .mesh ,
391+ )
414392 vae_cache = AutoencoderKLWanCache (wan_vae )
415393 batch = 1
416394 t = 13
@@ -429,32 +407,23 @@ def test_wan_decode(self):
429407 def test_wan_encode (self ):
430408 key = jax .random .key (0 )
431409 rngs = nnx .Rngs (key )
432- pyconfig .initialize (
433- [
434- None ,
435- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
436- ],
437- unittest = True ,
438- )
439- config = pyconfig .config
440- devices_array = create_device_mesh (config )
441- mesh = Mesh (devices_array , config .mesh_axes )
442410 dim = 96
443411 z_dim = 16
444412 dim_mult = [1 , 2 , 4 , 4 ]
445413 num_res_blocks = 2
446414 attn_scales = []
447415 temperal_downsample = [False , True , True ]
448- wan_vae = AutoencoderKLWan (
449- rngs = rngs ,
450- base_dim = dim ,
451- z_dim = z_dim ,
452- dim_mult = dim_mult ,
453- num_res_blocks = num_res_blocks ,
454- attn_scales = attn_scales ,
455- temperal_downsample = temperal_downsample ,
456- mesh = mesh ,
457- )
416+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
417+ wan_vae = AutoencoderKLWan (
418+ rngs = rngs ,
419+ base_dim = dim ,
420+ z_dim = z_dim ,
421+ dim_mult = dim_mult ,
422+ num_res_blocks = num_res_blocks ,
423+ attn_scales = attn_scales ,
424+ temperal_downsample = temperal_downsample ,
425+ mesh = self .mesh ,
426+ )
458427 vae_cache = AutoencoderKLWanCache (wan_vae )
459428 batch = 1
460429 channels = 3
@@ -474,18 +443,9 @@ def vae_encode(video, wan_vae, vae_cache, key):
474443
475444 key = jax .random .key (0 )
476445 rngs = nnx .Rngs (key )
477- pyconfig .initialize (
478- [
479- None ,
480- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
481- ],
482- unittest = True ,
483- )
484- config = pyconfig .config
485- devices_array = create_device_mesh (config )
486- mesh = Mesh (devices_array , config .mesh_axes )
487446
488- wan_vae = AutoencoderKLWan .from_config (config .pretrained_model_name_or_path , subfolder = "vae" , rngs = rngs , mesh = mesh )
447+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
448+ wan_vae = AutoencoderKLWan .from_config (self .config .pretrained_model_name_or_path , subfolder = "vae" , rngs = rngs , mesh = self .mesh )
489449 vae_cache = AutoencoderKLWanCache (wan_vae )
490450 video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
491451 video = load_video (video_path )
@@ -499,7 +459,7 @@ def vae_encode(video, wan_vae, vae_cache, key):
499459 graphdef , state = nnx .split (wan_vae )
500460 params = state .to_pure_dict ()
501461 # This replaces random params with the model.
502- params = load_wan_vae (config .pretrained_model_name_or_path , params , "cpu" )
462+ params = load_wan_vae (self . config .pretrained_model_name_or_path , params , "cpu" )
503463 params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), params )
504464 wan_vae = nnx .merge (graphdef , params )
505465
0 commit comments