Skip to content

Commit 2984677

Browse files
committed
applying decode fix
1 parent 7cfcce0 commit 2984677

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@ def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray:
556556

557557
video = jnp.transpose(video, (0, 4, 1, 2, 3))
558558
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
559+
video = (video / 2.0) + 0.5
560+
video = jnp.clip(video, 0.0, 1.0)
559561
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
560562
return self.video_processor.postprocess_video(video, output_type="np")
561563

0 commit comments

Comments
 (0)