Skip to content

Commit 09deaaa

Browse files
ninatumartinarroyo
andcommitted
Fix: Overhaul WAN checkpointers for robust multi-host restoration
This commit resolves several interrelated checkpointing issues by updating how Orbax handles metadata, sharding, and PyTree restoration. Key changes: * Add explicit `item_handlers`: Defined specific handlers (`JsonCheckpointHandler` for configs, `StandardCheckpointHandler` for states) in `CheckpointManager`. This ensures metadata is restored correctly, resolving known Orbax limitations (reference: google/orbax#986). * Bypass mesh validation during restore: Replaced `ocp.utils.to_shape_dtype_struct` with manual `jax.ShapeDtypeStruct` construction in `add_sharding_to_struct`. This makes restoration topology-agnostic, preventing `ValueError` when the current device mesh has fewer devices than the saved checkpoint's topology (e.g., restoring 32-device metadata on 4 devices). * Migrate to Standard API: Upgraded all WAN checkpointers from the `PyTreeSave`/`PyTreeRestore` APIs to `StandardSave`/`StandardRestore` to align with `item_handlers` defined in CheckpointManager. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 384d211 commit 09deaaa

7 files changed

Lines changed: 83 additions & 94 deletions

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1919

20-
from typing import Optional, Tuple
20+
from typing import Any, Optional, Tuple
2121
import jax
2222
import numpy as np
2323
import os
@@ -58,10 +58,17 @@ def create_orbax_checkpoint_manager(
5858
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
5959
p = epath.Path(checkpoint_dir)
6060

61+
item_handlers = None
6162
if checkpoint_type == FLUX_CHECKPOINT:
6263
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
6364
elif checkpoint_type == WAN_CHECKPOINT:
6465
item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config")
66+
item_handlers = {
67+
"wan_config": ocp.JsonCheckpointHandler(),
68+
"wan_state": ocp.StandardCheckpointHandler(),
69+
"low_noise_transformer_state": ocp.StandardCheckpointHandler(),
70+
"high_noise_transformer_state": ocp.StandardCheckpointHandler(),
71+
}
6572
else:
6673
item_names = (
6774
"unet_config",
@@ -89,6 +96,7 @@ def create_orbax_checkpoint_manager(
8996
options=CheckpointManagerOptions(
9097
create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async
9198
),
99+
item_handlers=item_handlers,
92100
logger=orbax_logger,
93101
)
94102

@@ -255,3 +263,23 @@ def map_to_pspec(data):
255263
except:
256264
max_logging.log(f"could not load {checkpoint_item} from orbax")
257265
return None
266+
267+
268+
def add_sharding_to_struct(leaf_struct: Any, sharding: jax.sharding.Sharding) -> Any:
269+
"""Manually constructs jax.ShapeDtypeStruct with a specific sharding.
270+
271+
This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
272+
allowing for sharding with a different mesh than the one used during
273+
saving.
274+
275+
Args:
276+
leaf_struct: A leaf of a pytree.
277+
sharding: The sharding to apply to the leaf.
278+
279+
Returns:
280+
A jax.ShapeDtypeStruct if leaf_struct has shape and dtype attributes,
281+
otherwise returns leaf_struct.
282+
"""
283+
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
284+
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
285+
return leaf_struct

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
import json
1818
from typing import Optional, Tuple
19-
from etils import epath
2019
import jax
2120
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
21+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct
2222
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2323
import numpy as np
2424
import orbax.checkpoint as ocp
@@ -44,31 +44,17 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4444
metadatas = self.checkpoint_manager.item_metadata(step)
4545
state = metadatas.wan_state
4646

47-
def add_sharding_to_struct(leaf_struct, sharding):
48-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51-
return struct
52-
5347
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5448

5549
with mesh:
5650
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
5751

58-
params_restore = ocp.args.PyTreeRestore(
59-
restore_args=jax.tree.map(
60-
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61-
abstract_train_state_with_sharding,
62-
)
63-
)
64-
6552
max_logging.log("Restoring WAN checkpoint")
6653
restored_checkpoint = self.checkpoint_manager.restore(
67-
directory=epath.Path(self.config.checkpoint_dir),
6854
step=step,
6955
args=ocp.args.Composite(
70-
wan_state=params_restore,
7156
wan_config=ocp.args.JsonRestore(),
57+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
7258
),
7359
)
7460
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
@@ -106,7 +92,7 @@ def config_to_json(model_or_config):
10692
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
10793
}
10894

109-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
95+
items["wan_state"] = ocp.args.StandardSave(train_states)
11096

11197
# Save the checkpoint
11298
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import json
1818
import jax
1919
import numpy as np
20+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2021
from typing import Optional, Tuple
2122
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
2223
from .. import max_logging
2324
import orbax.checkpoint as ocp
24-
from etils import epath
25+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct
2526
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2627

2728

@@ -35,39 +36,35 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
3945

4046
# Handle low_noise_transformer
4147
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
42-
abstract_tree_structure_low_params = jax.tree_util.tree_map(
43-
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
44-
)
45-
low_params_restore = ocp.args.PyTreeRestore(
46-
restore_args=jax.tree.map(
47-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
48-
abstract_tree_structure_low_params,
49-
)
50-
)
48+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
49+
with mesh:
50+
abstract_tree_structure_low_params = jax.tree_util.tree_map(
51+
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
52+
)
5153

5254
# Handle high_noise_transformer
5355
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
54-
abstract_tree_structure_high_params = jax.tree_util.tree_map(
55-
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
56-
)
57-
high_params_restore = ocp.args.PyTreeRestore(
58-
restore_args=jax.tree.map(
59-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
60-
abstract_tree_structure_high_params,
61-
)
62-
)
56+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
57+
with mesh:
58+
abstract_tree_structure_high_params = jax.tree_util.tree_map(
59+
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
60+
)
6361

6462
max_logging.log("Restoring WAN 2.2 checkpoint")
6563
restored_checkpoint = self.checkpoint_manager.restore(
66-
directory=epath.Path(self.config.checkpoint_dir),
6764
step=step,
6865
args=ocp.args.Composite(
69-
low_noise_transformer_state=low_params_restore,
70-
high_noise_transformer_state=high_params_restore,
66+
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
67+
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
7168
wan_config=ocp.args.JsonRestore(),
7269
),
7370
)
@@ -119,8 +116,8 @@ def config_to_json(model_or_config):
119116
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
120117
}
121118

122-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
123-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
119+
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
120+
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])
124121

125122
# Save the checkpoint
126123
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
import json
1818
from typing import Optional, Tuple
19-
from etils import epath
2019
import jax
2120
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
21+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct
2222
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2323
import numpy as np
2424
import orbax.checkpoint as ocp
@@ -44,31 +44,17 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4444
metadatas = self.checkpoint_manager.item_metadata(step)
4545
state = metadatas.wan_state
4646

47-
def add_sharding_to_struct(leaf_struct, sharding):
48-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51-
return struct
52-
5347
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5448

5549
with mesh:
5650
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
5751

58-
params_restore = ocp.args.PyTreeRestore(
59-
restore_args=jax.tree.map(
60-
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61-
abstract_train_state_with_sharding,
62-
)
63-
)
64-
6552
max_logging.log("Restoring WAN checkpoint")
6653
restored_checkpoint = self.checkpoint_manager.restore(
67-
directory=epath.Path(self.config.checkpoint_dir),
6854
step=step,
6955
args=ocp.args.Composite(
70-
wan_state=params_restore,
7156
wan_config=ocp.args.JsonRestore(),
57+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
7258
),
7359
)
7460
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
@@ -106,7 +92,7 @@ def config_to_json(model_or_config):
10692
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
10793
}
10894

109-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
95+
items["wan_state"] = ocp.args.StandardSave(train_states)
11096

11197
# Save the checkpoint
11298
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import json
1818
import jax
1919
import numpy as np
20+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2021
from typing import Optional, Tuple
2122
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
2223
from .. import max_logging
2324
import orbax.checkpoint as ocp
24-
from etils import epath
25+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct
2526
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2627

2728

@@ -35,39 +36,35 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
3945

4046
# Handle low_noise_transformer
4147
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
42-
abstract_tree_structure_low_params = jax.tree_util.tree_map(
43-
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
44-
)
45-
low_params_restore = ocp.args.PyTreeRestore(
46-
restore_args=jax.tree.map(
47-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
48-
abstract_tree_structure_low_params,
49-
)
50-
)
48+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
49+
with mesh:
50+
abstract_tree_structure_low_params = jax.tree_util.tree_map(
51+
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
52+
)
5153

5254
# Handle high_noise_transformer
5355
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
54-
abstract_tree_structure_high_params = jax.tree_util.tree_map(
55-
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
56-
)
57-
high_params_restore = ocp.args.PyTreeRestore(
58-
restore_args=jax.tree.map(
59-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
60-
abstract_tree_structure_high_params,
61-
)
62-
)
56+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
57+
with mesh:
58+
abstract_tree_structure_high_params = jax.tree_util.tree_map(
59+
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
60+
)
6361

6462
max_logging.log("Restoring WAN 2.2 checkpoint")
6563
restored_checkpoint = self.checkpoint_manager.restore(
66-
directory=epath.Path(self.config.checkpoint_dir),
6764
step=step,
6865
args=ocp.args.Composite(
69-
low_noise_transformer_state=low_params_restore,
70-
high_noise_transformer_state=high_params_restore,
66+
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
67+
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
7168
wan_config=ocp.args.JsonRestore(),
7269
),
7370
)
@@ -119,8 +116,8 @@ def config_to_json(model_or_config):
119116
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
120117
}
121118

122-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
123-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
119+
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
120+
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])
124121

125122
# Save the checkpoint
126123
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Optional, Tuple
1818
import jax
1919
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
20+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct
2021
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2122
import numpy as np
2223
import orbax.checkpoint as ocp
@@ -42,12 +43,6 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4243
metadatas = self.checkpoint_manager.item_metadata(step)
4344
state = metadatas.wan_state
4445

45-
def add_sharding_to_struct(leaf_struct, sharding):
46-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
47-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
48-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
49-
return struct
50-
5146
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5247

5348
with mesh:

src/maxdiffusion/tests/wan_checkpointer_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag
7171
checkpointer = WanCheckpointer2_1(config=self.config)
7272
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
7373

74-
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
74+
mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY)
7575
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
7676
self.assertEqual(pipeline, mock_pipeline_instance)
7777
self.assertIsNone(opt_state)
@@ -101,7 +101,7 @@ def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_man
101101
checkpointer = WanCheckpointer2_1(config=self.config)
102102
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
103103

104-
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
104+
mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY)
105105
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
106106
self.assertEqual(pipeline, mock_pipeline_instance)
107107
self.assertIsNotNone(opt_state)
@@ -164,7 +164,7 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag
164164
checkpointer = WanCheckpointer2_2(config=self.config)
165165
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
166166

167-
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
167+
mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY)
168168
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
169169
self.assertEqual(pipeline, mock_pipeline_instance)
170170
self.assertIsNone(opt_state)
@@ -197,7 +197,7 @@ def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mo
197197
checkpointer = WanCheckpointer2_2(config=self.config)
198198
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
199199

200-
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
200+
mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY)
201201
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
202202
self.assertEqual(pipeline, mock_pipeline_instance)
203203
self.assertIsNotNone(opt_state)
@@ -231,7 +231,7 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m
231231
checkpointer = WanCheckpointer2_2(config=self.config)
232232
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
233233

234-
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
234+
mock_manager.restore.assert_called_once_with(step=1, args=unittest.mock.ANY)
235235
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
236236
self.assertEqual(pipeline, mock_pipeline_instance)
237237
self.assertIsNotNone(opt_state)

0 commit comments

Comments
 (0)