@@ -213,7 +213,7 @@ def __init__(
213213 self .devices_array = devices_array
214214 self .mesh = mesh
215215 self .config = config
216- self .run_wan2_2 = config .run_wan2_2 if 'run_wan2_2' in self . config . __dict__ else False
216+ self .run_wan2_2 = config .model_name == "wan2.2"
217217
218218 self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
219219 self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
@@ -379,7 +379,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
379379 mesh = Mesh (devices_array , config .mesh_axes )
380380 rng = jax .random .key (config .seed )
381381 rngs = nnx .Rngs (rng )
382- run_wan2_2 = config .run_wan2_2 if 'run_wan2_2' in config . __dict__ else False
382+ run_wan2_2 = config .model_name == "wan2.2"
383383 low_noise_transformer = None
384384 high_noise_transformer = None
385385 tokenizer = None
@@ -421,7 +421,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
421421 mesh = Mesh (devices_array , config .mesh_axes )
422422 rng = jax .random .key (config .seed )
423423 rngs = nnx .Rngs (rng )
424- run_wan2_2 = config .run_wan2_2 if 'run_wan2_2' in config . __dict__ else False
424+ run_wan2_2 = config .model_name == "wan2.2"
425425 low_noise_transformer = None
426426 high_noise_transformer = None
427427 tokenizer = None
0 commit comments