@@ -48,28 +48,28 @@ class DistillationForwardOutput:
4848 """Dataclass to carry MaxText-specific output fields."""
4949
5050 #: logits
51- logits : jax .Array = None
51+ logits : jax .Array
5252 #: out_projection_activations
53- out_projection_activations : jax .Array = None
53+ out_projection_activations : jax .Array | None = None
5454
5555
5656@flax .struct .dataclass (frozen = True )
5757class MaxTextTrainingInput (peft_trainer .TrainingInput ):
5858 """Extended TrainingInput dataclass to carry MaxText-specific fields."""
5959
6060 #: Position indices for the tokens (for RoPE).
61- positions : jax .Array = None
61+ positions : jax .Array | None = None
6262 #: Segment IDs for packed sequences (0=padding, 1+=examples).
63- decoder_segment_ids : jax .Array = None
63+ decoder_segment_ids : jax .Array | None = None
6464 #: Ground truth target tokens (used for loss calculation and logging).
65- targets : jax .Array = None
65+ targets : jax .Array | None = None
6666 #: Position indices for the target tokens.
67- targets_position : jax .Array = None
67+ targets_position : jax .Array | None = None
6868 #: Segment IDs for packed target tokens.
69- targets_segmentation : jax .Array = None
69+ targets_segmentation : jax .Array | None = None
7070 #: Top-K logits from the teacher model.
71- top_k_logits : jax .Array = None
72- top_k_indices : jax .Array = None
71+ top_k_logits : jax .Array | None = None
72+ top_k_indices : jax .Array | None = None
7373
7474
7575# -----------------------------------------------------------------------------
@@ -275,7 +275,7 @@ def compute_loss(
275275 # 3. Combine losses
276276 base_logit_loss = (self .alpha * soft_loss ) + ((1.0 - self .alpha ) * hard_loss )
277277
278- feature_loss = 0.0
278+ feature_loss = jnp . array ( 0.0 )
279279 if self .beta_feature > 0.0 :
280280
281281 if self .layer_indices is not None :
@@ -420,6 +420,86 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
420420 force = force ,
421421 )
422422
423+ def maybe_restore (
424+ self ,
425+ model : Any ,
426+ optimizer : Any = None ,
427+ restore_only_lora_params : bool = False ,
428+ ) -> tuple [int , dict [str , Any ]]:
429+ """Restores model and optimizer state if a checkpoint exists, using correct sharding specs.
430+
431+ This method checks for the latest available checkpoint. If found, it restores the
432+ model parameters and optionally the optimizer state in-place. It automatically
433+ maps the parameter's `sharding` attributes to Orbax restore arguments to ensure
434+ the tensors are placed on the correct device meshes.
435+
436+ Args:
437+ model: The model to restore. If a `ModelBundle` is provided, it automatically
438+ extracts and restores only the `student_model`.
439+ optimizer: The optimizer state to restore. If None, optimizer restoration is skipped.
440+ restore_only_lora_params: If True, restricts restoration to parameters marked
441+ as `nnx.LoRAParam`.
442+
443+ Returns:
444+ A tuple containing the restored step number (0 if no checkpoint was found)
445+ and a dictionary of custom metadata.
446+ """
447+ if self ._checkpoint_manager is None :
448+ return 0 , {}
449+
450+ step = self ._checkpoint_manager .latest_step ()
451+ if step is None :
452+ return 0 , {}
453+
454+ max_logging .log (f"Restoring from checkpoint step { step } ..." )
455+
456+ # Extract student model safely
457+ target_model = getattr (model , "student_model" , model )
458+
459+ if restore_only_lora_params :
460+ params = nnx .state (target_model , nnx .LoRAParam )
461+ else :
462+ params = nnx .state (target_model )
463+
464+ def map_to_pspec (data ):
465+ if hasattr (data , "sharding" ):
466+ return checkpoint .type_handlers .ArrayRestoreArgs (sharding = data .sharding )
467+ return None
468+
469+ restore_args = jax .tree .map (map_to_pspec , params )
470+
471+ cp_restore_args = {
472+ "model_params" : checkpoint .args .PyTreeRestore (
473+ item = params ,
474+ restore_args = restore_args ,
475+ )
476+ }
477+
478+ if optimizer is not None :
479+ optimizer_state = nnx .state (optimizer , nnx .optimizer .OptState )
480+ opt_restore_args = jax .tree .map (map_to_pspec , optimizer_state )
481+ cp_restore_args ["optimizer_state" ] = checkpoint .args .PyTreeRestore (
482+ item = optimizer_state ,
483+ restore_args = opt_restore_args ,
484+ )
485+
486+ restored = self ._checkpoint_manager .restore (
487+ step ,
488+ args = checkpoint .args .Composite (** cp_restore_args ),
489+ )
490+
491+ nnx .update (target_model , restored .model_params )
492+ if optimizer is not None :
493+ nnx .update (optimizer , restored .optimizer_state )
494+
495+ metadata = self ._checkpoint_manager .metadata (step )
496+ if metadata and hasattr (metadata , "custom_metadata" ) and metadata .custom_metadata is not None :
497+ custom_metadata = metadata .custom_metadata
498+ else :
499+ custom_metadata = {}
500+
501+ return step , dict (custom_metadata )
502+
423503 def restore_iterator (self ):
424504 """Restores the iterator using MaxText's logic."""
425505 if self ._checkpoint_manager is None or self ._iterator is None :
0 commit comments