1515"""
1616
1717from abc import ABC
18- from maxdiffusion .checkpointing .checkpointing_utils import (create_orbax_checkpoint_manager )
18+ import json
19+
20+ import jax
21+ import numpy as np
22+ from maxdiffusion .checkpointing .checkpointing_utils import (create_orbax_checkpoint_manager , load_params_from_path )
1923from ..pipelines .wan .wan_pipeline import WanPipeline
2024from .. import max_logging , max_utils
25+ import orbax .checkpoint as ocp
2126
2227WAN_CHECKPOINT = "WAN_CHECKPOINT"
2328
@@ -44,22 +49,123 @@ def _create_optimizer(self, model, config, learning_rate):
4449 return tx , learning_rate_scheduler
4550
4651 def load_wan_configs_from_orbax (self , step ):
47- max_logging .log ("Restoring stable diffusion configs" )
4852 if step is None :
4953 step = self .checkpoint_manager .latest_step ()
54+ max_logging .log (f"Latest WAN checkpoint step: { step } " )
5055 if step is None :
5156 return None
57+ max_logging .log (f"Loading WAN checkpoint from step { step } " )
58+ metadatas = self .checkpoint_manager .item_metadata (step )
59+
60+ transformer_metadata = metadatas .wan_state
61+ abstract_tree_structure_params = jax .tree_util .tree_map (
62+ ocp .utils .to_shape_dtype_struct , transformer_metadata
63+ )
64+ params_restore = ocp .args .PyTreeRestore (
65+ restore_args = jax .tree .map (
66+ lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
67+ abstract_tree_structure_params ,
68+ )
69+ )
70+
71+ params_restore_util_way = load_params_from_path (
72+ self .config ,
73+ self .checkpoint_manager ,
74+ abstract_tree_structure_params ,
75+ "wan_state" ,
76+ step
77+ )
78+
79+ max_logging .log ("Restoring WAN checkpoint" )
80+ restored_checkpoint = self .checkpoint_manager .restore (
81+ step ,
82+ args = ocp .args .Composite (
83+ wan_state = params_restore ,
84+ # wan_state=params_restore_util_way,
85+ wan_config = ocp .args .JsonRestore (),
86+ ),
87+ )
88+ return restored_checkpoint
5289
5390 def load_diffusers_checkpoint (self ):
5491 pipeline = WanPipeline .from_pretrained (self .config )
5592 return pipeline
5693
5794 def load_checkpoint (self , step = None ):
58- model_configs = self .load_wan_configs_from_orbax (step )
95+ restored_checkpoint = self .load_wan_configs_from_orbax (step )
5996
60- if model_configs :
61- raise NotImplementedError ("model configs should not exist in orbax" )
97+ if restored_checkpoint :
98+ max_logging .log ("Loading WAN pipeline from checkpoint" )
99+ pipeline = WanPipeline .from_checkpoint (self .config , restored_checkpoint )
62100 else :
101+ max_logging .log ("No checkpoint found, loading default pipeline." )
63102 pipeline = self .load_diffusers_checkpoint ()
64103
65104 return pipeline
105+
106+ def save_checkpoint (self , train_step , pipeline : WanPipeline , train_states : dict ):
107+ """Saves the training state and model configurations."""
108+ def config_to_json (model_or_config ):
109+ return json .loads (model_or_config .to_json_string ())
110+ max_logging .log (f"Saving checkpoint for step { train_step } " )
111+ items = {
112+ "wan_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
113+ }
114+
115+ items ["wan_state" ] = ocp .args .PyTreeSave (train_states )
116+
117+ # Save the checkpoint
118+ self .checkpoint_manager .save (train_step , args = ocp .args .Composite (** items ))
119+ max_logging .log (f"Checkpoint for step { train_step } saved." )
120+
121+ def save_checkpoint_orig (self , train_step , pipeline : WanPipeline , train_states : dict ):
122+ """Saves the training state and model configurations."""
123+ def config_to_json (model_or_config ):
124+ """
125+ only save the config that is needed and can be serialized to JSON.
126+ """
127+ if not hasattr (model_or_config , "config" ):
128+ return None
129+ source_config = dict (model_or_config .config )
130+
131+ # 1. configs that can be serialized to JSON
132+ SAFE_KEYS = [
133+ '_class_name' , '_diffusers_version' , 'model_type' , 'patch_size' ,
134+ 'num_attention_heads' , 'attention_head_dim' , 'in_channels' ,
135+ 'out_channels' , 'text_dim' , 'freq_dim' , 'ffn_dim' , 'num_layers' ,
136+ 'cross_attn_norm' , 'qk_norm' , 'eps' , 'image_dim' ,
137+ 'added_kv_proj_dim' , 'rope_max_seq_len' , 'pos_embed_seq_len' ,
138+ 'flash_min_seq_length' , 'flash_block_sizes' , 'attention' ,
139+ '_use_default_values'
140+ ]
141+
142+ # 2. save the config that are in the SAFE_KEYS list
143+ clean_config = {}
144+ for key in SAFE_KEYS :
145+ if key in source_config :
146+ clean_config [key ] = source_config [key ]
147+
148+ # 3. deal with special data type and precision
149+ if 'dtype' in source_config and hasattr (source_config ['dtype' ], 'name' ):
150+ clean_config ['dtype' ] = source_config ['dtype' ].name # e.g 'bfloat16'
151+
152+ if 'weights_dtype' in source_config and hasattr (source_config ['weights_dtype' ], 'name' ):
153+ clean_config ['weights_dtype' ] = source_config ['weights_dtype' ].name
154+
155+ if 'precision' in source_config and isinstance (source_config ['precision' ], Precision ):
156+ clean_config ['precision' ] = source_config ['precision' ].name # e.g. 'HIGHEST'
157+
158+ return clean_config
159+
160+ items_to_save = {
161+ "transformer_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
162+ }
163+
164+ items_to_save ["transformer_states" ] = ocp .args .PyTreeSave (train_states )
165+
166+ # Create CompositeArgs for Orbax
167+ save_args = ocp .args .Composite (** items_to_save )
168+
169+ # Save the checkpoint
170+ self .checkpoint_manager .save (train_step , args = save_args )
171+ max_logging .log (f"Checkpoint for step { train_step } saved." )
0 commit comments