Skip to content

Commit 673f533

Browse files
committed
adding vae spatial axis to conv_in
1 parent 67bcef8 commit 673f533

6 files changed

Lines changed: 16 additions & 11 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ vae_logical_axis_rules: [
192192
['conv_batch', 'redundant'],
193193
['out_channels', 'vae_spatial'],
194194
['conv_out', 'vae_spatial'],
195+
['conv_in', 'vae_spatial'],
195196
]
196197
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
197198

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ vae_logical_axis_rules: [
168168
['conv_batch', 'redundant'],
169169
['out_channels', 'vae_spatial'],
170170
['conv_out', 'vae_spatial'],
171+
['conv_in', 'vae_spatial'],
171172
]
172173
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
173174

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ vae_logical_axis_rules: [
180180
['conv_batch', 'redundant'],
181181
['out_channels', 'vae_spatial'],
182182
['conv_out', 'vae_spatial'],
183+
['conv_in', 'vae_spatial'],
183184
]
184185
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
185186

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ vae_logical_axis_rules: [
174174
['conv_batch', 'redundant'],
175175
['out_channels', 'vae_spatial'],
176176
['conv_out', 'vae_spatial'],
177+
['conv_in', 'vae_spatial'],
177178
]
178179
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
179180

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ vae_logical_axis_rules: [
175175
['conv_batch', 'redundant'],
176176
['out_channels', 'vae_spatial'],
177177
['conv_out', 'vae_spatial'],
178+
['conv_in', 'vae_spatial'],
178179
]
179180
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
180181

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import jax
2222
import jax.numpy as jnp
2323
from flax import nnx
24-
from flax import linen as nn
2524
from flax.linen import partitioning as nn_partitioning
25+
from flax.linen import logical_to_mesh_sharding
2626
from jax.sharding import Mesh
2727
from .. import pyconfig
2828
from ..max_utils import (
@@ -224,7 +224,7 @@ def test_zero_padded_conv(self):
224224
output_torch = resample(input)
225225
assert output_torch.shape == (1, 96, 240, 360)
226226

227-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
227+
with self.mesh, nn_partitioning.axis_rules(self.config.vae_logical_axis_rules):
228228
model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2))
229229
dummy_input = jnp.ones(input_shape)
230230
dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1))
@@ -262,7 +262,7 @@ def test_wan_resample(self):
262262
torch_wan_resample = TorchWanResample(dim=dim, mode=mode)
263263
torch_output = torch_wan_resample(dummy_input)
264264
assert torch_output.shape == (batch, dim, t, h // 2, w // 2)
265-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
265+
with self.mesh, nn_partitioning.axis_rules(self.config.vae_logical_axis_rules):
266266
wan_resample = WanResample(dim, mode=mode, rngs=rngs)
267267
# channels is always last here
268268
input_shape = (batch, t, h, w, dim)
@@ -305,7 +305,7 @@ def test_3d_conv(self):
305305
dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels))
306306

307307
# Instantiate the module
308-
with self.mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
308+
with self.mesh, nn_partitioning.axis_rules(config.vae_logical_axis_rules):
309309
causal_conv_layer = WanCausalConv3d(
310310
in_channels=in_channels,
311311
out_channels=out_channels,
@@ -357,7 +357,7 @@ def test_wan_residual(self):
357357
dim = 96
358358
input_shape = (batch, t, height, width, dim)
359359
expected_output_shape = (batch, t, height, width, dim)
360-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
360+
with mesh, nn_partitioning.axis_rules(config.vae_logical_axis_rules):
361361
wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh)
362362
dummy_input = jnp.ones(input_shape)
363363
dummy_output, _, _ = wan_residual_block(dummy_input)
@@ -381,7 +381,7 @@ def test_wan_attention(self):
381381
height = 60
382382
width = 90
383383
input_shape = (batch, t, height, width, dim)
384-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
384+
with self.mesh, nn_partitioning.axis_rules(self.config.vae_logical_axis_rules):
385385
wan_attention = WanAttentionBlock(dim=dim, rngs=rngs)
386386
dummy_input = jnp.ones(input_shape)
387387
output, _, _ = wan_attention(dummy_input)
@@ -412,7 +412,7 @@ def test_wan_midblock(self):
412412
height = 60
413413
width = 90
414414
input_shape = (batch, t, height, width, dim)
415-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
415+
with mesh, nn_partitioning.axis_rules(config.vae_logical_axis_rules):
416416
wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh)
417417
dummy_input = jnp.ones(input_shape)
418418
output, _, _ = wan_midblock(dummy_input)
@@ -443,7 +443,7 @@ def test_wan_decode(self):
443443
num_res_blocks = 2
444444
attn_scales = []
445445
temperal_downsample = [False, True, True]
446-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
446+
with mesh, nn_partitioning.axis_rules(config.vae_logical_axis_rules):
447447
wan_vae = AutoencoderKLWan(
448448
rngs=rngs,
449449
base_dim=dim,
@@ -494,7 +494,7 @@ def test_wan_encode(self):
494494
num_res_blocks = 2
495495
attn_scales = []
496496
temperal_downsample = [False, True, True]
497-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
497+
with mesh, nn_partitioning.axis_rules(config.vae_logical_axis_rules):
498498
wan_vae = AutoencoderKLWan(
499499
rngs=rngs,
500500
base_dim=dim,
@@ -540,7 +540,7 @@ def vae_encode(video, wan_vae, vae_cache, key):
540540
# Reshape devices to include vae_spatial (size 1 for test)
541541
devices_array = devices_array.reshape(devices_array.shape + (1,))
542542
mesh = Mesh(devices_array, mesh_axes)
543-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
543+
with self.mesh, nn_partitioning.axis_rules(self.config.vae_logical_axis_rules):
544544
wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh)
545545
vae_cache = AutoencoderKLWanCache(wan_vae)
546546
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
@@ -559,7 +559,7 @@ def vae_encode(video, wan_vae, vae_cache, key):
559559
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
560560

561561
logical_state_spec = nnx.get_partition_spec(state)
562-
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
562+
logical_state_sharding = logical_to_mesh_sharding(logical_state_spec, mesh, config.vae_logical_axis_rules)
563563
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
564564

565565
state_flat = dict(nnx.to_flat_state(state))

0 commit comments

Comments
 (0)