Skip to content

Commit 0144d4b

Browse files
committed
print statement added
1 parent 55d7252 commit 0144d4b

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,12 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
123123

124124
if self.mesh is not None:
125125
# (B, D, H, W, C)
126+
print(f"DEBUG: Checking sharding logic. Shape: {x_padded.shape}, Data Mesh: {self.mesh.shape['data']}")
126127
if x_padded.shape[0] % self.mesh.shape['data'] == 0:
128+
print("DEBUG: Applying 'data' sharding constraint.")
127129
x_padded = with_sharding_constraint(x_padded, PartitionSpec('data', None, None, None, None))
130+
else:
131+
print("DEBUG: Skipping 'data' sharding constraint (not divisible).")
128132

129133
out = self.conv(x_padded)
130134
return out

0 commit comments

Comments
 (0)