Skip to content

Commit a157b93

Browse files
committed
gather videos for multihost
1 parent f31d659 commit a157b93

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def __call__(
457457
video = self.vae.decode(latents, self.vae_cache)[0]
458458

459459
video = jnp.transpose(video, (0, 4, 1, 2, 3))
460+
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
460461
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
461462
video = self.video_processor.postprocess_video(video, output_type="np")
462463
return video

0 commit comments

Comments
 (0)