Skip to content

Commit d5669bb

Browse files
committed
Fix checkpoint restore sharding error
jax array pyink
1 parent d0ad85c commit d5669bb

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxtext/common/checkpointing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,11 @@ def combine_sharding(sds, shardings):
196196
use_ocdbt=use_ocdbt,
197197
use_zarr3=use_zarr3,
198198
)
199-
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state)
199+
# Provide sharding info to ensure restoration returns JAX arrays (not NumPy arrays).
200+
restore_args = jax.tree_util.tree_map(
201+
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), abstract_unboxed_pre_state
202+
)
203+
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state, restore_args=restore_args)
200204

201205

202206
def create_orbax_checkpoint_manager(

0 commit comments

Comments
 (0)