Skip to content

Commit 494b3b0

Browse files
Merge pull request #3082 from AI-Hypercomputer:hengtaoguo-integration
PiperOrigin-RevId: 865645256
2 parents 02bc28c + d5669bb commit 494b3b0

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxtext/common/checkpointing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,11 @@ def combine_sharding(sds, shardings):
196196
use_ocdbt=use_ocdbt,
197197
use_zarr3=use_zarr3,
198198
)
199-
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state)
199+
# Provide sharding info to ensure restoration returns JAX arrays (not NumPy arrays).
200+
restore_args = jax.tree_util.tree_map(
201+
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), abstract_unboxed_pre_state
202+
)
203+
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state, restore_args=restore_args)
200204

201205

202206
def create_orbax_checkpoint_manager(

0 commit comments

Comments
 (0)