Skip to content

Commit 9af19a4

Browse files
committed
Remove extra params loading error
1 parent 698a5d0 commit 9af19a4

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
160160
for path, val in flax.traverse_util.flatten_dict(params).items():
161161
if restored_checkpoint:
162162
path = path[:-1]
163+
if path not in logical_state_sharding:
164+
continue
163165
sharding = logical_state_sharding[path].value
164166
state[path].value = device_put_replicated(val, sharding)
165167
state = nnx.from_flat_state(state)

0 commit comments

Comments
 (0)