Skip to content

Commit e075de4

Browse files
committed
Modified wan_utils.py, dim errors corrected
1 parent a95c50f commit e075de4

2 files changed

Lines changed: 6 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ def load_base_wan_transformer(
259259
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0")
260260
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
261261
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
262-
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
262+
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
263+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
263264
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
264265
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
265266
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,17 @@ def prepare_latents(
112112
if last_image is None:
113113
mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0)
114114
else:
115-
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
115+
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
116116
first_frame_mask = mask_lat_size[:, :, 0:1]
117117
first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2)
118118
mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2)
119119
mask_lat_size = mask_lat_size.reshape(
120120
batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width
121121
)
122122
mask_lat_size = jnp.swapaxes(mask_lat_size, 1, 2)
123-
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=1)
123+
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 3, 4, 1))
124+
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1)
125+
124126
return latents, condition, None
125127

126128

0 commit comments

Comments
 (0)