We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d0ad85c commit d5669bbCopy full SHA for d5669bb
1 file changed
src/maxtext/common/checkpointing.py
@@ -196,7 +196,11 @@ def combine_sharding(sds, shardings):
196
use_ocdbt=use_ocdbt,
197
use_zarr3=use_zarr3,
198
)
199
- return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state)
+ # Provide sharding info to ensure restoration returns JAX arrays (not NumPy arrays).
200
+ restore_args = jax.tree_util.tree_map(
201
+ lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), abstract_unboxed_pre_state
202
+ )
203
+ return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state, restore_args=restore_args)
204
205
206
def create_orbax_checkpoint_manager(
0 commit comments