We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent bce5842 commit 8c2b884Copy full SHA for 8c2b884
1 file changed
src/maxdiffusion/models/ltx2/ltx2_3_utils.py
@@ -52,14 +52,19 @@ def load_connectors_weights(
52
53
segments = flax_key_str.split(".")
54
55
- # Find if there is a layer index (digit)
+ # Only extract digit if it immediately follows 'stacked_blocks'
56
layer_idx = None
57
base_segments = []
58
- for seg in segments:
59
- if seg.isdigit():
60
- layer_idx = int(seg)
+ i = 0
+ while i < len(segments):
+ seg = segments[i]
61
+ if seg == "stacked_blocks" and i + 1 < len(segments) and segments[i+1].isdigit():
62
+ base_segments.append(seg)
63
+ layer_idx = int(segments[i+1])
64
+ i += 2
65
else:
66
base_segments.append(seg)
67
+ i += 1
68
69
if layer_idx is not None:
70
base_key = _tuple_str_to_int(base_segments)
0 commit comments