Skip to content

Commit 14bee9e

Browse files
committed
Adding check for batch size divisibility before sharding video condition tensor
1 parent 5a05e75 commit 14bee9e

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,11 @@ def prepare_latents_i2v_base(
544544
vae_dtype = getattr(self.vae, "dtype", jnp.float32)
545545
video_condition = video_condition.astype(vae_dtype)
546546
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
547-
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
548-
video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec)
547+
data_axis_name = self.config.mesh_axes[0]
548+
data_mesh_size = self.mesh.shape[self.config.mesh_axes[0]]
549+
if video_condition.shape[0] % data_mesh_size == 0:
550+
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
551+
video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec)
549552
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()
550553

551554
# Normalize latents

0 commit comments

Comments
 (0)