diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 2ceb0f7e6..5a27591d6 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -57,8 +57,10 @@ def rename_for_custom_trasformer(key): return renamed_pt_key -def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_fusionx_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] with jax.default_device(device): if hf_download: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors") @@ -97,7 +99,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) @@ -107,8 +109,10 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di return flax_state_dict -def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_causvid_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] with jax.default_device(device): if hf_download: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt") @@ -145,7 +149,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) @@ -155,18 +159,22 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di return flax_state_dict -def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): +def load_wan_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: - return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: - return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) else: - return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) -def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_base_wan_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] subfolder = "transformer" filename = "diffusion_pytorch_model.safetensors.index.json" local_files = False @@ -237,7 +245,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index abf449291..9ca2e03b9 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -95,7 +95,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # 4. Load pretrained weights and move them to device using the state shardings from (3) above. # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu") + params = load_wan_transformer( + config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"] + ) params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value