Skip to content

Commit 6318fa6

Browse files
committed
Fix namedsharding
1 parent e2cb67f commit 6318fa6

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def __call__(
397397
num_channels_latents=num_channel_latents,
398398
)
399399

400-
data_sharding = NamedSharding(self.devices_array, P())
400+
data_sharding = NamedSharding(self.mesh, P())
401401
if len(prompt) % jax.device_count() == 0:
402402
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
403403

0 commit comments

Comments
 (0)