1515"""
1616
1717from abc import ABC
18+ import json
19+
20+ import jax
21+ import numpy as np
1822from maxdiffusion .checkpointing .checkpointing_utils import (create_orbax_checkpoint_manager )
1923from ..pipelines .wan .wan_pipeline import WanPipeline
2024from .. import max_logging , max_utils
25+ import orbax .checkpoint as ocp
26+ from etils import epath
2127
2228WAN_CHECKPOINT = "WAN_CHECKPOINT"
2329
@@ -28,7 +34,7 @@ def __init__(self, config, checkpoint_type):
2834 self .config = config
2935 self .checkpoint_type = checkpoint_type
3036
31- self .checkpoint_manager = create_orbax_checkpoint_manager (
37+ self .checkpoint_manager : ocp . CheckpointManager = create_orbax_checkpoint_manager (
3238 self .config .checkpoint_dir ,
3339 enable_checkpointing = True ,
3440 save_interval_steps = 1 ,
@@ -44,22 +50,134 @@ def _create_optimizer(self, model, config, learning_rate):
4450 return tx , learning_rate_scheduler
4551
4652 def load_wan_configs_from_orbax (self , step ):
47- max_logging .log ("Restoring stable diffusion configs" )
4853 if step is None :
4954 step = self .checkpoint_manager .latest_step ()
55+ max_logging .log (f"Latest WAN checkpoint step: { step } " )
5056 if step is None :
5157 return None
58+ max_logging .log (f"Loading WAN checkpoint from step { step } " )
59+ metadatas = self .checkpoint_manager .item_metadata (step )
60+
61+ transformer_metadata = metadatas .wan_state
62+ abstract_tree_structure_params = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , transformer_metadata )
63+ params_restore = ocp .args .PyTreeRestore (
64+ restore_args = jax .tree .map (
65+ lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
66+ abstract_tree_structure_params ,
67+ )
68+ )
69+
70+ max_logging .log ("Restoring WAN checkpoint" )
71+ restored_checkpoint = self .checkpoint_manager .restore (
72+ directory = epath .Path (self .config .checkpoint_dir ),
73+ step = step ,
74+ args = ocp .args .Composite (
75+ wan_state = params_restore ,
76+ # wan_state=params_restore_util_way,
77+ wan_config = ocp .args .JsonRestore (),
78+ ),
79+ )
80+ return restored_checkpoint
5281
5382 def load_diffusers_checkpoint (self ):
5483 pipeline = WanPipeline .from_pretrained (self .config )
5584 return pipeline
5685
5786 def load_checkpoint (self , step = None ):
58- model_configs = self .load_wan_configs_from_orbax (step )
87+ restored_checkpoint = self .load_wan_configs_from_orbax (step )
5988
60- if model_configs :
61- raise NotImplementedError ("model configs should not exist in orbax" )
89+ if restored_checkpoint :
90+ max_logging .log ("Loading WAN pipeline from checkpoint" )
91+ pipeline = WanPipeline .from_checkpoint (self .config , restored_checkpoint )
6292 else :
93+ max_logging .log ("No checkpoint found, loading default pipeline." )
6394 pipeline = self .load_diffusers_checkpoint ()
6495
6596 return pipeline
97+
98+ def save_checkpoint (self , train_step , pipeline : WanPipeline , train_states : dict ):
99+ """Saves the training state and model configurations."""
100+
101+ def config_to_json (model_or_config ):
102+ return json .loads (model_or_config .to_json_string ())
103+
104+ max_logging .log (f"Saving checkpoint for step { train_step } " )
105+ items = {
106+ "wan_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
107+ }
108+
109+ items ["wan_state" ] = ocp .args .PyTreeSave (train_states )
110+
111+ # Save the checkpoint
112+ self .checkpoint_manager .save (train_step , args = ocp .args .Composite (** items ))
113+ max_logging .log (f"Checkpoint for step { train_step } saved." )
114+
115+
116+ def save_checkpoint_orig (self , train_step , pipeline : WanPipeline , train_states : dict ):
117+ """Saves the training state and model configurations."""
118+
119+ def config_to_json (model_or_config ):
120+ """
121+ only save the config that is needed and can be serialized to JSON.
122+ """
123+ if not hasattr (model_or_config , "config" ):
124+ return None
125+ source_config = dict (model_or_config .config )
126+
127+ # 1. configs that can be serialized to JSON
128+ SAFE_KEYS = [
129+ "_class_name" ,
130+ "_diffusers_version" ,
131+ "model_type" ,
132+ "patch_size" ,
133+ "num_attention_heads" ,
134+ "attention_head_dim" ,
135+ "in_channels" ,
136+ "out_channels" ,
137+ "text_dim" ,
138+ "freq_dim" ,
139+ "ffn_dim" ,
140+ "num_layers" ,
141+ "cross_attn_norm" ,
142+ "qk_norm" ,
143+ "eps" ,
144+ "image_dim" ,
145+ "added_kv_proj_dim" ,
146+ "rope_max_seq_len" ,
147+ "pos_embed_seq_len" ,
148+ "flash_min_seq_length" ,
149+ "flash_block_sizes" ,
150+ "attention" ,
151+ "_use_default_values" ,
152+ ]
153+
154+ # 2. save the config that are in the SAFE_KEYS list
155+ clean_config = {}
156+ for key in SAFE_KEYS :
157+ if key in source_config :
158+ clean_config [key ] = source_config [key ]
159+
160+ # 3. deal with special data type and precision
161+ if "dtype" in source_config and hasattr (source_config ["dtype" ], "name" ):
162+ clean_config ["dtype" ] = source_config ["dtype" ].name # e.g 'bfloat16'
163+
164+ if "weights_dtype" in source_config and hasattr (source_config ["weights_dtype" ], "name" ):
165+ clean_config ["weights_dtype" ] = source_config ["weights_dtype" ].name
166+
167+ if "precision" in source_config and isinstance (source_config ["precision" ]):
168+ clean_config ["precision" ] = source_config ["precision" ].name # e.g. 'HIGHEST'
169+
170+ return clean_config
171+
172+ items_to_save = {
173+ "transformer_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
174+ }
175+
176+ items_to_save ["transformer_states" ] = ocp .args .PyTreeSave (train_states )
177+
178+ # Create CompositeArgs for Orbax
179+ save_args = ocp .args .Composite (** items_to_save )
180+
181+ # Save the checkpoint
182+ self .checkpoint_manager .save (train_step , args = save_args )
183+ max_logging .log (f"Checkpoint for step { train_step } saved." )
0 commit comments