Skip to content

Commit 2965670

Browse files
Merge pull request #379 from AI-Hypercomputer:ninatu/fix-sharding-mismatch
PiperOrigin-RevId: 901337846
2 parents 702cadd + 69f7701 commit 2965670

7 files changed

Lines changed: 94 additions & 114 deletions

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1919

20-
from typing import Optional, Tuple
20+
from typing import Any, Optional, Tuple
2121
import jax
2222
import numpy as np
2323
import os
2424
import orbax.checkpoint
25+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2526
from maxdiffusion import max_logging
2627
from etils import epath
2728
from flax.training import train_state
@@ -58,10 +59,17 @@ def create_orbax_checkpoint_manager(
5859
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
5960
p = epath.Path(checkpoint_dir)
6061

62+
item_handlers = None
6163
if checkpoint_type == FLUX_CHECKPOINT:
6264
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
6365
elif checkpoint_type == WAN_CHECKPOINT:
6466
item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config")
67+
item_handlers = {
68+
"wan_config": ocp.JsonCheckpointHandler(),
69+
"wan_state": ocp.StandardCheckpointHandler(),
70+
"low_noise_transformer_state": ocp.StandardCheckpointHandler(),
71+
"high_noise_transformer_state": ocp.StandardCheckpointHandler(),
72+
}
6573
else:
6674
item_names = (
6775
"unet_config",
@@ -89,6 +97,7 @@ def create_orbax_checkpoint_manager(
8997
options=CheckpointManagerOptions(
9098
create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async
9199
),
100+
item_handlers=item_handlers,
92101
logger=orbax_logger,
93102
)
94103

@@ -255,3 +264,38 @@ def map_to_pspec(data):
255264
except:
256265
max_logging.log(f"could not load {checkpoint_item} from orbax")
257266
return None
267+
268+
269+
def get_cpu_mesh_and_sharding() -> Tuple[Mesh, NamedSharding]:
270+
"""Creates a JAX mesh using CPU devices and a fully replicated sharding.
271+
272+
This is useful for checkpointing when the full model state needs to be
273+
loaded onto a single device or when restoring on a different topology.
274+
275+
Returns:
276+
A tuple containing the CPU mesh and the replicated NamedSharding.
277+
"""
278+
cpu_devices = np.array(jax.devices(backend="cpu"))
279+
mesh = Mesh(cpu_devices, axis_names=("data",))
280+
replicated_sharding = NamedSharding(mesh, P())
281+
return mesh, replicated_sharding
282+
283+
284+
def add_sharding_to_struct(leaf_struct: Any, sharding: jax.sharding.Sharding) -> Any:
285+
"""Manually constructs jax.ShapeDtypeStruct with a specific sharding.
286+
287+
This avoids device mesh validation (as in ocp.utils.to_shape_dtype_struct)
288+
allowing for sharding with a different mesh than the one used during
289+
saving.
290+
291+
Args:
292+
leaf_struct: A leaf of a pytree.
293+
sharding: The sharding to apply to the leaf.
294+
295+
Returns:
296+
A jax.ShapeDtypeStruct if leaf_struct has shape and dtype attributes,
297+
otherwise returns leaf_struct.
298+
"""
299+
if hasattr(leaf_struct, "shape") and hasattr(leaf_struct, "dtype"):
300+
return jax.ShapeDtypeStruct(shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding)
301+
return leaf_struct

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616

1717
import json
1818
from typing import Optional, Tuple
19-
from etils import epath
2019
import jax
21-
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
20+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2221
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
23-
import numpy as np
2422
import orbax.checkpoint as ocp
2523
from .. import max_logging
2624
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
@@ -37,38 +35,21 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3735
return None, None
3836
max_logging.log(f"Loading WAN checkpoint from step {step}")
3937

40-
cpu_devices = np.array(jax.devices(backend="cpu"))
41-
mesh = Mesh(cpu_devices, axis_names=("data",))
42-
replicated_sharding = NamedSharding(mesh, P())
43-
38+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
4439
metadatas = self.checkpoint_manager.item_metadata(step)
4540
state = metadatas.wan_state
4641

47-
def add_sharding_to_struct(leaf_struct, sharding):
48-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51-
return struct
52-
5342
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5443

5544
with mesh:
5645
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
5746

58-
params_restore = ocp.args.PyTreeRestore(
59-
restore_args=jax.tree.map(
60-
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61-
abstract_train_state_with_sharding,
62-
)
63-
)
64-
6547
max_logging.log("Restoring WAN checkpoint")
6648
restored_checkpoint = self.checkpoint_manager.restore(
67-
directory=epath.Path(self.config.checkpoint_dir),
6849
step=step,
6950
args=ocp.args.Composite(
70-
wan_state=params_restore,
7151
wan_config=ocp.args.JsonRestore(),
52+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
7253
),
7354
)
7455
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
@@ -106,7 +87,7 @@ def config_to_json(model_or_config):
10687
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
10788
}
10889

109-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
90+
items["wan_state"] = ocp.args.StandardSave(train_states)
11091

11192
# Save the checkpoint
11293
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616

1717
import json
1818
import jax
19-
import numpy as np
2019
from typing import Optional, Tuple
2120
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
2221
from .. import max_logging
2322
import orbax.checkpoint as ocp
24-
from etils import epath
23+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2524
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2625

2726

@@ -35,39 +34,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3534
max_logging.log("No WAN checkpoint found.")
3635
return None, None
3736
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
38+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
3839
metadatas = self.checkpoint_manager.item_metadata(step)
3940

4041
# Handle low_noise_transformer
4142
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
42-
abstract_tree_structure_low_params = jax.tree_util.tree_map(
43-
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
44-
)
45-
low_params_restore = ocp.args.PyTreeRestore(
46-
restore_args=jax.tree.map(
47-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
48-
abstract_tree_structure_low_params,
49-
)
50-
)
43+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
44+
with mesh:
45+
abstract_tree_structure_low_params = jax.tree_util.tree_map(
46+
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
47+
)
5148

5249
# Handle high_noise_transformer
5350
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
54-
abstract_tree_structure_high_params = jax.tree_util.tree_map(
55-
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
56-
)
57-
high_params_restore = ocp.args.PyTreeRestore(
58-
restore_args=jax.tree.map(
59-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
60-
abstract_tree_structure_high_params,
61-
)
62-
)
51+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
52+
with mesh:
53+
abstract_tree_structure_high_params = jax.tree_util.tree_map(
54+
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
55+
)
6356

6457
max_logging.log("Restoring WAN 2.2 checkpoint")
6558
restored_checkpoint = self.checkpoint_manager.restore(
66-
directory=epath.Path(self.config.checkpoint_dir),
6759
step=step,
6860
args=ocp.args.Composite(
69-
low_noise_transformer_state=low_params_restore,
70-
high_noise_transformer_state=high_params_restore,
61+
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
62+
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
7163
wan_config=ocp.args.JsonRestore(),
7264
),
7365
)
@@ -119,8 +111,8 @@ def config_to_json(model_or_config):
119111
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
120112
}
121113

122-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
123-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
114+
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
115+
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])
124116

125117
# Save the checkpoint
126118
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616

1717
import json
1818
from typing import Optional, Tuple
19-
from etils import epath
2019
import jax
21-
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
20+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2221
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
23-
import numpy as np
2422
import orbax.checkpoint as ocp
2523
from .. import max_logging
2624
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
@@ -37,38 +35,21 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3735
return None, None
3836
max_logging.log(f"Loading WAN checkpoint from step {step}")
3937

40-
cpu_devices = np.array(jax.devices(backend="cpu"))
41-
mesh = Mesh(cpu_devices, axis_names=("data",))
42-
replicated_sharding = NamedSharding(mesh, P())
43-
38+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
4439
metadatas = self.checkpoint_manager.item_metadata(step)
4540
state = metadatas.wan_state
4641

47-
def add_sharding_to_struct(leaf_struct, sharding):
48-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51-
return struct
52-
5342
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5443

5544
with mesh:
5645
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
5746

58-
params_restore = ocp.args.PyTreeRestore(
59-
restore_args=jax.tree.map(
60-
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61-
abstract_train_state_with_sharding,
62-
)
63-
)
64-
6547
max_logging.log("Restoring WAN checkpoint")
6648
restored_checkpoint = self.checkpoint_manager.restore(
67-
directory=epath.Path(self.config.checkpoint_dir),
6849
step=step,
6950
args=ocp.args.Composite(
70-
wan_state=params_restore,
7151
wan_config=ocp.args.JsonRestore(),
52+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
7253
),
7354
)
7455
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
@@ -106,7 +87,7 @@ def config_to_json(model_or_config):
10687
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
10788
}
10889

109-
items["wan_state"] = ocp.args.PyTreeSave(train_states)
90+
items["wan_state"] = ocp.args.StandardSave(train_states)
11091

11192
# Save the checkpoint
11293
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616

1717
import json
1818
import jax
19-
import numpy as np
2019
from typing import Optional, Tuple
2120
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
2221
from .. import max_logging
2322
import orbax.checkpoint as ocp
24-
from etils import epath
23+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2524
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
2625

2726

@@ -35,39 +34,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3534
max_logging.log("No WAN checkpoint found.")
3635
return None, None
3736
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
38+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
3839
metadatas = self.checkpoint_manager.item_metadata(step)
3940

4041
# Handle low_noise_transformer
4142
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
42-
abstract_tree_structure_low_params = jax.tree_util.tree_map(
43-
ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata
44-
)
45-
low_params_restore = ocp.args.PyTreeRestore(
46-
restore_args=jax.tree.map(
47-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
48-
abstract_tree_structure_low_params,
49-
)
50-
)
43+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, low_noise_transformer_metadata)
44+
with mesh:
45+
abstract_tree_structure_low_params = jax.tree_util.tree_map(
46+
add_sharding_to_struct, low_noise_transformer_metadata, target_shardings
47+
)
5148

5249
# Handle high_noise_transformer
5350
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
54-
abstract_tree_structure_high_params = jax.tree_util.tree_map(
55-
ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata
56-
)
57-
high_params_restore = ocp.args.PyTreeRestore(
58-
restore_args=jax.tree.map(
59-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
60-
abstract_tree_structure_high_params,
61-
)
62-
)
51+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, high_noise_transformer_metadata)
52+
with mesh:
53+
abstract_tree_structure_high_params = jax.tree_util.tree_map(
54+
add_sharding_to_struct, high_noise_transformer_metadata, target_shardings
55+
)
6356

6457
max_logging.log("Restoring WAN 2.2 checkpoint")
6558
restored_checkpoint = self.checkpoint_manager.restore(
66-
directory=epath.Path(self.config.checkpoint_dir),
6759
step=step,
6860
args=ocp.args.Composite(
69-
low_noise_transformer_state=low_params_restore,
70-
high_noise_transformer_state=high_params_restore,
61+
low_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_low_params),
62+
high_noise_transformer_state=ocp.args.StandardRestore(abstract_tree_structure_high_params),
7163
wan_config=ocp.args.JsonRestore(),
7264
),
7365
)
@@ -119,8 +111,8 @@ def config_to_json(model_or_config):
119111
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
120112
}
121113

122-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
123-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
114+
items["low_noise_transformer_state"] = ocp.args.StandardSave(train_states["low_noise_transformer"])
115+
items["high_noise_transformer_state"] = ocp.args.StandardSave(train_states["high_noise_transformer"])
124116

125117
# Save the checkpoint
126118
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
import json
1717
from typing import Optional, Tuple
1818
import jax
19-
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
19+
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
2020
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
21-
import numpy as np
2221
import orbax.checkpoint as ocp
2322
from .. import max_logging
2423
from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1
@@ -35,19 +34,10 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3534
return None, None
3635
max_logging.log(f"Loading WAN checkpoint from step {step}")
3736

38-
cpu_devices = np.array(jax.devices(backend="cpu"))
39-
mesh = Mesh(cpu_devices, axis_names=("data",))
40-
replicated_sharding = NamedSharding(mesh, P())
41-
37+
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
4238
metadatas = self.checkpoint_manager.item_metadata(step)
4339
state = metadatas.wan_state
4440

45-
def add_sharding_to_struct(leaf_struct, sharding):
46-
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
47-
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
48-
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
49-
return struct
50-
5141
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
5242

5343
with mesh:

0 commit comments

Comments
 (0)