Skip to content

Commit cd3616a

Browse files
committed
sharding attempt
1 parent c4f15c9 commit cd3616a

2 files changed

Lines changed: 12 additions & 0 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ skip_jax_distributed_system: False
44
attention: 'flash'
55
attention_sharding_uniform: True
66
precision: 'bf16'
7+
data_sharding: ['data', 'fsdp', 'context', 'tensor']
78
remat_policy: "NONE"
89
names_which_can_be_saved: []
910
names_which_can_be_offloaded: []

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def _load_and_init(cls, config: HyperParameters, restored_checkpoint, vae_only=F
530530
vocoder=components["vocoder"]
531531
)
532532
pipeline.mesh = components["mesh"]
533+
pipeline.config = config
533534
if load_transformer:
534535
pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, pipeline.mesh)
535536
return pipeline, pipeline.transformer
@@ -1204,6 +1205,16 @@ def __call__(
12041205
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
12051206
latents_jax = jnp.concatenate([latents_jax] * 2, axis=0)
12061207
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 2, axis=0)
1208+
1209+
if hasattr(self, "mesh") and self.mesh is not None:
1210+
data_sharding = NamedSharding(self.mesh, P())
1211+
if hasattr(self, "config") and hasattr(self.config, "data_sharding"):
1212+
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
1213+
if isinstance(prompt_embeds_jax, list):
1214+
prompt_embeds_jax = [jax.device_put(x, data_sharding) for x in prompt_embeds_jax]
1215+
else:
1216+
prompt_embeds_jax = jax.device_put(prompt_embeds_jax, data_sharding)
1217+
prompt_attention_mask_jax = jax.device_put(prompt_attention_mask_jax, data_sharding)
12071218

12081219
# GraphDef and State
12091220
graphdef, state = nnx.split(self.transformer)

0 commit comments

Comments
 (0)