You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix: Overhaul WAN checkpointers for robust multi-host restoration
This commit resolves several interrelated checkpointing issues by updating
how Orbax handles metadata, sharding, and PyTree restoration.
Key changes:
* Add explicit `item_handlers`: Defined specific handlers (`JsonCheckpointHandler`
for configs, `StandardCheckpointHandler` for states) in `CheckpointManager`.
This ensures metadata is restored correctly, resolving known Orbax limitations
(reference: google/orbax#986).
* Bypass mesh validation during restore: Replaced `ocp.utils.to_shape_dtype_struct`
with manual `jax.ShapeDtypeStruct` construction in `add_sharding_to_struct`.
This makes restoration topology-agnostic, preventing `ValueError` when the
current device mesh has fewer devices than the saved checkpoint's topology
(e.g., restoring 32-device metadata on 4 devices).
* Migrate to Standard API: Upgraded all WAN checkpointers from
the `PyTreeSave`/`PyTreeRestore` APIs to `StandardSave`/`StandardRestore`
to align with `item_handlers` defined in CheckpointManager.
Co-authored-by: martinarroyo <martinarroyo@google.com>
0 commit comments