Skip to content

Commit 1ca1aa8

Browse files
committed
Add logical axis to mesh mapping in init
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 073d831 commit 1ca1aa8

9 files changed

Lines changed: 33 additions & 15 deletions

File tree

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ write_timing_metrics: True
2727
save_config_to_gcs: False
2828
log_period: 10000000000 # Flushes Tensorboard
2929

30+
tmp_dir: '/tmp' # directory for downloading gs:// files
31+
3032
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-base'
3133
unet_checkpoint: ''
3234
revision: 'main'

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ gcs_metrics: False
2727
save_config_to_gcs: False
2828
log_period: 100
2929

30+
tmp_dir: '/tmp' # directory for downloading gs:// files
31+
3032
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
3133
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
3234
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ write_timing_metrics: True
2727
save_config_to_gcs: False
2828
log_period: 100
2929

30+
tmp_dir: '/tmp' # directory for downloading gs:// files
31+
3032
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
3133
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
3234
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ write_timing_metrics: True
2727
save_config_to_gcs: False
2828
log_period: 100
2929

30+
tmp_dir: '/tmp' # directory for downloading gs:// files
31+
3032
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-schnell'
3133
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
3234
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ gcs_metrics: False
2727
save_config_to_gcs: False
2828
log_period: 100
2929

30+
tmp_dir: '/tmp' # directory for downloading gs:// files
31+
3032
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
3133

3234
# Overrides the transformer from pretrained_model_name_or_path
@@ -151,6 +153,7 @@ logical_axis_rules: [
151153
['conv_batch', ['data','fsdp']],
152154
['out_channels', 'tensor'],
153155
['conv_out', 'fsdp'],
156+
['conv_in', ''] # not sharded
154157
]
155158
data_sharding: [['data', 'fsdp', 'tensor']]
156159

src/maxdiffusion/configs/base_xl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ gcs_metrics: False
2727
save_config_to_gcs: False
2828
log_period: 100
2929

30+
tmp_dir: '/tmp' # directory for downloading gs:// files
31+
3032
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0'
3133
unet_checkpoint: ''
3234
revision: 'refs/pr/95'

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ write_timing_metrics: True
2626
save_config_to_gcs: False
2727
log_period: 100
2828

29+
tmp_dir: '/tmp' # directory for downloading gs:// files
30+
2931
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0'
3032
unet_checkpoint: ''
3133
revision: 'refs/pr/95'

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ def create_sharded_logical_transformer(
7171
):
7272

7373
def create_model(rngs: nnx.Rngs, wan_config: dict):
74-
wan_transformer = WanModel(**wan_config, rngs=rngs)
75-
return wan_transformer
74+
with nn_partitioning.axis_rules(config.logical_axis_rules):
75+
wan_transformer = WanModel(**wan_config, rngs=rngs)
76+
return wan_transformer
7677

7778
# 1. Load config.
7879
if restored_checkpoint:
@@ -204,15 +205,16 @@ def load_tokenizer(cls, config: HyperParameters):
204205
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
205206

206207
def create_model(rngs: nnx.Rngs, config: HyperParameters):
207-
wan_vae = AutoencoderKLWan.from_config(
208-
config.pretrained_model_name_or_path,
209-
subfolder="vae",
210-
rngs=rngs,
211-
mesh=mesh,
212-
dtype=config.activations_dtype,
213-
weights_dtype=config.weights_dtype,
214-
)
215-
return wan_vae
208+
with nn_partitioning.axis_rules(config.logical_axis_rules):
209+
wan_vae = AutoencoderKLWan.from_config(
210+
config.pretrained_model_name_or_path,
211+
subfolder="vae",
212+
rngs=rngs,
213+
mesh=mesh,
214+
dtype=config.activations_dtype,
215+
weights_dtype=config.weights_dtype,
216+
)
217+
return wan_vae
216218

217219
# 1. eval shape
218220
p_model_factory = partial(create_model, config=config)

src/maxdiffusion/pyconfig.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,15 @@ def user_init(raw_keys):
196196

197197
# Orbax doesn't save the tokenizer params, instead it loads them from the pretrained_model_name_or_path
198198
raw_keys["tokenizer_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
199+
tmp_dir = raw_keys.get("tmp_dir", "/tmp")
199200
if "gs://" in raw_keys["tokenizer_model_name_or_path"]:
200-
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp")
201+
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], tmp_dir)
201202
if "gs://" in raw_keys["pretrained_model_name_or_path"]:
202-
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp")
203+
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], tmp_dir)
203204
if "gs://" in raw_keys["unet_checkpoint"]:
204-
raw_keys["unet_checkpoint"] = max_utils.download_blobs(raw_keys["unet_checkpoint"], "/tmp")
205+
raw_keys["unet_checkpoint"] = max_utils.download_blobs(raw_keys["unet_checkpoint"], tmp_dir)
205206
if "gs://" in raw_keys["tokenizer_model_name_or_path"]:
206-
raw_keys["tokenizer_model_name_or_path"] = max_utils.download_blobs(raw_keys["tokenizer_model_name_or_path"], "/tmp")
207+
raw_keys["tokenizer_model_name_or_path"] = max_utils.download_blobs(raw_keys["tokenizer_model_name_or_path"], tmp_dir)
207208
if "gs://" in raw_keys["dataset_name"]:
208209
raw_keys["dataset_name"] = max_utils.download_blobs(raw_keys["dataset_name"], raw_keys["dataset_save_location"])
209210
raw_keys["dataset_save_location"] = raw_keys["dataset_name"]

0 commit comments

Comments
 (0)