Skip to content

Commit a4c45d6

Browse files
committed
transformer weight loading bug with scan layers = false
1 parent 740063b commit a4c45d6

2 files changed

Lines changed: 7 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,8 @@ def init_block(rngs):
907907
norm_elementwise_affine=self.norm_elementwise_affine,
908908
norm_eps=self.norm_eps,
909909
rope_type=self.rope_type,
910+
gated_attn=self.gated_attn,
911+
cross_attn_mod=self.cross_attn_mod,
910912
dtype=self.dtype,
911913
weights_dtype=self.weights_dtype,
912914
mesh=self.mesh,

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
187187
for path, val in flax.traverse_util.flatten_dict(params).items():
188188
if restored_checkpoint:
189189
path = path[:-1]
190-
sharding = logical_state_sharding[path].value
190+
try:
191+
sharding = logical_state_sharding[path].value
192+
except KeyError:
193+
path_str = tuple(str(k) for k in path)
194+
sharding = logical_state_sharding[path_str].value
191195
state[path].value = device_put_replicated(val, sharding)
192196
state = nnx.from_flat_state(state)
193197

0 commit comments

Comments
 (0)