Skip to content

Commit fa69260

Browse files
authored
gather videos for multihost (#229)
1 parent 6ec20ea commit fa69260

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ def __call__(
527527
video = self.vae.decode(latents, self.vae_cache)[0]
528528

529529
video = jnp.transpose(video, (0, 4, 1, 2, 3))
530+
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
530531
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
531532
video = self.video_processor.postprocess_video(video, output_type="np")
532533
return video

0 commit comments

Comments
 (0)