Skip to content

Commit e57ed73

Browse files
Merge pull request #3464 from AI-Hypercomputer:agagik-distill-checkpoint
PiperOrigin-RevId: 888255077
2 parents 5cda5eb + 359ac86 commit e57ed73

File tree

3 files changed

+582
-186
lines changed

3 files changed

+582
-186
lines changed

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,28 +48,28 @@ class DistillationForwardOutput:
4848
"""Dataclass to carry MaxText-specific output fields."""
4949

5050
#: logits
51-
logits: jax.Array = None
51+
logits: jax.Array
5252
#: out_projection_activations
53-
out_projection_activations: jax.Array = None
53+
out_projection_activations: jax.Array | None = None
5454

5555

5656
@flax.struct.dataclass(frozen=True)
5757
class MaxTextTrainingInput(peft_trainer.TrainingInput):
5858
"""Extended TrainingInput dataclass to carry MaxText-specific fields."""
5959

6060
#: Position indices for the tokens (for RoPE).
61-
positions: jax.Array = None
61+
positions: jax.Array | None = None
6262
#: Segment IDs for packed sequences (0=padding, 1+=examples).
63-
decoder_segment_ids: jax.Array = None
63+
decoder_segment_ids: jax.Array | None = None
6464
#: Ground truth target tokens (used for loss calculation and logging).
65-
targets: jax.Array = None
65+
targets: jax.Array | None = None
6666
#: Position indices for the target tokens.
67-
targets_position: jax.Array = None
67+
targets_position: jax.Array | None = None
6868
#: Segment IDs for packed target tokens.
69-
targets_segmentation: jax.Array = None
69+
targets_segmentation: jax.Array | None = None
7070
#: Top-K logits from the teacher model.
71-
top_k_logits: jax.Array = None
72-
top_k_indices: jax.Array = None
71+
top_k_logits: jax.Array | None = None
72+
top_k_indices: jax.Array | None = None
7373

7474

7575
# -----------------------------------------------------------------------------
@@ -275,7 +275,7 @@ def compute_loss(
275275
# 3. Combine losses
276276
base_logit_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)
277277

278-
feature_loss = 0.0
278+
feature_loss = jnp.array(0.0)
279279
if self.beta_feature > 0.0:
280280

281281
if self.layer_indices is not None:
@@ -420,6 +420,86 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
420420
force=force,
421421
)
422422

423+
def maybe_restore(
424+
self,
425+
model: Any,
426+
optimizer: Any = None,
427+
restore_only_lora_params: bool = False,
428+
) -> tuple[int, dict[str, Any]]:
429+
"""Restores model and optimizer state if a checkpoint exists, using correct sharding specs.
430+
431+
This method checks for the latest available checkpoint. If found, it restores the
432+
model parameters and optionally the optimizer state in-place. It automatically
433+
maps the parameter's `sharding` attributes to Orbax restore arguments to ensure
434+
the tensors are placed on the correct device meshes.
435+
436+
Args:
437+
model: The model to restore. If a `ModelBundle` is provided, it automatically
438+
extracts and restores only the `student_model`.
439+
optimizer: The optimizer state to restore. If None, optimizer restoration is skipped.
440+
restore_only_lora_params: If True, restricts restoration to parameters marked
441+
as `nnx.LoRAParam`.
442+
443+
Returns:
444+
A tuple containing the restored step number (0 if no checkpoint was found)
445+
and a dictionary of custom metadata.
446+
"""
447+
if self._checkpoint_manager is None:
448+
return 0, {}
449+
450+
step = self._checkpoint_manager.latest_step()
451+
if step is None:
452+
return 0, {}
453+
454+
max_logging.log(f"Restoring from checkpoint step {step}...")
455+
456+
# Extract student model safely
457+
target_model = getattr(model, "student_model", model)
458+
459+
if restore_only_lora_params:
460+
params = nnx.state(target_model, nnx.LoRAParam)
461+
else:
462+
params = nnx.state(target_model)
463+
464+
def map_to_pspec(data):
465+
if hasattr(data, "sharding"):
466+
return checkpoint.type_handlers.ArrayRestoreArgs(sharding=data.sharding)
467+
return None
468+
469+
restore_args = jax.tree.map(map_to_pspec, params)
470+
471+
cp_restore_args = {
472+
"model_params": checkpoint.args.PyTreeRestore(
473+
item=params,
474+
restore_args=restore_args,
475+
)
476+
}
477+
478+
if optimizer is not None:
479+
optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState)
480+
opt_restore_args = jax.tree.map(map_to_pspec, optimizer_state)
481+
cp_restore_args["optimizer_state"] = checkpoint.args.PyTreeRestore(
482+
item=optimizer_state,
483+
restore_args=opt_restore_args,
484+
)
485+
486+
restored = self._checkpoint_manager.restore(
487+
step,
488+
args=checkpoint.args.Composite(**cp_restore_args),
489+
)
490+
491+
nnx.update(target_model, restored.model_params)
492+
if optimizer is not None:
493+
nnx.update(optimizer, restored.optimizer_state)
494+
495+
metadata = self._checkpoint_manager.metadata(step)
496+
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
497+
custom_metadata = metadata.custom_metadata
498+
else:
499+
custom_metadata = {}
500+
501+
return step, dict(custom_metadata)
502+
423503
def restore_iterator(self):
424504
"""Restores the iterator using MaxText's logic."""
425505
if self._checkpoint_manager is None or self._iterator is None:

0 commit comments

Comments
 (0)