Skip to content

Commit 55d7252

Browse files
committed
Only batch sharding added
1 parent 05e1b28 commit 55d7252

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373

7474
# Store the amount of padding needed *before* the depth dimension for caching logic
7575
self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0]
76+
self.mesh = mesh
7677

7778
# Set sharding dynamically based on out_channels.
7879
num_fsdp_axis_devices = mesh.device_ids.shape[1]

0 commit comments

Comments
 (0)