Skip to content

Commit 222ddcb

Browse files
ninatumartinarroyo
andcommitted
Fix: Overhaul WAN checkpointers for robust multi-host restoration
This commit resolves several interrelated checkpointing issues by updating how Orbax handles metadata, sharding, and PyTree restoration. Key changes: * Add explicit `item_handlers`: Defined specific handlers (`JsonCheckpointHandler` for configs, `StandardCheckpointHandler` for states) in `CheckpointManager`. This ensures metadata is restored correctly, resolving known Orbax limitations (reference: google/orbax#986). * Bypass mesh validation during restore: Replaced `ocp.utils.to_shape_dtype_struct` with manual `jax.ShapeDtypeStruct` construction in `add_sharding_to_struct`. This makes restoration topology-agnostic, preventing `ValueError` when the current device mesh has fewer devices than the saved checkpoint's topology (e.g., restoring 32-device metadata on 4 devices). * Migrate to Standard API: Upgraded all WAN checkpointers from the `PyTreeSave`/`PyTreeRestore` APIs to `StandardSave`/`StandardRestore` to align with `item_handlers` defined in CheckpointManager. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 384d211 commit 222ddcb

6 files changed

Lines changed: 93 additions & 80 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,17 @@ def create_orbax_checkpoint_manager(
5858
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
5959
p = epath.Path(checkpoint_dir)
6060

61+
item_handlers = None
6162
if checkpoint_type == FLUX_CHECKPOINT:
6263
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
6364
elif checkpoint_type == WAN_CHECKPOINT:
6465
item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config")
66+
item_handlers = {
67+
"wan_config": ocp.JsonCheckpointHandler(),
68+
"wan_state": ocp.StandardCheckpointHandler(),
69+
"low_noise_transformer_state": ocp.StandardCheckpointHandler(),
70+
"high_noise_transformer_state": ocp.StandardCheckpointHandler(),
71+
}
6572
else:
6673
item_names = (
6774
"unet_config",
@@ -89,6 +96,7 @@ def create_orbax_checkpoint_manager(
8996
options=CheckpointManagerOptions(
9097
create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async
9198
),
99+
item_handlers=item_handlers,
92100
logger=orbax_logger,
93101
)
94102

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import json
1818
from typing import Optional, Tuple
19-
from etils import epath
2019
import jax
2120
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2221
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
@@ -44,31 +43,26 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4443
metadatas = self.checkpoint_manager.item_metadata(step)
4544
state = metadatas.wan_state
4645

46+
# Manually constructs jax.ShapeDtypeStruct with a specific sharding.
47+
# This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
48+
# allowing for sharding with a different mesh than the one used during
49+
# saving.
4750
def add_sharding_to_struct(leaf_struct, sharding):
48-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51-
return struct
51+
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
52+
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
53+
return leaf_struct
5254

5355
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5456

5557
with mesh:
5658
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
5759

58-
params_restore = ocp.args.PyTreeRestore(
59-
restore_args=jax.tree.map(
60-
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61-
abstract_train_state_with_sharding,
62-
)
63-
)
64-
6560
max_logging.log("Restoring WAN checkpoint")
6661
restored_checkpoint = self.checkpoint_manager.restore(
67-
directory=epath.Path(self.config.checkpoint_dir),
6862
step=step,
6963
args=ocp.args.Composite(
70-
wan_state=params_restore,
7164
wan_config=ocp.args.JsonRestore(),
65+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
7266
),
7367
)
7468
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
@@ -106,7 +100,7 @@ def config_to_json(model_or_config):
106100
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
107101
}
108102

109-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
103+
items["wan_state"] = ocp.args.StandardSave(train_states)
110104

111105
# Save the checkpoint
112106
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import json
1818
import jax
1919
import numpy as np
20+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2021
from typing import Optional, Tuple
2122
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
2223
from .. import max_logging
2324
import orbax.checkpoint as ocp
24-
from etils import epath
2525
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2626

2727

@@ -35,39 +35,45 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3535
max_logging.log("No WAN checkpoint found.")
3636
return None, None
3737
max_logging.log(f"Loading WAN checkpoint from step {step}")
38+
39+
cpu_devices = np.array(jax.devices(backend="cpu"))
40+
mesh = Mesh(cpu_devices, axis_names=("data",))
41+
replicated_sharding = NamedSharding(mesh, P())
42+
3843
metadatas = self.checkpoint_manager.item_metadata(step)
3944

4045
# Handle low_noise_transformer
4146
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
42-
abstract_tree_structure_low_params = jax.tree_util.tree_map(
43-
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
44-
)
45-
low_params_restore = ocp.args.PyTreeRestore(
46-
restore_args=jax.tree.map(
47-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
48-
abstract_tree_structure_low_params,
49-
)
50-
)
47+
48+
# Manually constructs jax.ShapeDtypeStruct with a specific sharding.
49+
# This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
50+
# allowing for sharding with a different mesh than the one used during
51+
# saving.
52+
def add_sharding_to_struct(leaf_struct, sharding):
53+
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
54+
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
55+
return leaf_struct
56+
57+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
58+
with mesh:
59+
abstract_tree_structure_low_params = jax.tree_util.tree_map(
60+
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
61+
)
5162

5263
# Handle high_noise_transformer
5364
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
54-
abstract_tree_structure_high_params = jax.tree_util.tree_map(
55-
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
56-
)
57-
high_params_restore = ocp.args.PyTreeRestore(
58-
restore_args=jax.tree.map(
59-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
60-
abstract_tree_structure_high_params,
61-
)
62-
)
65+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
66+
with mesh:
67+
abstract_tree_structure_high_params = jax.tree_util.tree_map(
68+
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
69+
)
6370

6471
max_logging.log("Restoring WAN 2.2 checkpoint")
6572
restored_checkpoint = self.checkpoint_manager.restore(
66-
directory=epath.Path(self.config.checkpoint_dir),
6773
step=step,
6874
args=ocp.args.Composite(
69-
low_noise_transformer_state=low_params_restore,
70-
high_noise_transformer_state=high_params_restore,
75+
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
76+
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
7177
wan_config=ocp.args.JsonRestore(),
7278
),
7379
)
@@ -119,8 +125,8 @@ def config_to_json(model_or_config):
119125
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
120126
}
121127

122-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
123-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
128+
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
129+
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])
124130

125131
# Save the checkpoint
126132
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,31 +44,26 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4444
metadatas = self.checkpoint_manager.item_metadata(step)
4545
state = metadatas.wan_state
4646

47+
# Manually constructs jax.ShapeDtypeStruct with a specific sharding.
48+
# This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
49+
# allowing for sharding with a different mesh than the one used during
50+
# saving.
4751
def add_sharding_to_struct(leaf_struct, sharding):
48-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51-
return struct
52+
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
53+
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
54+
return leaf_struct
5255

5356
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5457

5558
with mesh:
5659
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
5760

58-
params_restore = ocp.args.PyTreeRestore(
59-
restore_args=jax.tree.map(
60-
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61-
abstract_train_state_with_sharding,
62-
)
63-
)
64-
6561
max_logging.log("Restoring WAN checkpoint")
6662
restored_checkpoint = self.checkpoint_manager.restore(
67-
directory=epath.Path(self.config.checkpoint_dir),
6863
step=step,
6964
args=ocp.args.Composite(
70-
wan_state=params_restore,
7165
wan_config=ocp.args.JsonRestore(),
66+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
7267
),
7368
)
7469
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
@@ -106,7 +101,7 @@ def config_to_json(model_or_config):
106101
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
107102
}
108103

109-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
104+
items["wan_state"] = ocp.args.StandardSave(train_states)
110105

111106
# Save the checkpoint
112107
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import jax
1919
import numpy as np
20+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2021
from typing import Optional, Tuple
2122
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
2223
from .. import max_logging
@@ -35,39 +36,45 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
3945

4046
# Handle low_noise_transformer
4147
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
42-
abstract_tree_structure_low_params = jax.tree_util.tree_map(
43-
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
44-
)
45-
low_params_restore = ocp.args.PyTreeRestore(
46-
restore_args=jax.tree.map(
47-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
48-
abstract_tree_structure_low_params,
49-
)
50-
)
48+
49+
# Manually constructs jax.ShapeDtypeStruct with a specific sharding.
50+
# This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
51+
# allowing for sharding with a different mesh than the one used during
52+
# saving.
53+
def add_sharding_to_struct(leaf_struct, sharding):
54+
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
55+
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
56+
return leaf_struct
57+
58+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
59+
with mesh:
60+
abstract_tree_structure_low_params = jax.tree_util.tree_map(
61+
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
62+
)
5163

5264
# Handle high_noise_transformer
5365
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
54-
abstract_tree_structure_high_params = jax.tree_util.tree_map(
55-
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
56-
)
57-
high_params_restore = ocp.args.PyTreeRestore(
58-
restore_args=jax.tree.map(
59-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
60-
abstract_tree_structure_high_params,
61-
)
62-
)
66+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
67+
with mesh:
68+
abstract_tree_structure_high_params = jax.tree_util.tree_map(
69+
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
70+
)
6371

6472
max_logging.log("Restoring WAN 2.2 checkpoint")
6573
restored_checkpoint = self.checkpoint_manager.restore(
66-
directory=epath.Path(self.config.checkpoint_dir),
6774
step=step,
6875
args=ocp.args.Composite(
69-
low_noise_transformer_state=low_params_restore,
70-
high_noise_transformer_state=high_params_restore,
76+
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
77+
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
7178
wan_config=ocp.args.JsonRestore(),
7279
),
7380
)
@@ -119,8 +126,8 @@ def config_to_json(model_or_config):
119126
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
120127
}
121128

122-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
123-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
129+
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
130+
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])
124131

125132
# Save the checkpoint
126133
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,14 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4242
metadatas = self.checkpoint_manager.item_metadata(step)
4343
state = metadatas.wan_state
4444

45+
# Manually constructs jax.ShapeDtypeStruct with a specific sharding.
46+
# This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
47+
# allowing for sharding with a different mesh than the one used during
48+
# saving.
4549
def add_sharding_to_struct(leaf_struct, sharding):
46-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
47-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
48-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
49-
return struct
50+
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
51+
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
52+
return leaf_struct
5053

5154
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5255

0 commit comments

Comments
 (0)