Skip to content

Commit 0771fe1

Browse files
committed
transformer weight
1 parent fed003b commit 0771fe1

2 files changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,14 @@ def load_transformer_weights(
220220
num_layers: int = 48,
221221
scan_layers: bool = True,
222222
subfolder: str = "transformer",
223+
filename: str = None,
223224
):
224225
device = jax.local_devices(backend=device)[0]
225226
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
226227

227228
with jax.default_device(device):
228229
# Support sharded loading
229-
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device)
230+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device, filename=filename)
230231

231232
flax_state_dict = {}
232233
cpu = jax.local_devices(backend="cpu")[0]

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,15 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
168168
else:
169169
params = restored_checkpoint["ltx2_state"]
170170
else:
171+
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
172+
subfolder = "" if getattr(config, "model_name", "") == "ltx2.3" else subfolder
171173
params = load_transformer_weights(
172174
config.pretrained_model_name_or_path,
173175
params, # eval_shapes
174176
"cpu",
175177
scan_layers=getattr(config, "scan_layers", True),
176178
subfolder=subfolder,
179+
filename=filename,
177180
)
178181

179182
params = jax.tree_util.tree_map_with_path(

0 commit comments

Comments
 (0)