1616
1717import json
1818import jax
19- import numpy as np
2019from typing import Optional , Tuple
2120from ..pipelines .wan .wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
2221from .. import max_logging
2322import orbax .checkpoint as ocp
24- from etils import epath
23+ from maxdiffusion . checkpointing . checkpointing_utils import add_sharding_to_struct , get_cpu_mesh_and_sharding
2524from maxdiffusion .checkpointing .wan_checkpointer import WanCheckpointer
2625
2726
@@ -35,39 +34,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3534 max_logging .log ("No WAN checkpoint found." )
3635 return None , None
3736 max_logging .log (f"Loading WAN checkpoint from step { step } " )
37+
38+ mesh , replicated_sharding = get_cpu_mesh_and_sharding ()
3839 metadatas = self .checkpoint_manager .item_metadata (step )
3940
4041 # Handle low_noise_transformer
4142 low_noise_transformer_metadata = metadatas .low_noise_transformer_state
42- abstract_tree_structure_low_params = jax .tree_util .tree_map (
43- ocp .utils .to_shape_dtype_struct , low_noise_transformer_metadata
44- )
45- low_params_restore = ocp .args .PyTreeRestore (
46- restore_args = jax .tree .map (
47- lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
48- abstract_tree_structure_low_params ,
49- )
50- )
43+ target_shardings = jax .tree_util .tree_map (lambda x : replicated_sharding , low_noise_transformer_metadata )
44+ with mesh :
45+ abstract_tree_structure_low_params = jax .tree_util .tree_map (
46+ add_sharding_to_struct , low_noise_transformer_metadata , target_shardings
47+ )
5148
5249 # Handle high_noise_transformer
5350 high_noise_transformer_metadata = metadatas .high_noise_transformer_state
54- abstract_tree_structure_high_params = jax .tree_util .tree_map (
55- ocp .utils .to_shape_dtype_struct , high_noise_transformer_metadata
56- )
57- high_params_restore = ocp .args .PyTreeRestore (
58- restore_args = jax .tree .map (
59- lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
60- abstract_tree_structure_high_params ,
61- )
62- )
51+ target_shardings = jax .tree_util .tree_map (lambda x : replicated_sharding , high_noise_transformer_metadata )
52+ with mesh :
53+ abstract_tree_structure_high_params = jax .tree_util .tree_map (
54+ add_sharding_to_struct , high_noise_transformer_metadata , target_shardings
55+ )
6356
6457 max_logging .log ("Restoring WAN 2.2 checkpoint" )
6558 restored_checkpoint = self .checkpoint_manager .restore (
66- directory = epath .Path (self .config .checkpoint_dir ),
6759 step = step ,
6860 args = ocp .args .Composite (
69- low_noise_transformer_state = low_params_restore ,
70- high_noise_transformer_state = high_params_restore ,
61+ low_noise_transformer_state = ocp . args . StandardRestore ( abstract_tree_structure_low_params ) ,
62+ high_noise_transformer_state = ocp . args . StandardRestore ( abstract_tree_structure_high_params ) ,
7163 wan_config = ocp .args .JsonRestore (),
7264 ),
7365 )
@@ -119,8 +111,8 @@ def config_to_json(model_or_config):
119111 "wan_config" : ocp .args .JsonSave (config_to_json (pipeline .low_noise_transformer )),
120112 }
121113
122- items ["low_noise_transformer_state" ] = ocp .args .PyTreeSave (train_states ["low_noise_transformer" ])
123- items ["high_noise_transformer_state" ] = ocp .args .PyTreeSave (train_states ["high_noise_transformer" ])
114+ items ["low_noise_transformer_state" ] = ocp .args .StandardSave (train_states ["low_noise_transformer" ])
115+ items ["high_noise_transformer_state" ] = ocp .args .StandardSave (train_states ["high_noise_transformer" ])
124116
125117 # Save the checkpoint
126118 self .checkpoint_manager .save (train_step , args = ocp .args .Composite (** items ))
0 commit comments