Skip to content

Commit 6011145

Browse files
committed
feat(ltx2): pass sharding specs to VAE and connector in pipeline
1 parent 75db696 commit 6011145

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,18 @@ def load_connectors(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, co
329329
max_logging.log("Loading Connectors...")
330330

331331
def create_model(rngs: nnx.Rngs, config: HyperParameters):
332+
sharding_config = getattr(config, "sharding", {})
333+
connector_strategy = sharding_config.get("text_connector", "default")
334+
connector_specs = get_sharding_specs(connector_strategy, "text_connector")
335+
332336
connectors = LTX2AudioVideoGemmaTextEncoder.from_config(
333337
config.pretrained_model_name_or_path,
334338
subfolder="connectors",
335339
rngs=rngs,
336340
mesh=mesh,
337341
dtype=jnp.float32,
338342
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
343+
sharding_specs=connector_specs,
339344
)
340345
return connectors
341346

@@ -371,13 +376,18 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
371376
max_logging.log("Loading Video VAE...")
372377

373378
def create_model(rngs: nnx.Rngs, config: HyperParameters):
379+
sharding_config = getattr(config, "sharding", {})
380+
vae_strategy = sharding_config.get("vae", "default")
381+
vae_specs = get_sharding_specs(vae_strategy, "vae")
382+
374383
vae = LTX2VideoAutoencoderKL.from_config(
375384
config.pretrained_model_name_or_path,
376385
subfolder="vae",
377386
rngs=rngs,
378387
mesh=mesh,
379388
dtype=jnp.float32,
380389
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
390+
sharding_specs=vae_specs,
381391
)
382392
return vae
383393

0 commit comments

Comments
 (0)