Skip to content

Commit 8d4546e

Browse files
committed
transformer weights fix
1 parent 852f57a commit 8d4546e

1 file changed

Lines changed: 7 additions & 29 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,6 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
140140
else:
141141
ltx2_config = LTX2VideoTransformer3DModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder)
142142

143-
# Align RoPE type with connectors
144-
ltx2_config["rope_type"] = "split"
145-
146143
if ltx2_config.get("activation_fn") == "gelu-approximate":
147144
ltx2_config["activation_fn"] = "gelu"
148145

@@ -157,13 +154,6 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
157154
ltx2_config["remat_policy"] = config.remat_policy
158155
ltx2_config["names_which_can_be_saved"] = config.names_which_can_be_saved
159156
ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded
160-
ltx2_config["use_prompt_embeddings"] = True
161-
162-
if getattr(config, "model_name", "") == "ltx2.3":
163-
ltx2_config["gated_attn"] = True
164-
ltx2_config["cross_attn_mod"] = True
165-
ltx2_config["perturbed_attn"] = True
166-
ltx2_config["use_prompt_embeddings"] = False
167157

168158
# 2. eval_shape
169159
p_model_factory = partial(create_model, ltx2_config=ltx2_config)
@@ -184,25 +174,13 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
184174
else:
185175
params = restored_checkpoint["ltx2_state"]
186176
else:
187-
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
188-
subfolder = "" if getattr(config, "model_name", "") == "ltx2.3" else subfolder
189-
190-
if tensors is not None and getattr(config, "model_name", "") == "ltx2.3":
191-
from maxdiffusion.models.ltx2.ltx2_3_utils import load_transformer_weights_2_3
192-
params = load_transformer_weights_2_3(
193-
params, # eval_shapes
194-
"cpu",
195-
tensors,
196-
scan_layers=getattr(config, "scan_layers", True),
197-
)
198-
else:
199-
params = load_transformer_weights(
200-
config.pretrained_model_name_or_path,
201-
params, # eval_shapes
202-
"cpu",
203-
scan_layers=getattr(config, "scan_layers", True),
204-
subfolder=subfolder,
205-
)
177+
params = load_transformer_weights(
178+
config.pretrained_model_name_or_path,
179+
params, # eval_shapes
180+
"cpu",
181+
scan_layers=getattr(config, "scan_layers", True),
182+
subfolder=subfolder,
183+
)
206184

207185
params = jax.tree_util.tree_map_with_path(
208186
lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params

0 commit comments

Comments
 (0)