Skip to content

Commit c1fc264

Browse files
committed
moving vae logic to config files
1 parent 57aaf0d commit c1fc264

11 files changed

Lines changed: 87 additions & 142 deletions

File tree

dependencies/requirements/generated_requirements/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.
33

44
absl-py>=2.3.1
5+
accelerate>=1.13.0
56
aiofiles>=25.1.0
67
aiohappyeyeballs>=2.6.1
78
aiohttp>=3.13.3
@@ -80,6 +81,7 @@ isort>=8.0.1
8081
jaraco-functools>=4.4.0
8182
jax>=0.9.0
8283
jaxlib>=0.9.0
84+
jaxopt>=0.8.5
8385
jaxtyping>=0.3.9
8486
jinja2>=3.1.6
8587
keras>=3.13.1

requirements.txt

Lines changed: 0 additions & 40 deletions
This file was deleted.

requirements_with_jax_ai_image.txt

Lines changed: 0 additions & 41 deletions
This file was deleted.

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@ logical_axis_rules: [
181181
['out_channels', 'tensor'],
182182
['conv_out', 'context'],
183183
]
184+
vae_logical_axis_rules: [
185+
['activation_batch', 'redundant'],
186+
['activation_length', 'vae_spatial'],
187+
['activation_heads', null],
188+
['activation_kv_length', null],
189+
['embed', null],
190+
['heads', null],
191+
['norm', null],
192+
['conv_batch', 'redundant'],
193+
['out_channels', 'vae_spatial'],
194+
['conv_out', 'vae_spatial'],
195+
]
184196
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
185197

186198
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Google LLC
1+
# Copyright 2023 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -157,6 +157,18 @@ logical_axis_rules: [
157157
['out_channels', 'tensor'],
158158
['conv_out', 'context'],
159159
]
160+
vae_logical_axis_rules: [
161+
['activation_batch', 'redundant'],
162+
['activation_length', 'vae_spatial'],
163+
['activation_heads', null],
164+
['activation_kv_length', null],
165+
['embed', null],
166+
['heads', null],
167+
['norm', null],
168+
['conv_batch', 'redundant'],
169+
['out_channels', 'vae_spatial'],
170+
['conv_out', 'vae_spatial'],
171+
]
160172
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
161173

162174
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ logical_axis_rules: [
169169
['out_channels', 'tensor'],
170170
['conv_out', 'context'],
171171
]
172+
vae_logical_axis_rules: [
173+
['activation_batch', 'redundant'],
174+
['activation_length', 'vae_spatial'],
175+
['activation_heads', null],
176+
['activation_kv_length', null],
177+
['embed', null],
178+
['heads', null],
179+
['norm', null],
180+
['conv_batch', 'redundant'],
181+
['out_channels', 'vae_spatial'],
182+
['conv_out', 'vae_spatial'],
183+
]
172184
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
173185

174186
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,18 @@ logical_axis_rules: [
163163
['out_channels', 'tensor'],
164164
['conv_out', 'context'],
165165
]
166+
vae_logical_axis_rules: [
167+
['activation_batch', 'redundant'],
168+
['activation_length', 'vae_spatial'],
169+
['activation_heads', null],
170+
['activation_kv_length', null],
171+
['embed', null],
172+
['heads', null],
173+
['norm', null],
174+
['conv_batch', 'redundant'],
175+
['out_channels', 'vae_spatial'],
176+
['conv_out', 'vae_spatial'],
177+
]
166178
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
167179

168180
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,18 @@ logical_axis_rules: [
164164
['out_channels', 'tensor'],
165165
['conv_out', 'context'],
166166
]
167+
vae_logical_axis_rules: [
168+
['activation_batch', 'redundant'],
169+
['activation_length', 'vae_spatial'],
170+
['activation_heads', null],
171+
['activation_kv_length', null],
172+
['embed', null],
173+
['heads', null],
174+
['norm', null],
175+
['conv_batch', 'redundant'],
176+
['out_channels', 'vae_spatial'],
177+
['conv_out', 'vae_spatial'],
178+
]
167179
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
168180

169181
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -212,33 +212,18 @@ def load_base_wan_transformer(
212212
device = jax.local_devices(backend=device)[0]
213213
filename = "diffusion_pytorch_model.safetensors.index.json"
214214
local_files = False
215-
216-
# Only rank 0 downloads; others wait for cache to be populated
217-
process_index = jax.process_index()
218215
if os.path.isdir(pretrained_model_name_or_path):
219216
index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
220217
if not os.path.isfile(index_file_path):
221218
raise FileNotFoundError(f"File {index_file_path} not found for local directory.")
222219
local_files = True
223220
elif hf_download:
224-
# Only rank 0 downloads; synchronize across all ranks
225-
if process_index == 0:
226-
# download the index file for sharded models.
227-
index_file_path = hf_hub_download(
228-
pretrained_model_name_or_path,
229-
subfolder=subfolder,
230-
filename=filename,
231-
)
232-
jax.experimental.multihost_utils.sync_global_devices("model_index_download")
233-
234-
if process_index != 0:
235-
# Non-rank-0 processes wait and use the cached path
236-
index_file_path = hf_hub_download(
237-
pretrained_model_name_or_path,
238-
subfolder=subfolder,
239-
filename=filename,
240-
force_download=False, # Use cache, don't download
241-
)
221+
# download the index file for sharded models.
222+
index_file_path = hf_hub_download(
223+
pretrained_model_name_or_path,
224+
subfolder=subfolder,
225+
filename=filename,
226+
)
242227
with jax.default_device(device):
243228
# open the index file.
244229
with open(index_file_path, "r") as f:
@@ -253,19 +238,7 @@ def load_base_wan_transformer(
253238
if local_files:
254239
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
255240
else:
256-
# Only rank 0 downloads new files; others use cached versions
257-
if process_index == 0:
258-
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
259-
jax.experimental.multihost_utils.sync_global_devices(f"model_download_{model_file}")
260-
261-
if process_index != 0:
262-
# Non-rank-0: use cached version
263-
ckpt_shard_path = hf_hub_download(
264-
pretrained_model_name_or_path,
265-
subfolder=subfolder,
266-
filename=model_file,
267-
force_download=False, # Use cache
268-
)
241+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
269242
# now get all the filenames for the model that need downloading
270243
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
271244

@@ -331,25 +304,12 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
331304
device = jax.devices(device)[0]
332305
subfolder = "vae"
333306
filename = "diffusion_pytorch_model.safetensors"
334-
process_index = jax.process_index()
335-
336307
if os.path.isdir(pretrained_model_name_or_path):
337308
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
338309
if not os.path.isfile(ckpt_path):
339310
raise FileNotFoundError(f"File {ckpt_path} not found for local directory.")
340311
elif hf_download:
341-
# Only rank 0 downloads; others use cache
342-
if process_index == 0:
343-
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
344-
jax.experimental.multihost_utils.sync_global_devices("vae_download")
345-
346-
if process_index != 0:
347-
ckpt_path = hf_hub_download(
348-
pretrained_model_name_or_path,
349-
subfolder=subfolder,
350-
filename=filename,
351-
force_download=False, # Use cache
352-
)
312+
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
353313
max_logging.log(f"Load and port {pretrained_model_name_or_path} VAE on {device}")
354314
with jax.default_device(device):
355315
if ckpt_path is not None:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -635,18 +635,20 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
635635
)
636636

637637
# logical axis rules for VAE encoding/decoding
638-
vae_logical_axis_rules = (
639-
("activation_batch", "redundant"),
640-
("activation_length", "vae_spatial"),
641-
("activation_heads", None),
642-
("activation_kv_length", None),
643-
("embed", None),
644-
("heads", None),
645-
("norm", None),
646-
("conv_batch", "redundant"),
647-
("out_channels", "vae_spatial"),
648-
("conv_out", "vae_spatial"),
649-
)
638+
vae_logical_axis_rules = getattr(config, "vae_logical_axis_rules", None)
639+
if vae_logical_axis_rules is None:
640+
vae_logical_axis_rules = (
641+
("activation_batch", "redundant"),
642+
("activation_length", "vae_spatial"),
643+
("activation_heads", None),
644+
("activation_kv_length", None),
645+
("embed", None),
646+
("heads", None),
647+
("norm", None),
648+
("conv_batch", "redundant"),
649+
("out_channels", "vae_spatial"),
650+
("conv_out", "vae_spatial"),
651+
)
650652

651653
rng = jax.random.key(config.seed)
652654
rngs = nnx.Rngs(rng)

0 commit comments

Comments
 (0)