Skip to content

Commit ec6fd38

Browse files
committed
added vae_spatial
1 parent 4957a70 commit ec6fd38

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,13 @@ def test_3d_conv(self):
281281
)
282282
config = pyconfig.config
283283
devices_array = create_device_mesh(config)
284-
mesh = Mesh(devices_array, config.mesh_axes)
284+
# Add vae_spatial axis to mesh for VAE operations
285+
mesh_axes = list(config.mesh_axes)
286+
if "vae_spatial" not in mesh_axes:
287+
mesh_axes.append("vae_spatial")
288+
# Reshape devices to include vae_spatial (size 1 for test)
289+
devices_array = devices_array.reshape(devices_array.shape + (1,))
290+
mesh = Mesh(devices_array, mesh_axes)
285291

286292
batch_size = 1
287293
in_depth, in_height, in_width = 10, 32, 32

0 commit comments

Comments
 (0)