Skip to content

Commit f46674b

Browse files
committed
changed vae test logic
1 parent ec6fd38 commit f46674b

2 files changed

Lines changed: 36 additions & 10 deletions

File tree

src/maxdiffusion/pyconfig.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,7 @@ def user_init(raw_keys):
257257
) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"])
258258

259259
if raw_keys.get("vae_spatial", -1) == -1:
260-
total_device = len(jax.devices())
261-
dp = raw_keys.get("ici_data_parallelism", 1) * raw_keys.get("dcn_data_parallelism", 1)
262-
if dp == -1 or dp == 0:
263-
dp = 1
264-
raw_keys["vae_spatial"] = (total_device * 2) // dp
260+
raw_keys["vae_spatial"] = 1
265261

266262

267263
def get_num_slices(raw_keys):

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,13 @@ def test_wan_residual(self):
340340
)
341341
config = pyconfig.config
342342
devices_array = create_device_mesh(config)
343-
mesh = Mesh(devices_array, config.mesh_axes)
343+
# Add vae_spatial axis to mesh for VAE operations
344+
mesh_axes = list(config.mesh_axes)
345+
if "vae_spatial" not in mesh_axes:
346+
mesh_axes.append("vae_spatial")
347+
# Reshape devices to include vae_spatial (size 1 for test)
348+
devices_array = devices_array.reshape(devices_array.shape + (1,))
349+
mesh = Mesh(devices_array, mesh_axes)
344350
# --- Test Case 1: same in/out dim ---
345351
in_dim = out_dim = 96
346352
batch = 1
@@ -392,7 +398,13 @@ def test_wan_midblock(self):
392398
)
393399
config = pyconfig.config
394400
devices_array = create_device_mesh(config)
395-
mesh = Mesh(devices_array, config.mesh_axes)
401+
# Add vae_spatial axis to mesh for VAE operations
402+
mesh_axes = list(config.mesh_axes)
403+
if "vae_spatial" not in mesh_axes:
404+
mesh_axes.append("vae_spatial")
405+
# Reshape devices to include vae_spatial (size 1 for test)
406+
devices_array = devices_array.reshape(devices_array.shape + (1,))
407+
mesh = Mesh(devices_array, mesh_axes)
396408
batch = 1
397409
t = 1
398410
dim = 384
@@ -417,7 +429,13 @@ def test_wan_decode(self):
417429
)
418430
config = pyconfig.config
419431
devices_array = create_device_mesh(config)
420-
mesh = Mesh(devices_array, config.mesh_axes)
432+
# Add vae_spatial axis to mesh for VAE operations
433+
mesh_axes = list(config.mesh_axes)
434+
if "vae_spatial" not in mesh_axes:
435+
mesh_axes.append("vae_spatial")
436+
# Reshape devices to include vae_spatial (size 1 for test)
437+
devices_array = devices_array.reshape(devices_array.shape + (1,))
438+
mesh = Mesh(devices_array, mesh_axes)
421439
dim = 96
422440
z_dim = 16
423441
dim_mult = [1, 2, 4, 4]
@@ -462,7 +480,13 @@ def test_wan_encode(self):
462480
)
463481
config = pyconfig.config
464482
devices_array = create_device_mesh(config)
465-
mesh = Mesh(devices_array, config.mesh_axes)
483+
# Add vae_spatial axis to mesh for VAE operations
484+
mesh_axes = list(config.mesh_axes)
485+
if "vae_spatial" not in mesh_axes:
486+
mesh_axes.append("vae_spatial")
487+
# Reshape devices to include vae_spatial (size 1 for test)
488+
devices_array = devices_array.reshape(devices_array.shape + (1,))
489+
mesh = Mesh(devices_array, mesh_axes)
466490
dim = 96
467491
z_dim = 16
468492
dim_mult = [1, 2, 4, 4]
@@ -508,7 +532,13 @@ def vae_encode(video, wan_vae, vae_cache, key):
508532
)
509533
config = pyconfig.config
510534
devices_array = create_device_mesh(config)
511-
mesh = Mesh(devices_array, config.mesh_axes)
535+
# Add vae_spatial axis to mesh for VAE operations
536+
mesh_axes = list(config.mesh_axes)
537+
if "vae_spatial" not in mesh_axes:
538+
mesh_axes.append("vae_spatial")
539+
# Reshape devices to include vae_spatial (size 1 for test)
540+
devices_array = devices_array.reshape(devices_array.shape + (1,))
541+
mesh = Mesh(devices_array, mesh_axes)
512542
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
513543
wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh)
514544
vae_cache = AutoencoderKLWanCache(wan_vae)

0 commit comments

Comments
 (0)