Skip to content

Commit f0e04ff

Browse files
committed
fix
1 parent 4c4446c commit f0e04ff

2 files changed

Lines changed: 47 additions & 9 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,19 +221,39 @@ def load_vae_weights(
221221

222222
for pt_key, tensor in tensors.items():
223223
renamed_pt_key = rename_key(pt_key)
224-
if ".resnets." in renamed_pt_key:
225-
# pattern: resnets.0 -> resnets_0
226-
# We need to capture the number after resnets
227-
import re
228-
# Replace resnets.N with resnets_N
229-
renamed_pt_key = re.sub(r"resnets\.(\d+)", r"resnets_\1", renamed_pt_key)
230-
224+
231225
pt_tuple_key = tuple(renamed_pt_key.split("."))
232226

227+
# Handle resnets.N -> resnets with stacking
228+
resnet_index = None
229+
if "resnets" in pt_tuple_key:
230+
pt_list = list(pt_tuple_key)
231+
# Iterate backwards to safely pop
232+
for i in range(len(pt_list) - 1, -1, -1):
233+
if pt_list[i] == "resnets" and i + 1 < len(pt_list) and pt_list[i+1].isdigit():
234+
resnet_index = int(pt_list[i+1])
235+
pt_list.pop(i+1)
236+
break
237+
pt_tuple_key = tuple(pt_list)
238+
233239
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
234240
flax_key = _tuple_str_to_int(flax_key)
235-
236-
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
241+
242+
if resnet_index is not None:
243+
if flax_key in flax_state_dict:
244+
current_tensor = flax_state_dict[flax_key]
245+
else:
246+
# Initialize with correct shape from random_flax_state_dict
247+
target_shape = random_flax_state_dict[flax_key].shape
248+
current_tensor = jnp.zeros(target_shape, dtype=flax_tensor.dtype)
249+
250+
# Place the tensor at the correct index
251+
# flax_tensor is (..., C), target is (N_resnets, ..., C)
252+
# We need to ensure dims match for assignment
253+
current_tensor = current_tensor.at[resnet_index].set(flax_tensor)
254+
flax_state_dict[flax_key] = current_tensor
255+
else:
256+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
237257

238258
validate_flax_state_dict(eval_shapes, flax_state_dict)
239259
flax_state_dict = unflatten_dict(flax_state_dict)

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@
1010
from maxdiffusion.models.modeling_flax_pytorch_utils import validate_flax_state_dict
1111
from flax.traverse_util import flatten_dict
1212

13+
class LTX2VideoConfig:
14+
def __init__(self):
15+
self.in_channels = 128
16+
self.out_channels = 128
17+
self.patch_size = 1
18+
self.patch_size_t = 1
19+
self.num_attention_heads = 32
20+
self.attention_head_dim = 128
21+
self.cross_attention_dim = 4096
22+
self.audio_in_channels = 128
23+
self.audio_out_channels = 128
24+
self.audio_patch_size = 1
25+
self.audio_patch_size_t = 1
26+
self.audio_num_attention_heads = 32
27+
self.audio_attention_head_dim = 128 # Default is 64 but we want 128
28+
self.audio_cross_attention_dim = 2048
29+
self.num_layers = 48
30+
1331
class LTX2UtilsTest(unittest.TestCase):
1432
def setUp(self):
1533
self.device = "cpu"

0 commit comments

Comments
 (0)