Skip to content

Commit 1026e0a

Browse files
committed
fix and ltx2 backward compatibility
1 parent b83a1ab commit 1026e0a

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
379379
logical_state_spec = nnx.get_partition_spec(state)
380380
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
381381
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
382+
params = state.to_pure_dict()
383+
state = dict(nnx.to_flat_state(state))
382384
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
383385
params = load_connectors_weights(
384386
config.pretrained_model_name_or_path,

0 commit comments

Comments
 (0)