@@ -340,7 +340,13 @@ def test_wan_residual(self):
340340 )
341341 config = pyconfig .config
342342 devices_array = create_device_mesh (config )
343- mesh = Mesh (devices_array , config .mesh_axes )
343+ # Add vae_spatial axis to mesh for VAE operations
344+ mesh_axes = list (config .mesh_axes )
345+ if "vae_spatial" not in mesh_axes :
346+ mesh_axes .append ("vae_spatial" )
347+ # Reshape devices to include vae_spatial (size 1 for test)
348+ devices_array = devices_array .reshape (devices_array .shape + (1 ,))
349+ mesh = Mesh (devices_array , mesh_axes )
344350 # --- Test Case 1: same in/out dim ---
345351 in_dim = out_dim = 96
346352 batch = 1
@@ -392,7 +398,13 @@ def test_wan_midblock(self):
392398 )
393399 config = pyconfig .config
394400 devices_array = create_device_mesh (config )
395- mesh = Mesh (devices_array , config .mesh_axes )
401+ # Add vae_spatial axis to mesh for VAE operations
402+ mesh_axes = list (config .mesh_axes )
403+ if "vae_spatial" not in mesh_axes :
404+ mesh_axes .append ("vae_spatial" )
405+ # Reshape devices to include vae_spatial (size 1 for test)
406+ devices_array = devices_array .reshape (devices_array .shape + (1 ,))
407+ mesh = Mesh (devices_array , mesh_axes )
396408 batch = 1
397409 t = 1
398410 dim = 384
@@ -417,7 +429,13 @@ def test_wan_decode(self):
417429 )
418430 config = pyconfig .config
419431 devices_array = create_device_mesh (config )
420- mesh = Mesh (devices_array , config .mesh_axes )
432+ # Add vae_spatial axis to mesh for VAE operations
433+ mesh_axes = list (config .mesh_axes )
434+ if "vae_spatial" not in mesh_axes :
435+ mesh_axes .append ("vae_spatial" )
436+ # Reshape devices to include vae_spatial (size 1 for test)
437+ devices_array = devices_array .reshape (devices_array .shape + (1 ,))
438+ mesh = Mesh (devices_array , mesh_axes )
421439 dim = 96
422440 z_dim = 16
423441 dim_mult = [1 , 2 , 4 , 4 ]
@@ -462,7 +480,13 @@ def test_wan_encode(self):
462480 )
463481 config = pyconfig .config
464482 devices_array = create_device_mesh (config )
465- mesh = Mesh (devices_array , config .mesh_axes )
483+ # Add vae_spatial axis to mesh for VAE operations
484+ mesh_axes = list (config .mesh_axes )
485+ if "vae_spatial" not in mesh_axes :
486+ mesh_axes .append ("vae_spatial" )
487+ # Reshape devices to include vae_spatial (size 1 for test)
488+ devices_array = devices_array .reshape (devices_array .shape + (1 ,))
489+ mesh = Mesh (devices_array , mesh_axes )
466490 dim = 96
467491 z_dim = 16
468492 dim_mult = [1 , 2 , 4 , 4 ]
@@ -508,7 +532,13 @@ def vae_encode(video, wan_vae, vae_cache, key):
508532 )
509533 config = pyconfig .config
510534 devices_array = create_device_mesh (config )
511- mesh = Mesh (devices_array , config .mesh_axes )
535+ # Add vae_spatial axis to mesh for VAE operations
536+ mesh_axes = list (config .mesh_axes )
537+ if "vae_spatial" not in mesh_axes :
538+ mesh_axes .append ("vae_spatial" )
539+ # Reshape devices to include vae_spatial (size 1 for test)
540+ devices_array = devices_array .reshape (devices_array .shape + (1 ,))
541+ mesh = Mesh (devices_array , mesh_axes )
512542 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
513543 wan_vae = AutoencoderKLWan .from_config (config .pretrained_model_name_or_path , subfolder = "vae" , rngs = rngs , mesh = mesh )
514544 vae_cache = AutoencoderKLWanCache (wan_vae )
0 commit comments