Skip to content

Commit 4850947

Browse files
ninatumartinarroyo
andcommitted
Fix: Overhaul WAN checkpointers for robust multi-host restoration
This commit resolves several interrelated checkpointing issues by updating how Orbax handles metadata, sharding, and PyTree restoration. Key changes: * Add explicit `item_handlers`: Defined specific handlers (`JsonCheckpointHandler` for configs, `StandardCheckpointHandler` for states) in `CheckpointManager`. This ensures metadata is restored correctly, resolving known Orbax limitations (reference: google/orbax#986). * Bypass mesh validation during restore: Replaced `ocp.utils.to_shape_dtype_struct` with manual `jax.ShapeDtypeStruct` construction in `add_sharding_to_struct`. This makes restoration topology-agnostic, preventing `ValueError` when the current device mesh has fewer devices than the saved checkpoint's topology (e.g., restoring 32-device metadata on 4 devices). * Migrate to Standard API: Upgraded all WAN checkpointers from the `PyTreeSave`/`PyTreeRestore` APIs to `StandardSave`/`StandardRestore` to align with `item_handlers` defined in CheckpointManager. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 384d211 commit 4850947

7 files changed

Lines changed: 96 additions & 106 deletions

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1919

20-
from typing import Optional, Tuple
20+
from typing import Any, Optional, Tuple
2121
import jax
2222
import numpy as np
2323
import os
2424
import orbax.checkpoint
25+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2526
from maxdiffusion import max_logging
2627
from etils import epath
2728
from flax.training import train_state
@@ -58,10 +59,17 @@ def create_orbax_checkpoint_manager(
5859
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
5960
p = epath.Path(checkpoint_dir)
6061

62+
item_handlers = None
6163
if checkpoint_type == FLUX_CHECKPOINT:
6264
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
6365
elif checkpoint_type == WAN_CHECKPOINT:
6466
item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config")
67+
item_handlers = {
68+
"wan_config": ocp.JsonCheckpointHandler(),
69+
"wan_state": ocp.StandardCheckpointHandler(),
70+
"low_noise_transformer_state": ocp.StandardCheckpointHandler(),
71+
"high_noise_transformer_state": ocp.StandardCheckpointHandler(),
72+
}
6573
else:
6674
item_names = (
6775
"unet_config",
@@ -89,6 +97,7 @@ def create_orbax_checkpoint_manager(
8997
options=CheckpointManagerOptions(
9098
create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async
9199
),
100+
item_handlers=item_handlers,
92101
logger=orbax_logger,
93102
)
94103

@@ -255,3 +264,38 @@ def map_to_pspec(data):
255264
except:
256265
max_logging.log(f"could not load {checkpoint_item} from orbax")
257266
return None
267+
268+
269+
def get_cpu_mesh_and_sharding() -> Tuple[Mesh, NamedSharding]:
270+
"""Creates a JAX mesh using CPU devices and a fully replicated sharding.
271+
272+
This is useful for checkpointing when the full model state needs to be
273+
loaded onto a single device or when restoring on a different topology.
274+
275+
Returns:
276+
A tuple containing the CPU mesh and the replicated NamedSharding.
277+
"""
278+
cpu_devices = np.array(jax.devices(backend="cpu"))
279+
mesh = Mesh(cpu_devices, axis_names=("data",))
280+
replicated_sharding = NamedSharding(mesh, P())
281+
return mesh, replicated_sharding
282+
283+
284+
def add_sharding_to_struct(leaf_struct: Any, sharding: jax.sharding.Sharding) -> Any:
285+
"""Manually constructs jax.ShapeDtypeStruct with a specific sharding.
286+
287+
This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
288+
allowing for sharding with a different mesh than the one used during
289+
saving.
290+
291+
Args:
292+
leaf_struct: A leaf of a pytree.
293+
sharding: The sharding to apply to the leaf.
294+
295+
Returns:
296+
A jax.ShapeDtypeStruct if leaf_struct has shape and dtype attributes,
297+
otherwise returns leaf_struct.
298+
"""
299+
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
300+
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
301+
return leaf_struct

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
import json
1818
from typing import Optional, Tuple
19-
from etils import epath
2019
import jax
2120
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
21+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2222
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2323
import numpy as np
2424
import orbax.checkpoint as ocp
@@ -37,38 +37,21 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3737
return None, None
3838
max_logging.log(f"Loading WAN checkpoint from step {step}")
3939

40-
cpu_devices = np.array(jax.devices(backend="cpu"))
41-
mesh = Mesh(cpu_devices, axis_names=("data",))
42-
replicated_sharding = NamedSharding(mesh, P())
43-
40+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
4441
metadatas = self.checkpoint_manager.item_metadata(step)
4542
state = metadatas.wan_state
4643

47-
def add_sharding_to_struct(leaf_struct, sharding):
48-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51-
return struct
52-
5344
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5445

5546
with mesh:
5647
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
5748

58-
params_restore = ocp.args.PyTreeRestore(
59-
restore_args=jax.tree.map(
60-
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61-
abstract_train_state_with_sharding,
62-
)
63-
)
64-
6549
max_logging.log("Restoring WAN checkpoint")
6650
restored_checkpoint = self.checkpoint_manager.restore(
67-
directory=epath.Path(self.config.checkpoint_dir),
6851
step=step,
6952
args=ocp.args.Composite(
70-
wan_state=params_restore,
7153
wan_config=ocp.args.JsonRestore(),
54+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
7255
),
7356
)
7457
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
@@ -106,7 +89,7 @@ def config_to_json(model_or_config):
10689
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
10790
}
10891

109-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
92+
items["wan_state"] = ocp.args.StandardSave(train_states)
11093

11194
# Save the checkpoint
11295
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import json
1818
import jax
1919
import numpy as np
20+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2021
from typing import Optional, Tuple
2122
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
2223
from .. import max_logging
2324
import orbax.checkpoint as ocp
24-
from etils import epath
25+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2526
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2627

2728

@@ -35,39 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
3841
metadatas = self.checkpoint_manager.item_metadata(step)
3942

4043
# Handle low_noise_transformer
4144
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
42-
abstract_tree_structure_low_params = jax.tree_util.tree_map(
43-
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
44-
)
45-
low_params_restore = ocp.args.PyTreeRestore(
46-
restore_args=jax.tree.map(
47-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
48-
abstract_tree_structure_low_params,
49-
)
50-
)
45+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
46+
with mesh:
47+
abstract_tree_structure_low_params = jax.tree_util.tree_map(
48+
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
49+
)
5150

5251
# Handle high_noise_transformer
5352
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
54-
abstract_tree_structure_high_params = jax.tree_util.tree_map(
55-
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
56-
)
57-
high_params_restore = ocp.args.PyTreeRestore(
58-
restore_args=jax.tree.map(
59-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
60-
abstract_tree_structure_high_params,
61-
)
62-
)
53+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
54+
with mesh:
55+
abstract_tree_structure_high_params = jax.tree_util.tree_map(
56+
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
57+
)
6358

6459
max_logging.log("Restoring WAN 2.2 checkpoint")
6560
restored_checkpoint = self.checkpoint_manager.restore(
66-
directory=epath.Path(self.config.checkpoint_dir),
6761
step=step,
6862
args=ocp.args.Composite(
69-
low_noise_transformer_state=low_params_restore,
70-
high_noise_transformer_state=high_params_restore,
63+
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
64+
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
7165
wan_config=ocp.args.JsonRestore(),
7266
),
7367
)
@@ -119,8 +113,8 @@ def config_to_json(model_or_config):
119113
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
120114
}
121115

122-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
123-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
116+
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
117+
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])
124118

125119
# Save the checkpoint
126120
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
import json
1818
from typing import Optional, Tuple
19-
from etils import epath
2019
import jax
2120
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
21+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2222
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2323
import numpy as np
2424
import orbax.checkpoint as ocp
@@ -37,38 +37,21 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3737
return None, None
3838
max_logging.log(f"Loading WAN checkpoint from step {step}")
3939

40-
cpu_devices = np.array(jax.devices(backend="cpu"))
41-
mesh = Mesh(cpu_devices, axis_names=("data",))
42-
replicated_sharding = NamedSharding(mesh, P())
43-
40+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
4441
metadatas = self.checkpoint_manager.item_metadata(step)
4542
state = metadatas.wan_state
4643

47-
def add_sharding_to_struct(leaf_struct, sharding):
48-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51-
return struct
52-
5344
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5445

5546
with mesh:
5647
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
5748

58-
params_restore = ocp.args.PyTreeRestore(
59-
restore_args=jax.tree.map(
60-
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61-
abstract_train_state_with_sharding,
62-
)
63-
)
64-
6549
max_logging.log("Restoring WAN checkpoint")
6650
restored_checkpoint = self.checkpoint_manager.restore(
67-
directory=epath.Path(self.config.checkpoint_dir),
6851
step=step,
6952
args=ocp.args.Composite(
70-
wan_state=params_restore,
7153
wan_config=ocp.args.JsonRestore(),
54+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
7255
),
7356
)
7457
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
@@ -106,7 +89,7 @@ def config_to_json(model_or_config):
10689
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
10790
}
10891

109-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
92+
items["wan_state"] = ocp.args.StandardSave(train_states)
11093

11194
# Save the checkpoint
11295
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import json
1818
import jax
1919
import numpy as np
20+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2021
from typing import Optional, Tuple
2122
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
2223
from .. import max_logging
2324
import orbax.checkpoint as ocp
24-
from etils import epath
25+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2526
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2627

2728

@@ -35,39 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
3841
metadatas = self.checkpoint_manager.item_metadata(step)
3942

4043
# Handle low_noise_transformer
4144
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
42-
abstract_tree_structure_low_params = jax.tree_util.tree_map(
43-
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
44-
)
45-
low_params_restore = ocp.args.PyTreeRestore(
46-
restore_args=jax.tree.map(
47-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
48-
abstract_tree_structure_low_params,
49-
)
50-
)
45+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
46+
with mesh:
47+
abstract_tree_structure_low_params = jax.tree_util.tree_map(
48+
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
49+
)
5150

5251
# Handle high_noise_transformer
5352
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
54-
abstract_tree_structure_high_params = jax.tree_util.tree_map(
55-
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
56-
)
57-
high_params_restore = ocp.args.PyTreeRestore(
58-
restore_args=jax.tree.map(
59-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
60-
abstract_tree_structure_high_params,
61-
)
62-
)
53+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
54+
with mesh:
55+
abstract_tree_structure_high_params = jax.tree_util.tree_map(
56+
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
57+
)
6358

6459
max_logging.log("Restoring WAN 2.2 checkpoint")
6560
restored_checkpoint = self.checkpoint_manager.restore(
66-
directory=epath.Path(self.config.checkpoint_dir),
6761
step=step,
6862
args=ocp.args.Composite(
69-
low_noise_transformer_state=low_params_restore,
70-
high_noise_transformer_state=high_params_restore,
63+
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
64+
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
7165
wan_config=ocp.args.JsonRestore(),
7266
),
7367
)
@@ -119,8 +113,8 @@ def config_to_json(model_or_config):
119113
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
120114
}
121115

122-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
123-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
116+
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
117+
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])
124118

125119
# Save the checkpoint
126120
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Optional, Tuple
1818
import jax
1919
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
20+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2021
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2122
import numpy as np
2223
import orbax.checkpoint as ocp
@@ -35,19 +36,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
return None, None
3637
max_logging.log(f"Loading WAN checkpoint from step {step}")
3738

38-
cpu_devices = np.array(jax.devices(backend="cpu"))
39-
mesh = Mesh(cpu_devices, axis_names=("data",))
40-
replicated_sharding = NamedSharding(mesh, P())
41-
39+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
4240
metadatas = self.checkpoint_manager.item_metadata(step)
4341
state = metadatas.wan_state
4442

45-
def add_sharding_to_struct(leaf_struct, sharding):
46-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
47-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
48-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
49-
return struct
50-
5143
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5244

5345
with mesh:

0 commit comments

Comments
 (0)