Skip to content

Commit c00f53d

Browse files
committed
debug added for latents + nornalisation changed
1 parent 5f26c70 commit c00f53d

3 files changed

Lines changed: 36 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,8 @@ def scan_fn(carry, input_slice):
12911291
input_slice = jnp.expand_dims(input_slice, 1)
12921292
out_slice, new_carry = self.decoder(input_slice, carry)
12931293
out_swapped = out_slice[:, jnp.array([0, 2, 1, 3]), ...]
1294+
jax.debug.print("Decoder output shape: {shape}", shape=out_slice.shape)
1295+
jax.debug.print("After swap shape: {shape}", shape=out_swapped.shape)
12941296

12951297
return new_carry, out_swapped
12961298

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,13 @@ def prepare_latents_i2v_base(
532532

533533
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
534534
encoded_output = self.vae.encode(video_condition)[0].mode()
535+
536+
encoded_output = jnp.transpose(encoded_output, (0, 2, 3, 4, 1))
535537

536538
# Normalize latents
537539
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim)
538-
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim)
539-
latent_condition = (encoded_output - latents_mean) * latents_std
540+
latents_std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim)
541+
latent_condition = (encoded_output - latents_mean) / latents_std
540542
latent_condition = latent_condition.astype(dtype)
541543

542544
return latent_condition, video_condition

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def prepare_latents(
107107

108108
num_channels_latents = self.vae.z_dim
109109
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
110+
jax.debug.print("num_frames: {nf}, num_latent_frames: {nlf}, expected: {exp}",
111+
nf=num_frames,
112+
nlf=latents.shape[1],
113+
exp=num_latent_frames)
110114
latent_height = height // self.vae_scale_factor_spatial
111115
latent_width = width // self.vae_scale_factor_spatial
112116

@@ -124,6 +128,13 @@ def prepare_latents(
124128
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
125129
first_frame_mask = mask_lat_size[:, :, 0:1]
126130
first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2)
131+
jax.debug.print("first_frame_mask.shape:{shape}, is None:{isnone}",
132+
shape = first_frame_mask.shape if first_frame_mask is not None else (-1,),
133+
isnone = first_frame_mask is None)
134+
jax.debug.print("first_frame_mask_stats: min={mn:.2f}, max={mx:.2f}, mean={mean:.2f}",
135+
mn=jnp.min(first_frame_mask) if first_frame_mask is not None else 0.0,
136+
mx=jnp.max(first_frame_mask) if first_frame_mask is not None else 0.0,
137+
mean=jnp.mean(first_frame_mask) if first_frame_mask is not None else 0.0)
127138
mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2)
128139
mask_lat_size = mask_lat_size.reshape(
129140
batch_size,
@@ -135,6 +146,12 @@ def prepare_latents(
135146
)
136147
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1)
137148
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1)
149+
jax.debug.print("condition shape: {shape}, channel dim: {c}",
150+
shape=condition.shape,
151+
c=condition.shape[-1])
152+
jax.debug.print("condition stats: mask_mean={mm:.4f}, latent_mean={lm:.4f}",
153+
mm=jnp.mean(condition[..., 0]),
154+
lm=jnp.mean(condition[..., 1:]))
138155

139156
return latents, condition, None
140157

@@ -300,11 +317,24 @@ def loop_body(step, vals):
300317
encoder_hidden_states_image=image_embeds_input,
301318
)
302319
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
320+
jax.debug.print("Step {s}: latents_prev std={std:.6f}, mean={mean:.6f}",
321+
s=step,
322+
std=jnp.std(latents),
323+
mean=jnp.mean(latents))
303324
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
325+
jax.debug.print("Step {s}: latents_next std={std:.6f}, mean={mean:.6f}",
326+
s=step,
327+
std=jnp.std(latents),
328+
mean=jnp.mean(latents))
304329
latents = latents.astype(original_dtype)
305330
return latents, scheduler_state, rng
306331

307332
max_logging.log(f"Running fori_loop for {num_inference_steps} steps.")
308333
latents, _, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state, rng))
334+
jax.debug.print("Final latents states: min={lmin:.6f}, max={lmax:.6f}, mean={lmean:.6f}, std={lstd:.6f}",
335+
lmin=jnp.min(latents),
336+
lmax=jnp.max(latents),
337+
lmean=jnp.mean(latents),
338+
lstd=jnp.std(latents))
309339
max_logging.log("Finished fori_loop.")
310340
return latents

0 commit comments

Comments
 (0)