Skip to content

Commit 2a1d2ea

Browse files
committed
fix
1 parent f8dd177 commit 2a1d2ea

1 file changed

Lines changed: 55 additions & 95 deletions

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 55 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from ..models.wan.wan_utils import load_wan_vae
4747
from ..utils import load_video
4848
from ..video_processor import VideoProcessor
49+
from flax.linen import partitioning as nn_partitioning
4950

5051
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
5152

@@ -160,6 +161,16 @@ class WanVaeTest(unittest.TestCase):
160161

161162
def setUp(self):
162163
WanVaeTest.dummy_data = {}
164+
pyconfig.initialize(
165+
[
166+
None,
167+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
168+
],
169+
unittest=True,
170+
)
171+
self.config = pyconfig.config
172+
devices_array = create_device_mesh(self.config)
173+
self.mesh = Mesh(devices_array, self.config.mesh_axes)
163174

164175
def test_wanrms_norm(self):
165176
"""Test against the Pytorch implementation"""
@@ -209,7 +220,8 @@ def test_zero_padded_conv(self):
209220
output_torch = resample(input)
210221
assert output_torch.shape == (1, 96, 240, 360)
211222

212-
model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2))
223+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
224+
model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2))
213225
dummy_input = jnp.ones(input_shape)
214226
dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1))
215227
output = model(dummy_input)
@@ -247,7 +259,8 @@ def test_wan_resample(self):
247259
torch_output = torch_wan_resample(dummy_input)
248260
assert torch_output.shape == (batch, dim, t, h // 2, w // 2)
249261

250-
wan_resample = WanResample(dim, mode=mode, rngs=rngs)
262+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
263+
wan_resample = WanResample(dim, mode=mode, rngs=rngs)
251264
# channels is always last here
252265
input_shape = (batch, t, h, w, dim)
253266
dummy_input = jnp.ones(input_shape)
@@ -257,16 +270,6 @@ def test_wan_resample(self):
257270
def test_3d_conv(self):
258271
key = jax.random.key(0)
259272
rngs = nnx.Rngs(key)
260-
pyconfig.initialize(
261-
[
262-
None,
263-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
264-
],
265-
unittest=True,
266-
)
267-
config = pyconfig.config
268-
devices_array = create_device_mesh(config)
269-
mesh = Mesh(devices_array, config.mesh_axes)
270273

271274
batch_size = 1
272275
in_depth, in_height, in_width = 10, 32, 32
@@ -283,14 +286,15 @@ def test_3d_conv(self):
283286
dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels))
284287

285288
# Instantiate the module
286-
causal_conv_layer = WanCausalConv3d(
287-
in_channels=in_channels,
288-
out_channels=out_channels,
289-
kernel_size=(kernel_d, kernel_h, kernel_w),
290-
padding=(padding_d, padding_h, padding_w),
291-
rngs=rngs, # Pass rngs for initialization,
292-
mesh=mesh,
293-
)
289+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
290+
causal_conv_layer = WanCausalConv3d(
291+
in_channels=in_channels,
292+
out_channels=out_channels,
293+
kernel_size=(kernel_d, kernel_h, kernel_w),
294+
padding=(padding_d, padding_h, padding_w),
295+
rngs=rngs, # Pass rngs for initialization,
296+
mesh=self.mesh,
297+
)
294298

295299
# --- Test Case 1: No Cache ---
296300
output_no_cache = causal_conv_layer(dummy_input)
@@ -309,16 +313,6 @@ def test_3d_conv(self):
309313
def test_wan_residual(self):
310314
key = jax.random.key(0)
311315
rngs = nnx.Rngs(key)
312-
pyconfig.initialize(
313-
[
314-
None,
315-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
316-
],
317-
unittest=True,
318-
)
319-
config = pyconfig.config
320-
devices_array = create_device_mesh(config)
321-
mesh = Mesh(devices_array, config.mesh_axes)
322316
# --- Test Case 1: same in/out dim ---
323317
in_dim = out_dim = 96
324318
batch = 1
@@ -329,7 +323,8 @@ def test_wan_residual(self):
329323
input_shape = (batch, t, height, width, dim)
330324
expected_output_shape = (batch, t, height, width, dim)
331325

332-
wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh)
326+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
327+
wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=self.mesh)
333328
dummy_input = jnp.ones(input_shape)
334329
dummy_output = wan_residual_block(dummy_input)
335330
assert dummy_output.shape == expected_output_shape
@@ -339,7 +334,8 @@ def test_wan_residual(self):
339334
out_dim = 196
340335
expected_output_shape = (batch, t, height, width, out_dim)
341336

342-
wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh)
337+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
338+
wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=self.mesh)
343339
dummy_input = jnp.ones(input_shape)
344340
dummy_output = wan_residual_block(dummy_input)
345341
assert dummy_output.shape == expected_output_shape
@@ -361,56 +357,38 @@ def test_wan_attention(self):
361357
def test_wan_midblock(self):
362358
key = jax.random.key(0)
363359
rngs = nnx.Rngs(key)
364-
pyconfig.initialize(
365-
[
366-
None,
367-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
368-
],
369-
unittest=True,
370-
)
371-
config = pyconfig.config
372-
devices_array = create_device_mesh(config)
373-
mesh = Mesh(devices_array, config.mesh_axes)
374360
batch = 1
375361
t = 1
376362
dim = 384
377363
height = 60
378364
width = 90
379365
input_shape = (batch, t, height, width, dim)
380-
wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh)
366+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
367+
wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=self.mesh)
381368
dummy_input = jnp.ones(input_shape)
382369
output = wan_midblock(dummy_input)
383370
assert output.shape == input_shape
384371

385372
def test_wan_decode(self):
386373
key = jax.random.key(0)
387374
rngs = nnx.Rngs(key)
388-
pyconfig.initialize(
389-
[
390-
None,
391-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
392-
],
393-
unittest=True,
394-
)
395-
config = pyconfig.config
396-
devices_array = create_device_mesh(config)
397-
mesh = Mesh(devices_array, config.mesh_axes)
398375
dim = 96
399376
z_dim = 16
400377
dim_mult = [1, 2, 4, 4]
401378
num_res_blocks = 2
402379
attn_scales = []
403380
temperal_downsample = [False, True, True]
404-
wan_vae = AutoencoderKLWan(
405-
rngs=rngs,
406-
base_dim=dim,
407-
z_dim=z_dim,
408-
dim_mult=dim_mult,
409-
num_res_blocks=num_res_blocks,
410-
attn_scales=attn_scales,
411-
temperal_downsample=temperal_downsample,
412-
mesh=mesh,
413-
)
381+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
382+
wan_vae = AutoencoderKLWan(
383+
rngs=rngs,
384+
base_dim=dim,
385+
z_dim=z_dim,
386+
dim_mult=dim_mult,
387+
num_res_blocks=num_res_blocks,
388+
attn_scales=attn_scales,
389+
temperal_downsample=temperal_downsample,
390+
mesh=self.mesh,
391+
)
414392
vae_cache = AutoencoderKLWanCache(wan_vae)
415393
batch = 1
416394
t = 13
@@ -429,32 +407,23 @@ def test_wan_decode(self):
429407
def test_wan_encode(self):
430408
key = jax.random.key(0)
431409
rngs = nnx.Rngs(key)
432-
pyconfig.initialize(
433-
[
434-
None,
435-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
436-
],
437-
unittest=True,
438-
)
439-
config = pyconfig.config
440-
devices_array = create_device_mesh(config)
441-
mesh = Mesh(devices_array, config.mesh_axes)
442410
dim = 96
443411
z_dim = 16
444412
dim_mult = [1, 2, 4, 4]
445413
num_res_blocks = 2
446414
attn_scales = []
447415
temperal_downsample = [False, True, True]
448-
wan_vae = AutoencoderKLWan(
449-
rngs=rngs,
450-
base_dim=dim,
451-
z_dim=z_dim,
452-
dim_mult=dim_mult,
453-
num_res_blocks=num_res_blocks,
454-
attn_scales=attn_scales,
455-
temperal_downsample=temperal_downsample,
456-
mesh=mesh,
457-
)
416+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
417+
wan_vae = AutoencoderKLWan(
418+
rngs=rngs,
419+
base_dim=dim,
420+
z_dim=z_dim,
421+
dim_mult=dim_mult,
422+
num_res_blocks=num_res_blocks,
423+
attn_scales=attn_scales,
424+
temperal_downsample=temperal_downsample,
425+
mesh=self.mesh,
426+
)
458427
vae_cache = AutoencoderKLWanCache(wan_vae)
459428
batch = 1
460429
channels = 3
@@ -474,18 +443,9 @@ def vae_encode(video, wan_vae, vae_cache, key):
474443

475444
key = jax.random.key(0)
476445
rngs = nnx.Rngs(key)
477-
pyconfig.initialize(
478-
[
479-
None,
480-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
481-
],
482-
unittest=True,
483-
)
484-
config = pyconfig.config
485-
devices_array = create_device_mesh(config)
486-
mesh = Mesh(devices_array, config.mesh_axes)
487446

488-
wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh)
447+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
448+
wan_vae = AutoencoderKLWan.from_config(self.config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=self.mesh)
489449
vae_cache = AutoencoderKLWanCache(wan_vae)
490450
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
491451
video = load_video(video_path)
@@ -499,7 +459,7 @@ def vae_encode(video, wan_vae, vae_cache, key):
499459
graphdef, state = nnx.split(wan_vae)
500460
params = state.to_pure_dict()
501461
# This replaces random params with the model.
502-
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
462+
params = load_wan_vae(self.config.pretrained_model_name_or_path, params, "cpu")
503463
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
504464
wan_vae = nnx.merge(graphdef, params)
505465

0 commit comments

Comments
 (0)