Skip to content

Commit 8a752e7

Browse files
committed
added model_name in config file
1 parent 33bf49c commit 8a752e7

3 files changed

Lines changed: 5 additions & 4 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ save_config_to_gcs: False
2828
log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
31+
model_name: wan2.1
3132

3233
# Overrides the transformer from pretrained_model_name_or_path
3334
wan_transformer_pretrained_model_name_or_path: ''

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ save_config_to_gcs: False
2828
log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers'
31-
run_wan2_2: True
31+
model_name: wan2.2
3232

3333
# Overrides the transformer from pretrained_model_name_or_path
3434
wan_transformer_pretrained_model_name_or_path: ''

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(
213213
self.devices_array = devices_array
214214
self.mesh = mesh
215215
self.config = config
216-
self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False
216+
self.run_wan2_2 = config.model_name == "wan2.2"
217217

218218
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
219219
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
@@ -379,7 +379,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
379379
mesh = Mesh(devices_array, config.mesh_axes)
380380
rng = jax.random.key(config.seed)
381381
rngs = nnx.Rngs(rng)
382-
run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False
382+
run_wan2_2 = config.model_name == "wan2.2"
383383
low_noise_transformer = None
384384
high_noise_transformer = None
385385
tokenizer = None
@@ -421,7 +421,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
421421
mesh = Mesh(devices_array, config.mesh_axes)
422422
rng = jax.random.key(config.seed)
423423
rngs = nnx.Rngs(rng)
424-
run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False
424+
run_wan2_2 = config.model_name == "wan2.2"
425425
low_noise_transformer = None
426426
high_noise_transformer = None
427427
tokenizer = None

0 commit comments

Comments
 (0)