Skip to content

Commit f1e245b

Browse files
committed
fix and ltx2 backward compatibility
1 parent e03eda6 commit f1e245b

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,14 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
188188
for path, val in flax.traverse_util.flatten_dict(params).items():
189189
if restored_checkpoint:
190190
path = path[:-1]
191+
192+
path_str = tuple(str(k) for k in path)
193+
if path not in logical_state_sharding and path_str not in logical_state_sharding:
194+
continue
195+
191196
try:
192197
sharding = logical_state_sharding[path].value
193198
except KeyError:
194-
path_str = tuple(str(k) for k in path)
195199
sharding = logical_state_sharding[path_str].value
196200
state[path].value = device_put_replicated(val, sharding)
197201
state = nnx.from_flat_state(state)

0 commit comments

Comments
 (0)