Skip to content

Commit 39b227b

Browse files
authored
Wan checkpointing (#238)
* initial checkpointing Signed-off-by: Kunjan Patel <kunjan@ucla.edu> * Support loading from gcs * Formatting Signed-off-by: Kunjan Patel <kunjan@ucla.edu> * Formatting Signed-off-by: Kunjan Patel <kunjan@ucla.edu> * Formatting Signed-off-by: Kunjan Patel <kunjan@ucla.edu> * Formatting Signed-off-by: Kunjan Patel <kunjan@ucla.edu> * Formatting Signed-off-by: Kunjan Patel <kunjan@ucla.edu> * Formatting Signed-off-by: Kunjan Patel <kunjan@ucla.edu> * Set checkpoint_dir default to empty Signed-off-by: Kunjan Patel <kunjan@ucla.edu> --------- Signed-off-by: Kunjan Patel <kunjan@ucla.edu>
1 parent 224a951 commit 39b227b

7 files changed

Lines changed: 237 additions & 20 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1919

20-
from typing import Optional, Any
20+
from typing import Optional, Tuple
2121
import jax
2222
import numpy as np
2323
import os
24-
2524
import orbax.checkpoint
2625
from maxdiffusion import max_logging
2726
from etils import epath
2827
from flax.training import train_state
28+
from flax.traverse_util import flatten_dict, unflatten_dict
2929
import orbax
3030
import orbax.checkpoint as ocp
3131
from orbax.checkpoint.logging import AbstractLogger
@@ -34,6 +34,7 @@
3434
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
3535
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
3636
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
37+
WAN_CHECKPOINT = "WAN_CHECKPOINT"
3738

3839

3940
def create_orbax_checkpoint_manager(
@@ -59,6 +60,8 @@ def create_orbax_checkpoint_manager(
5960

6061
if checkpoint_type == FLUX_CHECKPOINT:
6162
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
63+
elif checkpoint_type == WAN_CHECKPOINT:
64+
item_names = ("wan_state", "wan_config")
6265
else:
6366
item_names = (
6467
"unet_config",
@@ -78,7 +81,7 @@ def create_orbax_checkpoint_manager(
7881
if dataset_type == "grain":
7982
item_names += ("iter",)
8083

81-
print("item_names: ", item_names)
84+
max_logging.log(f"item_names: {item_names}")
8285

8386
mngr = CheckpointManager(
8487
p,
@@ -133,6 +136,7 @@ def load_params_from_path(
133136
unboxed_abstract_params,
134137
checkpoint_item: str,
135138
step: Optional[int] = None,
139+
checkpoint_item_config: Optional[str] = None,
136140
):
137141
ckptr = ocp.PyTreeCheckpointer()
138142

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@
1515
"""
1616

1717
from abc import ABC
18+
import json
19+
20+
import jax
21+
import numpy as np
1822
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
1923
from ..pipelines.wan.wan_pipeline import WanPipeline
2024
from .. import max_logging, max_utils
25+
import orbax.checkpoint as ocp
26+
from etils import epath
2127

2228
WAN_CHECKPOINT = "WAN_CHECKPOINT"
2329

@@ -28,7 +34,7 @@ def __init__(self, config, checkpoint_type):
2834
self.config = config
2935
self.checkpoint_type = checkpoint_type
3036

31-
self.checkpoint_manager = create_orbax_checkpoint_manager(
37+
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
3238
self.config.checkpoint_dir,
3339
enable_checkpointing=True,
3440
save_interval_steps=1,
@@ -44,22 +50,134 @@ def _create_optimizer(self, model, config, learning_rate):
4450
return tx, learning_rate_scheduler
4551

4652
def load_wan_configs_from_orbax(self, step):
47-
max_logging.log("Restoring stable diffusion configs")
4853
if step is None:
4954
step = self.checkpoint_manager.latest_step()
55+
max_logging.log(f"Latest WAN checkpoint step: {step}")
5056
if step is None:
5157
return None
58+
max_logging.log(f"Loading WAN checkpoint from step {step}")
59+
metadatas = self.checkpoint_manager.item_metadata(step)
60+
61+
transformer_metadata = metadatas.wan_state
62+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
63+
params_restore = ocp.args.PyTreeRestore(
64+
restore_args=jax.tree.map(
65+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
66+
abstract_tree_structure_params,
67+
)
68+
)
69+
70+
max_logging.log("Restoring WAN checkpoint")
71+
restored_checkpoint = self.checkpoint_manager.restore(
72+
directory=epath.Path(self.config.checkpoint_dir),
73+
step=step,
74+
args=ocp.args.Composite(
75+
wan_state=params_restore,
76+
# wan_state=params_restore_util_way,
77+
wan_config=ocp.args.JsonRestore(),
78+
),
79+
)
80+
return restored_checkpoint
5281

5382
def load_diffusers_checkpoint(self):
5483
pipeline = WanPipeline.from_pretrained(self.config)
5584
return pipeline
5685

5786
def load_checkpoint(self, step=None):
58-
model_configs = self.load_wan_configs_from_orbax(step)
87+
restored_checkpoint = self.load_wan_configs_from_orbax(step)
5988

60-
if model_configs:
61-
raise NotImplementedError("model configs should not exist in orbax")
89+
if restored_checkpoint:
90+
max_logging.log("Loading WAN pipeline from checkpoint")
91+
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
6292
else:
93+
max_logging.log("No checkpoint found, loading default pipeline.")
6394
pipeline = self.load_diffusers_checkpoint()
6495

6596
return pipeline
97+
98+
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
99+
"""Saves the training state and model configurations."""
100+
101+
def config_to_json(model_or_config):
102+
return json.loads(model_or_config.to_json_string())
103+
104+
max_logging.log(f"Saving checkpoint for step {train_step}")
105+
items = {
106+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
107+
}
108+
109+
items["wan_state"] = ocp.args.PyTreeSave(train_states)
110+
111+
# Save the checkpoint
112+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
113+
max_logging.log(f"Checkpoint for step {train_step} saved.")
114+
115+
116+
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
117+
"""Saves the training state and model configurations."""
118+
119+
def config_to_json(model_or_config):
120+
"""
121+
only save the config that is needed and can be serialized to JSON.
122+
"""
123+
if not hasattr(model_or_config, "config"):
124+
return None
125+
source_config = dict(model_or_config.config)
126+
127+
# 1. configs that can be serialized to JSON
128+
SAFE_KEYS = [
129+
"_class_name",
130+
"_diffusers_version",
131+
"model_type",
132+
"patch_size",
133+
"num_attention_heads",
134+
"attention_head_dim",
135+
"in_channels",
136+
"out_channels",
137+
"text_dim",
138+
"freq_dim",
139+
"ffn_dim",
140+
"num_layers",
141+
"cross_attn_norm",
142+
"qk_norm",
143+
"eps",
144+
"image_dim",
145+
"added_kv_proj_dim",
146+
"rope_max_seq_len",
147+
"pos_embed_seq_len",
148+
"flash_min_seq_length",
149+
"flash_block_sizes",
150+
"attention",
151+
"_use_default_values",
152+
]
153+
154+
# 2. save the config that are in the SAFE_KEYS list
155+
clean_config = {}
156+
for key in SAFE_KEYS:
157+
if key in source_config:
158+
clean_config[key] = source_config[key]
159+
160+
# 3. deal with special data type and precision
161+
if "dtype" in source_config and hasattr(source_config["dtype"], "name"):
162+
clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16'
163+
164+
if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"):
165+
clean_config["weights_dtype"] = source_config["weights_dtype"].name
166+
167+
if "precision" in source_config and isinstance(source_config["precision"]):
168+
clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST'
169+
170+
return clean_config
171+
172+
items_to_save = {
173+
"transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
174+
}
175+
176+
items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states)
177+
178+
# Create CompositeArgs for Orbax
179+
save_args = ocp.args.Composite(**items_to_save)
180+
181+
# Save the checkpoint
182+
self.checkpoint_manager.save(train_step, args=save_args)
183+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ remat_policy: "NONE"
198198

199199
# checkpoint every number of samples, -1 means don't checkpoint.
200200
checkpoint_every: -1
201+
checkpoint_dir: ""
201202
# enables one replica to read the ckpt then broadcast to the rest
202203
enable_single_replica_ckpt_restoring: False
203204

src/maxdiffusion/configuration_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
from collections import OrderedDict
2525
from pathlib import PosixPath
2626
from typing import Any, Dict, Tuple, Union
27-
27+
from . import max_logging
2828
import numpy as np
2929

3030
from huggingface_hub import create_repo, hf_hub_download
3131
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
3232
from requests import HTTPError
33-
33+
import jax.numpy as jnp
3434
from . import __version__
3535
from .utils import (
3636
DIFFUSERS_CACHE,
@@ -47,6 +47,21 @@
4747

4848
_re_configuration_file = re.compile(r"config\.(.*)\.json")
4949

50+
class CustomEncoder(json.JSONEncoder):
51+
"""
52+
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
53+
"""
54+
def default(self, o):
55+
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
56+
if isinstance(o, type(jnp.dtype('bfloat16'))):
57+
return str(o)
58+
# Add fallbacks for other numpy types if needed
59+
if isinstance(o, np.integer):
60+
return int(o)
61+
if isinstance(o, np.floating):
62+
return float(o)
63+
# Let the base class default method raise the TypeError for other types
64+
return super().default(o)
5065

5166
class FrozenDict(OrderedDict):
5267

@@ -579,8 +594,25 @@ def to_json_saveable(value):
579594
config_dict.pop("precision", None)
580595
config_dict.pop("weights_dtype", None)
581596
config_dict.pop("quant", None)
597+
keys_to_remove = []
598+
for key, value in config_dict.items():
599+
# Check the type of the value by its class name to avoid import issues
600+
if type(value).__name__ == 'Rngs':
601+
keys_to_remove.append(key)
602+
603+
if keys_to_remove:
604+
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
605+
for key in keys_to_remove:
606+
config_dict.pop(key)
607+
608+
try:
609+
610+
json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder)
611+
except Exception as e:
612+
max_logging.log(f"Error serializing config to JSON: {e}")
613+
raise e
582614

583-
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
615+
return json_str + "\n"
584616

585617
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
586618
"""

src/maxdiffusion/generate_wan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525

2626
def run(config, pipeline=None, filename_prefix=""):
2727
print("seed: ", config.seed)
28+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
29+
30+
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
31+
pipeline = checkpoint_loader.load_checkpoint()
2832
if pipeline is None:
2933
pipeline = WanPipeline.from_pretrained(config)
3034
s0 = time.perf_counter()

0 commit comments

Comments
 (0)