We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6ec20ea commit fa69260Copy full SHA for fa69260
1 file changed
src/maxdiffusion/pipelines/wan/wan_pipeline.py
@@ -527,6 +527,7 @@ def __call__(
527
video = self.vae.decode(latents, self.vae_cache)[0]
528
529
video = jnp.transpose(video, (0, 4, 1, 2, 3))
530
+ video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
531
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
532
video = self.video_processor.postprocess_video(video, output_type="np")
533
return video
0 commit comments