Skip to content

Commit d3fef93

Browse files
committed
initial checkpointing
Signed-off-by: Kunjan Patel <kunjan@ucla.edu>
1 parent aad9839 commit d3fef93

7 files changed

Lines changed: 233 additions & 24 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@
1717

1818
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1919

20-
from typing import Optional, Any
20+
from typing import Optional, Tuple
2121
import jax
2222
import numpy as np
2323
import os
24-
24+
from jaxtyping import PyTree
2525
import orbax.checkpoint
2626
from maxdiffusion import max_logging
2727
from etils import epath
2828
from flax.training import train_state
29+
from flax.traverse_util import flatten_dict, unflatten_dict
2930
import orbax
3031
import orbax.checkpoint as ocp
3132
from orbax.checkpoint.logging import AbstractLogger
@@ -34,6 +35,7 @@
3435
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
3536
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
3637
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
38+
WAN_CHECKPOINT = "WAN_CHECKPOINT"
3739

3840

3941
def create_orbax_checkpoint_manager(
@@ -59,6 +61,8 @@ def create_orbax_checkpoint_manager(
5961

6062
if checkpoint_type == FLUX_CHECKPOINT:
6163
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
64+
elif checkpoint_type == WAN_CHECKPOINT:
65+
item_names = ("wan_state", "wan_config")
6266
else:
6367
item_names = (
6468
"unet_config",
@@ -78,7 +82,7 @@ def create_orbax_checkpoint_manager(
7882
if dataset_type == "grain":
7983
item_names += ("iter",)
8084

81-
print("item_names: ", item_names)
85+
max_logging.log(f"item_names: {item_names}")
8286

8387
mngr = CheckpointManager(
8488
p,
@@ -133,6 +137,7 @@ def load_params_from_path(
133137
unboxed_abstract_params,
134138
checkpoint_item: str,
135139
step: Optional[int] = None,
140+
checkpoint_item_config: Optional[str] = None
136141
):
137142
ckptr = ocp.PyTreeCheckpointer()
138143

@@ -148,7 +153,11 @@ def load_params_from_path(
148153

149154
restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params)
150155
restored = ckptr.restore(
151-
ckpt_path, item={"params": unboxed_abstract_params}, transforms={}, restore_args={"params": restore_args}
156+
ckpt_path,
157+
item={"params": unboxed_abstract_params},
158+
transforms={},
159+
restore_args={
160+
"params": restore_args}
152161
)
153162
return restored["params"]
154163

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515
"""
1616

1717
from abc import ABC
18-
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
18+
import json
19+
20+
import jax
21+
import numpy as np
22+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager, load_params_from_path)
1923
from ..pipelines.wan.wan_pipeline import WanPipeline
2024
from .. import max_logging, max_utils
25+
import orbax.checkpoint as ocp
2126

2227
WAN_CHECKPOINT = "WAN_CHECKPOINT"
2328

@@ -44,22 +49,123 @@ def _create_optimizer(self, model, config, learning_rate):
4449
return tx, learning_rate_scheduler
4550

4651
def load_wan_configs_from_orbax(self, step):
47-
max_logging.log("Restoring stable diffusion configs")
4852
if step is None:
4953
step = self.checkpoint_manager.latest_step()
54+
max_logging.log(f"Latest WAN checkpoint step: {step}")
5055
if step is None:
5156
return None
57+
max_logging.log(f"Loading WAN checkpoint from step {step}")
58+
metadatas = self.checkpoint_manager.item_metadata(step)
59+
60+
transformer_metadata = metadatas.wan_state
61+
abstract_tree_structure_params = jax.tree_util.tree_map(
62+
ocp.utils.to_shape_dtype_struct, transformer_metadata
63+
)
64+
params_restore = ocp.args.PyTreeRestore(
65+
restore_args=jax.tree.map(
66+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
67+
abstract_tree_structure_params,
68+
)
69+
)
70+
71+
params_restore_util_way = load_params_from_path(
72+
self.config,
73+
self.checkpoint_manager,
74+
abstract_tree_structure_params,
75+
"wan_state",
76+
step
77+
)
78+
79+
max_logging.log("Restoring WAN checkpoint")
80+
restored_checkpoint = self.checkpoint_manager.restore(
81+
step,
82+
args=ocp.args.Composite(
83+
wan_state=params_restore,
84+
# wan_state=params_restore_util_way,
85+
wan_config=ocp.args.JsonRestore(),
86+
),
87+
)
88+
return restored_checkpoint
5289

5390
def load_diffusers_checkpoint(self):
5491
pipeline = WanPipeline.from_pretrained(self.config)
5592
return pipeline
5693

5794
def load_checkpoint(self, step=None):
58-
model_configs = self.load_wan_configs_from_orbax(step)
95+
restored_checkpoint = self.load_wan_configs_from_orbax(step)
5996

60-
if model_configs:
61-
raise NotImplementedError("model configs should not exist in orbax")
97+
if restored_checkpoint:
98+
max_logging.log("Loading WAN pipeline from checkpoint")
99+
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
62100
else:
101+
max_logging.log("No checkpoint found, loading default pipeline.")
63102
pipeline = self.load_diffusers_checkpoint()
64103

65104
return pipeline
105+
106+
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
107+
"""Saves the training state and model configurations."""
108+
def config_to_json(model_or_config):
109+
return json.loads(model_or_config.to_json_string())
110+
max_logging.log(f"Saving checkpoint for step {train_step}")
111+
items = {
112+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
113+
}
114+
115+
items["wan_state"] = ocp.args.PyTreeSave(train_states)
116+
117+
# Save the checkpoint
118+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
119+
max_logging.log(f"Checkpoint for step {train_step} saved.")
120+
121+
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
122+
"""Saves the training state and model configurations."""
123+
def config_to_json(model_or_config):
124+
"""
125+
only save the config that is needed and can be serialized to JSON.
126+
"""
127+
if not hasattr(model_or_config, "config"):
128+
return None
129+
source_config = dict(model_or_config.config)
130+
131+
# 1. configs that can be serialized to JSON
132+
SAFE_KEYS = [
133+
'_class_name', '_diffusers_version', 'model_type', 'patch_size',
134+
'num_attention_heads', 'attention_head_dim', 'in_channels',
135+
'out_channels', 'text_dim', 'freq_dim', 'ffn_dim', 'num_layers',
136+
'cross_attn_norm', 'qk_norm', 'eps', 'image_dim',
137+
'added_kv_proj_dim', 'rope_max_seq_len', 'pos_embed_seq_len',
138+
'flash_min_seq_length', 'flash_block_sizes', 'attention',
139+
'_use_default_values'
140+
]
141+
142+
# 2. save the config that are in the SAFE_KEYS list
143+
clean_config = {}
144+
for key in SAFE_KEYS:
145+
if key in source_config:
146+
clean_config[key] = source_config[key]
147+
148+
# 3. deal with special data type and precision
149+
if 'dtype' in source_config and hasattr(source_config['dtype'], 'name'):
150+
clean_config['dtype'] = source_config['dtype'].name # e.g 'bfloat16'
151+
152+
if 'weights_dtype' in source_config and hasattr(source_config['weights_dtype'], 'name'):
153+
clean_config['weights_dtype'] = source_config['weights_dtype'].name
154+
155+
if 'precision' in source_config and isinstance(source_config['precision'], Precision):
156+
clean_config['precision'] = source_config['precision'].name # e.g. 'HIGHEST'
157+
158+
return clean_config
159+
160+
items_to_save = {
161+
"transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
162+
}
163+
164+
items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states)
165+
166+
# Create CompositeArgs for Orbax
167+
save_args = ocp.args.Composite(**items_to_save)
168+
169+
# Save the checkpoint
170+
self.checkpoint_manager.save(train_step, args=save_args)
171+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ remat_policy: "NONE"
198198

199199
# checkpoint every number of samples, -1 means don't checkpoint.
200200
checkpoint_every: -1
201+
checkpoint_dir: "/mnt/disks/kunjanp-dev/output-dir/test-wan-training-new/checkpoints"
201202
# enables one replica to read the ckpt then broadcast to the rest
202203
enable_single_replica_ckpt_restoring: False
203204

src/maxdiffusion/configuration_utils.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from huggingface_hub import create_repo, hf_hub_download
3131
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
3232
from requests import HTTPError
33-
33+
import max_logging
34+
import jax.numpy as jnp
3435
from . import __version__
3536
from .utils import (
3637
DIFFUSERS_CACHE,
@@ -47,7 +48,22 @@
4748

4849
_re_configuration_file = re.compile(r"config\.(.*)\.json")
4950

50-
51+
class CustomEncoder(json.JSONEncoder):
52+
"""
53+
Custom JSON encoder to handle non-serializable types like JAX/Numpy dtypes.
54+
"""
55+
def default(self, o):
56+
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
57+
if isinstance(o, type(jnp.dtype('bfloat16'))):
58+
return str(o)
59+
# Add fallbacks for other numpy types if needed
60+
if isinstance(o, np.integer):
61+
return int(o)
62+
if isinstance(o, np.floating):
63+
return float(o)
64+
# Let the base class default method raise the TypeError for other types
65+
return super().default(o)
66+
5167
class FrozenDict(OrderedDict):
5268

5369
def __init__(self, *args, **kwargs):
@@ -579,8 +595,31 @@ def to_json_saveable(value):
579595
config_dict.pop("precision", None)
580596
config_dict.pop("weights_dtype", None)
581597
config_dict.pop("quant", None)
582-
583-
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
598+
keys_to_remove = []
599+
for key, value in config_dict.items():
600+
# Check the type of the value by its class name to avoid import issues
601+
if type(value).__name__ == 'Rngs':
602+
keys_to_remove.append(key)
603+
604+
if keys_to_remove:
605+
max_logging.log(f"Skipping non-serializable config keys: {keys_to_remove}")
606+
for key in keys_to_remove:
607+
config_dict.pop(key)
608+
609+
try:
610+
611+
json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder)
612+
except Exception as e:
613+
max_logging.log(f"Error serializing config to JSON: {e}")
614+
non_serializable_keys = []
615+
for key in config_dict.keys():
616+
if not isinstance(key, str):
617+
non_serializable_keys.append(key)
618+
print(f"Non-serializable keys: {non_serializable_keys}")
619+
raise e
620+
json_str = "{}"
621+
622+
return json_str + "\n"
584623

585624
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
586625
"""

src/maxdiffusion/generate_wan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626
def run(config, pipeline=None, filename_prefix=""):
2727
print("seed: ", config.seed)
28+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
29+
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
30+
pipeline = checkpoint_loader.load_checkpoint()
2831
if pipeline is None:
2932
pipeline = WanPipeline.from_pretrained(config)
3033
s0 = time.perf_counter()

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,17 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
6666

6767

6868
# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
69-
def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
69+
def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None):
7070

7171
def create_model(rngs: nnx.Rngs, wan_config: dict):
7272
wan_transformer = WanModel(**wan_config, rngs=rngs)
7373
return wan_transformer
7474

7575
# 1. Load config.
76-
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
76+
if restored_checkpoint:
77+
wan_config = restored_checkpoint["wan_config"]
78+
else:
79+
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
7780
wan_config["mesh"] = mesh
7881
wan_config["dtype"] = config.activations_dtype
7982
wan_config["weights_dtype"] = config.weights_dtype
@@ -99,11 +102,16 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
99102
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
100103
# This helps with loading sharded weights directly into the accelerators without fist copying them
101104
# all to one device and then distributing them, thus using low HBM memory.
102-
params = load_wan_transformer(
103-
config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"]
104-
)
105+
if restored_checkpoint:
106+
params = restored_checkpoint["wan_state"]
107+
else:
108+
params = load_wan_transformer(
109+
config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"]
110+
)
105111
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
106112
for path, val in flax.traverse_util.flatten_dict(params).items():
113+
if restored_checkpoint:
114+
path = path[:-1]
107115
sharding = logical_state_sharding[path].value
108116
state[path].value = device_put_replicated(val, sharding)
109117
state = nnx.from_flat_state(state)
@@ -295,9 +303,9 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline
295303
return quantized_model
296304

297305
@classmethod
298-
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
306+
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None):
299307
with mesh:
300-
wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
308+
wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint)
301309
return wan_transformer
302310

303311
@classmethod
@@ -309,6 +317,43 @@ def load_scheduler(cls, config):
309317
)
310318
return scheduler, scheduler_state
311319

320+
@classmethod
321+
def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True):
322+
devices_array = max_utils.create_device_mesh(config)
323+
mesh = Mesh(devices_array, config.mesh_axes)
324+
rng = jax.random.key(config.seed)
325+
rngs = nnx.Rngs(rng)
326+
transformer = None
327+
tokenizer = None
328+
scheduler = None
329+
scheduler_state = None
330+
text_encoder = None
331+
if not vae_only:
332+
if load_transformer:
333+
with mesh:
334+
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint)
335+
336+
text_encoder = cls.load_text_encoder(config=config)
337+
tokenizer = cls.load_tokenizer(config=config)
338+
339+
scheduler, scheduler_state = cls.load_scheduler(config=config)
340+
341+
with mesh:
342+
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
343+
344+
return WanPipeline(
345+
tokenizer=tokenizer,
346+
text_encoder=text_encoder,
347+
transformer=transformer,
348+
vae=wan_vae,
349+
vae_cache=vae_cache,
350+
scheduler=scheduler,
351+
scheduler_state=scheduler_state,
352+
devices_array=devices_array,
353+
mesh=mesh,
354+
config=config,
355+
)
356+
312357
@classmethod
313358
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
314359
devices_array = max_utils.create_device_mesh(config)

0 commit comments

Comments
 (0)