Skip to content

Commit 23077de

Browse files
committed
sharding before call to encoder
1 parent 37361b1 commit 23077de

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,6 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
134134
else:
135135
x_padded = x
136136

137-
if self.mesh is not None:
138-
# (B, D, H, W, C)
139-
if x_padded.shape[0] % self.mesh.shape['data'] == 0:
140-
x_padded = with_sharding_constraint(x_padded, PartitionSpec('data', None, None, None, None))
141-
142137
out = self.conv(x_padded)
143138
return out
144139

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ def prepare_latents_i2v_base(
529529

530530
vae_dtype = getattr(self.vae, "dtype", jnp.float32)
531531
video_condition = video_condition.astype(vae_dtype)
532+
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) # typically P('data', None, ...)
533+
video_condition = jax.lax.with_sharding_constraint(
534+
video_condition,
535+
sharding_spec
536+
)
532537

533538
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
534539
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()

0 commit comments

Comments
 (0)