Skip to content

Commit dfacf93

Browse files
committed
vocoder weights
1 parent c91595d commit dfacf93

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,14 +369,18 @@ def rename_for_ltx2_vocoder(key):
369369

370370

371371
def load_vocoder_weights(
372-
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, subfolder: str = "vocoder"
372+
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, subfolder: str = "vocoder", filename: str = None
373373
):
374-
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device)
374+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device, filename=filename)
375375

376376
flax_state_dict = {}
377377
cpu = jax.local_devices(backend="cpu")[0]
378378

379379
for pt_key, tensor in tensors.items():
380+
if filename and not pt_key.startswith("vocoder."):
381+
continue
382+
if filename and pt_key.startswith("vocoder."):
383+
pt_key = pt_key[len("vocoder."):]
380384
key = rename_for_ltx2_vocoder(pt_key)
381385
parts = key.split(".")
382386

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,33 +537,35 @@ def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, confi
537537
max_logging.log("Loading Vocoder...")
538538

539539
def create_model(rngs: nnx.Rngs, config: HyperParameters):
540+
vocoder_repo = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path
540541
if getattr(config, "model_name", "") == "ltx2.3":
541542
vocoder_class = LTX2VocoderWithBWE
542543
else:
543544
vocoder_class = LTX2Vocoder
544-
545+
545546
vocoder = vocoder_class.from_config(
546-
config.pretrained_model_name_or_path,
547+
vocoder_repo,
547548
subfolder="vocoder",
548549
rngs=rngs,
549550
mesh=mesh,
550551
dtype=jnp.float32,
551552
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
552553
)
553554
return vocoder
554-
555+
555556
p_model_factory = partial(create_model, config=config)
556557
vocoder = nnx.eval_shape(p_model_factory, rngs=rngs)
557558
graphdef, state, rest_of_state = nnx.split(vocoder, nnx.Param, ...)
558559
rest_of_state = jax.tree_util.tree_map(cls._init_dummy_shape, rest_of_state)
559-
560+
560561
logical_state_spec = nnx.get_partition_spec(state)
561562
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
562563
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
563564
params = state.to_pure_dict()
564565
state = dict(nnx.to_flat_state(state))
565-
566-
params = load_vocoder_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="vocoder")
566+
567+
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
568+
params = load_vocoder_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="vocoder", filename=filename)
567569
if hasattr(config, "weights_dtype"):
568570
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
569571

0 commit comments

Comments
 (0)