We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e2cb67f commit 6318fa6Copy full SHA for 6318fa6
1 file changed
src/maxdiffusion/pipelines/wan/wan_pipeline.py
@@ -397,7 +397,7 @@ def __call__(
397
num_channels_latents=num_channel_latents,
398
)
399
400
- data_sharding = NamedSharding(self.devices_array, P())
+ data_sharding = NamedSharding(self.mesh, P())
401
if len(prompt) % jax.device_count() == 0:
402
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
403
0 commit comments