Skip to content

Commit dbee23c

Browse files
committed
fix
1 parent bf6ca95 commit dbee23c

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,14 +1207,16 @@ def __call__(
12071207
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 2, axis=0)
12081208

12091209
if hasattr(self, "mesh") and self.mesh is not None:
1210-
data_sharding = NamedSharding(self.mesh, P())
1210+
data_sharding_3d = NamedSharding(self.mesh, P())
1211+
data_sharding_2d = NamedSharding(self.mesh, P())
12111212
if hasattr(self, "config") and hasattr(self.config, "data_sharding"):
1212-
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
1213+
data_sharding_3d = NamedSharding(self.mesh, P(*self.config.data_sharding[:3]))
1214+
data_sharding_2d = NamedSharding(self.mesh, P(*self.config.data_sharding[:2]))
12131215
if isinstance(prompt_embeds_jax, list):
1214-
prompt_embeds_jax = [jax.device_put(x, data_sharding) for x in prompt_embeds_jax]
1216+
prompt_embeds_jax = [jax.device_put(x, data_sharding_3d) for x in prompt_embeds_jax]
12151217
else:
1216-
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, data_sharding)
1217-
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, data_sharding)
1218+
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, data_sharding_3d)
1219+
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, data_sharding_2d)
12181220

12191221
# GraphDef and State
12201222
graphdef, state = nnx.split(self.transformer)

0 commit comments

Comments
 (0)