Skip to content

Commit d7431a9

Browse files
committed
Modified wan_utils.py, chnaged added_kv_proj_dim to 5120, errors corrected
1 parent 251e61f commit d7431a9

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ def load_base_wan_transformer(
255255
del flattened_dict
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258+
renamed_pt_key = renamed_pt_key.replace("image_embedder.ff.net.0", "image_embedder.ff.net_0")
259+
renamed_pt_key = renamed_pt_key.replace("image_embedder.ff.net.2", "image_embedder.ff.net_2")
260+
if "image_embedder.norm1.scale" in renamed_pt_key:
261+
renamed_pt_key = renamed_pt_key.replace("norm1.scale", "norm1.kernel")
258262
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
259263
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
260264
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
108108
else:
109109
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder)
110110
if config.model_type == "I2V":
111-
wan_config["added_kv_proj_dim"] = 1024
111+
wan_config["added_kv_proj_dim"] = 5120
112112
wan_config["mesh"] = mesh
113113
wan_config["dtype"] = config.activations_dtype
114114
wan_config["weights_dtype"] = config.weights_dtype
@@ -529,7 +529,7 @@ def prepare_latents_i2v_base(
529529
video_condition = video_condition.astype(vae_dtype)
530530

531531
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
532-
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0]
532+
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()
533533

534534
# Normalize latents
535535
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)

0 commit comments

Comments
 (0)