Skip to content

Commit 478030e

Browse files
committed
fix
1 parent f0e04ff commit 478030e

2 files changed

Lines changed: 33 additions & 16 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -224,34 +224,52 @@ def load_vae_weights(
224224

225225
pt_tuple_key = tuple(renamed_pt_key.split("."))
226226

227-
# Handle resnets.N -> resnets with stacking
227+
pt_list = []
228228
resnet_index = None
229-
if "resnets" in pt_tuple_key:
230-
pt_list = list(pt_tuple_key)
231-
# Iterate backwards to safely pop
232-
for i in range(len(pt_list) - 1, -1, -1):
233-
if pt_list[i] == "resnets" and i + 1 < len(pt_list) and pt_list[i+1].isdigit():
234-
resnet_index = int(pt_list[i+1])
235-
pt_list.pop(i+1)
236-
break
237-
pt_tuple_key = tuple(pt_list)
229+
230+
for part in pt_tuple_key:
231+
# Check for name_N pattern
232+
if "_" in part and part.split("_")[-1].isdigit():
233+
name = "_".join(part.split("_")[:-1])
234+
idx = int(part.split("_")[-1])
235+
236+
if name == "resnets":
237+
resnet_index = idx
238+
pt_list.append("resnets")
239+
elif name in ["down_blocks", "up_blocks", "downsamplers", "upsamplers"]:
240+
pt_list.append(name)
241+
pt_list.append(idx)
242+
else:
243+
pt_list.append(part)
244+
else:
245+
pt_list.append(part)
246+
247+
pt_tuple_key = tuple(pt_list)
238248

239249
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
250+
# _tuple_str_to_int might not be needed if we already injected ints, but it's safe
240251
flax_key = _tuple_str_to_int(flax_key)
241252

242253
if resnet_index is not None:
243254
if flax_key in flax_state_dict:
244255
current_tensor = flax_state_dict[flax_key]
245256
else:
246257
# Initialize with correct shape from random_flax_state_dict
247-
target_shape = random_flax_state_dict[flax_key].shape
248-
current_tensor = jnp.zeros(target_shape, dtype=flax_tensor.dtype)
258+
if flax_key in random_flax_state_dict:
259+
target_shape = random_flax_state_dict[flax_key].shape
260+
current_tensor = jnp.zeros(target_shape, dtype=flax_tensor.dtype)
261+
else:
262+
# Fallback if key missing (shouldn't happen with correct mapping)
263+
print(f"Warning: Key {flax_key} not found in random_flax_state_dict, cannot stack.")
264+
current_tensor = flax_tensor # Might fail shape check later
249265

250266
# Place the tensor at the correct index
251267
# flax_tensor is (..., C), target is (N_resnets, ..., C)
252-
# We need to ensure dims match for assignment
253-
current_tensor = current_tensor.at[resnet_index].set(flax_tensor)
254-
flax_state_dict[flax_key] = current_tensor
268+
if flax_key in random_flax_state_dict: # Only stack if we have a valid target
269+
current_tensor = current_tensor.at[resnet_index].set(flax_tensor)
270+
flax_state_dict[flax_key] = current_tensor
271+
else:
272+
flax_state_dict[flax_key] = flax_tensor
255273
else:
256274
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
257275

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def test_load_transformer_weights(self):
6767
audio_cross_attention_dim=self.config.audio_cross_attention_dim,
6868
num_layers=self.config.num_layers,
6969
scan_layers=True,
70-
param_dtype=jnp.bfloat16,
7170
rngs=nnx.Rngs(0),
7271
)
7372

0 commit comments

Comments
 (0)