Skip to content

Commit b64c2bf

Browse files
committed
Fix IndivisibleError in VAE sharding by checking tensor axis size
1 parent ef956e2 commit b64c2bf

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ def __init__(
100100
self.mesh = mesh
101101
# Set sharding dynamically based on out_channels.
102102
num_context_axis_devices = mesh.shape["context"]
103+
tensor_axis_size = mesh.shape.get("tensor", 1)
103104
kernel_sharding = (None, None, None, None, None)
104-
if out_channels % num_context_axis_devices == 0:
105+
if out_channels % num_context_axis_devices == 0 and out_channels % tensor_axis_size == 0:
105106
kernel_sharding = (None, None, None, None, "conv_out")
106107

107108
self.conv = nnx.Conv(

0 commit comments

Comments
 (0)