Skip to content

Commit e7cd3c4

Browse files
committed
Renaming VAE sharding axis to vae_spatial
1 parent a0c377f commit e7cd3c4

9 files changed

Lines changed: 20 additions & 8 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ activations_dtype: 'bfloat16'
4343

4444
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4545
replicate_vae: False
46+
vae_spatial: -1 # default to total_device * 2 // (dp)
4647

4748
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4849
# Options are "DEFAULT", "HIGH", "HIGHEST"

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ activations_dtype: 'bfloat16'
4343

4444
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4545
replicate_vae: False
46+
vae_spatial: -1
4647

4748
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4849
# Options are "DEFAULT", "HIGH", "HIGHEST"

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def run(config, pipeline=None, filename_prefix=""):
166166
max_logging.log(f"hardware: {jax.devices()[0].platform}")
167167
max_logging.log(f"number of devices: {jax.device_count()}")
168168
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
169+
max_logging.log(f"vae_spatial: {config.vae_spatial}")
169170
max_logging.log("============================================================")
170171

171172
compile_time = time.perf_counter() - s0

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _tpu_flash_attention(
255255
kv_max_block_size = key.shape[1]
256256
else:
257257
kv_max_block_size = q_max_block_size
258-
258+
259259
# ensure that for cross attention we override the block sizes.
260260
if flash_block_sizes and key.shape[1] == query.shape[1]:
261261
block_sizes = flash_block_sizes

src/maxdiffusion/models/vae_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..configuration_utils import ConfigMixin, flax_register_to_config
2929
from ..utils import BaseOutput
3030
from .modeling_flax_utils import FlaxModelMixin
31-
31+
3232

3333

3434
@flax.struct.dataclass

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def __init__(
9999
self.mesh = mesh
100100

101101
# Weight sharding (Kernel is sharded along output channels)
102-
num_fsdp_devices = mesh.shape["fsdp"]
102+
num_fsdp_devices = mesh.shape["vae_spatial"]
103103
kernel_sharding = (None, None, None, None, None)
104104
if out_channels % num_fsdp_devices == 0:
105-
kernel_sharding = (None, None, None, None, "fsdp")
105+
kernel_sharding = (None, None, None, None, "vae_spatial")
106106

107107
self.conv = nnx.Conv(
108108
in_features=in_channels,
@@ -121,7 +121,7 @@ def __init__(
121121
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array:
122122
# Sharding Width (index 3)
123123
# Spec: (Batch, Time, Height, Width, Channels)
124-
spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "fsdp", None))
124+
spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None))
125125
x = jax.lax.with_sharding_constraint(x, spatial_sharding)
126126

127127
current_padding = list(self._causal_padding)
@@ -1098,7 +1098,7 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
10981098
iter_ = 1 + (t - 1) // 4
10991099
enc_feat_map = feat_cache._enc_feat_map
11001100

1101-
spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "fsdp", None))
1101+
spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None))
11021102

11031103
# First iteration (i=0): size 1
11041104
chunk_0 = x[:, :1, ...]
@@ -1180,7 +1180,7 @@ def _decode(
11801180

11811181
dec_feat_map = feat_cache._feat_map
11821182
# NamedSharding for the Width axis (axis 3)
1183-
spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "fsdp", None))
1183+
spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None))
11841184

11851185
# First chunk (i=0)
11861186
chunk_in_0 = jax.lax.with_sharding_constraint(x[:, 0:1, ...], spatial_sharding)
@@ -1264,4 +1264,4 @@ def decode(
12641264
decoded = self._decode(z, feat_cache).sample
12651265
if not return_dict:
12661266
return (decoded,)
1267-
return FlaxDecoderOutput(sample=decoded)
1267+
return FlaxDecoderOutput(sample=decoded)

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
5454
scheduler_state=common_components["scheduler_state"],
5555
devices_array=common_components["devices_array"],
5656
mesh=common_components["mesh"],
57+
vae_mesh=common_components["vae_mesh"],
5758
config=config,
5859
)
5960

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6363
scheduler_state=common_components["scheduler_state"],
6464
devices_array=common_components["devices_array"],
6565
mesh=common_components["mesh"],
66+
vae_mesh=common_components["vae_mesh"],
6667
config=config,
6768
)
6869
return pipeline, low_noise_transformer, high_noise_transformer

src/maxdiffusion/pyconfig.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ def user_init(raw_keys):
248248
_HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"])
249249
)
250250

251+
if getattr(raw_keys, "vae_spatial", -1) == -1 or "vae_spatial" in raw_keys and raw_keys["vae_spatial"] == -1:
252+
total_device = len(jax.devices())
253+
dp = raw_keys.get("ici_data_parallelism", 1) * raw_keys.get("dcn_data_parallelism", 1)
254+
if dp == -1 or dp == 0:
255+
dp = 1
256+
raw_keys["vae_spatial"] = (total_device * 2) // dp
257+
251258

252259
def get_num_slices(raw_keys):
253260
if int(raw_keys["compile_topology_num_slices"]) > 0:

0 commit comments

Comments
 (0)