Skip to content

Commit a95c50f

Browse files
committed
Modified wan_utils.py, dim errors corrected
1 parent d7431a9 commit a95c50f

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,11 @@ def load_base_wan_transformer(
255255
del flattened_dict
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258-
renamed_pt_key = renamed_pt_key.replace("image_embedder.ff.net.0", "image_embedder.ff.net_0")
259-
renamed_pt_key = renamed_pt_key.replace("image_embedder.ff.net.2", "image_embedder.ff.net_2")
260-
if "image_embedder.norm1.scale" in renamed_pt_key:
261-
renamed_pt_key = renamed_pt_key.replace("norm1.scale", "norm1.kernel")
258+
if "image_embedder" in renamed_pt_key:
259+
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0")
260+
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
261+
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
262+
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
262263
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
263264
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
264265
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,8 @@ def prepare_latents_i2v_base(
532532
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()
533533

534534
# Normalize latents
535-
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
536-
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
535+
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim)
536+
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim)
537537
latent_condition = (encoded_output - latents_mean) * latents_std
538538
latent_condition = latent_condition.astype(dtype)
539539

0 commit comments

Comments
 (0)