Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions src/maxdiffusion/models/wan/wan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def rename_for_custom_trasformer(key):
return renamed_pt_key


def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
def load_fusionx_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
):
device = jax.local_devices(backend=device)[0]
with jax.default_device(device):
if hf_download:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors")
Expand Down Expand Up @@ -97,7 +99,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
Expand All @@ -107,8 +109,10 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
return flax_state_dict


def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
def load_causvid_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
):
device = jax.local_devices(backend=device)[0]
with jax.default_device(device):
if hf_download:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt")
Expand Down Expand Up @@ -145,7 +149,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
Expand All @@ -155,18 +159,22 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
return flax_state_dict


def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
def load_wan_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
):

if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
else:
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)


def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
def load_base_wan_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
):
device = jax.local_devices(backend=device)[0]
subfolder = "transformer"
filename = "diffusion_pytorch_model.safetensors.index.json"
local_files = False
Expand Down Expand Up @@ -237,7 +245,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
# This helps with loading sharded weights directly into the accelerators without fist copying them
# all to one device and then distributing them, thus using low HBM memory.
params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu")
params = load_wan_transformer(
config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"]
)
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding[path].value
Expand Down
Loading