Skip to content

Commit 95ed966

Browse files
Merge pull request #3145 from AI-Hypercomputer:agagik-distill-grain
PiperOrigin-RevId: 871545114
2 parents 0f59a69 + 1435982 commit 95ed966

4 files changed

Lines changed: 609 additions & 157 deletions

File tree

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# Copyright 2023-2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility classes for MaxText Distillation with Tunix.
16+
17+
This module contains adapter classes that bridge MaxText's data loading and
18+
model structures with Tunix's training interfaces.
19+
"""
20+
21+
from typing import Any, Iterator
22+
23+
import flax
24+
from flax import nnx
25+
import jax
26+
import jax.numpy as jnp
27+
import optax
28+
from orbax import checkpoint
29+
30+
from maxtext.utils import max_logging
31+
# Reuse MaxText's native checkpointing logic
32+
from maxtext.common.checkpointing import GrainCheckpointHandler, GrainCheckpointSave, GrainCheckpointRestore
33+
from tunix.distillation import distillation_trainer
34+
from tunix.distillation.strategies import logit
35+
from tunix.sft import checkpoint_manager as tunix_checkpoint_manager
36+
37+
38+
# -----------------------------------------------------------------------------
39+
# Custom Data Structures
40+
# -----------------------------------------------------------------------------
41+
42+
43+
@flax.struct.dataclass(frozen=True)
44+
class MaxTextTrainingInput(distillation_trainer.TrainingInput):
45+
"""Extended TrainingInput dataclass to carry MaxText-specific fields."""
46+
47+
#: Position indices for the tokens (for RoPE).
48+
positions: jax.Array = None
49+
#: Segment IDs for packed sequences (0=padding, 1+=examples).
50+
decoder_segment_ids: jax.Array = None
51+
#: Ground truth target tokens (used for loss calculation and logging).
52+
targets: jax.Array = None
53+
54+
55+
# -----------------------------------------------------------------------------
56+
# Data Loading Adapter
57+
# -----------------------------------------------------------------------------
58+
59+
60+
class MaxTextToTunixIterator:
61+
"""Adapts the raw dictionary output of MaxText's data loader to Tunix objects.
62+
63+
MaxText's `input_pipeline_interface.create_data_iterator` yields a dictionary.
64+
Tunix expects an object with specific attributes (input_tokens, etc.).
65+
"""
66+
67+
def __init__(self, maxtext_iterator: Iterator):
68+
"""Initializes the adapter.
69+
70+
Args:
71+
maxtext_iterator: The upstream iterator created by MaxText's input pipeline.
72+
"""
73+
self._iterator = maxtext_iterator
74+
75+
def __iter__(self):
76+
"""Returns self as the iterator."""
77+
return self
78+
79+
def __next__(self) -> MaxTextTrainingInput:
80+
"""Fetches the next batch and converts it to the Tunix data class.
81+
82+
Returns:
83+
A MaxTextTrainingInput object containing the batch data.
84+
85+
Raises:
86+
StopIteration: If the upstream iterator is exhausted.
87+
"""
88+
batch = next(self._iterator)
89+
90+
# Ensure segmentation exists, default to ones if missing (standard non-packed)
91+
if "inputs_segmentation" in batch:
92+
input_mask = batch["inputs_segmentation"] != 0
93+
seg_ids = batch["inputs_segmentation"]
94+
else:
95+
# Fallback for non-packed datasets
96+
input_mask = jnp.ones_like(batch["inputs"], dtype=bool)
97+
seg_ids = None
98+
99+
# pylint: disable=unexpected-keyword-arg
100+
return MaxTextTrainingInput(
101+
input_tokens=batch["inputs"],
102+
input_mask=input_mask,
103+
teacher_output=None,
104+
positions=batch["inputs_position"],
105+
decoder_segment_ids=seg_ids,
106+
targets=batch["targets"],
107+
)
108+
109+
110+
# -----------------------------------------------------------------------------
111+
# Distillation Strategy
112+
# -----------------------------------------------------------------------------
113+
class MonitoredLogitStrategy(logit.LogitStrategy):
114+
"""Logit Strategy that returns detailed metrics for TensorBoard."""
115+
116+
def compute_loss(
117+
self,
118+
student_output: jax.Array,
119+
teacher_output: jax.Array,
120+
labels: jax.Array,
121+
) -> tuple[jax.Array, dict[str, jax.Array]]:
122+
"""Computes Loss and Auxiliary Metrics."""
123+
# Calculate Distillation Loss (KL Divergence)
124+
# Scale logits by temperature T for soft targets
125+
# We use explicit float32 casting for stability in loss calculation
126+
s_logits = student_output.astype(jnp.float32)
127+
t_logits = teacher_output.astype(jnp.float32)
128+
129+
log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1)
130+
teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1)
131+
132+
# KL(Teacher || Student)
133+
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp)
134+
135+
# Scale gradients by T^2 (Hinton et al.)
136+
soft_loss = jnp.mean(kl_div) * (self.temperature**2)
137+
138+
# 1. Student Hard Loss (Existing)
139+
ce_loss_student = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
140+
hard_loss = jnp.mean(ce_loss_student)
141+
142+
# 2. Teacher Hard Loss (For Verification)
143+
ce_loss_teacher = optax.softmax_cross_entropy(logits=t_logits, labels=labels)
144+
teacher_hard_loss = jnp.mean(ce_loss_teacher)
145+
146+
# 3. Combine losses
147+
total_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)
148+
149+
# 4. Return Loss AND Metrics
150+
metrics = {
151+
"distill/soft_loss": soft_loss,
152+
"distill/hard_loss": hard_loss,
153+
"distill/kl_div": jnp.mean(kl_div),
154+
"distill/teacher_loss": teacher_hard_loss,
155+
}
156+
return total_loss, metrics
157+
158+
def compute_eval_loss(
159+
self,
160+
student_output: jax.Array,
161+
labels: jax.Array,
162+
) -> tuple[jax.Array, dict[str, jax.Array]]:
163+
"""Computes Eval Loss and returns empty aux dict (required for consistency)."""
164+
# Parent logic for task loss
165+
# We re-implement simple CE here to ensure float32 casting
166+
s_logits = student_output.astype(jnp.float32)
167+
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
168+
task_loss = jnp.mean(ce_loss)
169+
170+
# Must return a tuple because _has_aux=True expects it
171+
return task_loss, {}
172+
173+
174+
# -----------------------------------------------------------------------------
175+
# Checkpoint Manager
176+
# -----------------------------------------------------------------------------
177+
178+
179+
class MaxTextCheckpointManager(tunix_checkpoint_manager.CheckpointManager):
180+
"""Custom CheckpointManager that uses MaxText's native handlers.
181+
182+
This manager extends Tunix to support saving/restoring the MaxText input pipeline
183+
(Grain) alongside the model and optimizer.
184+
"""
185+
186+
def __init__(
187+
self,
188+
raw_iterator: Any | None,
189+
root_directory: str | None = None,
190+
options: checkpoint.CheckpointManagerOptions | None = None,
191+
):
192+
super().__init__(root_directory=root_directory, options=options)
193+
self._iterator = raw_iterator
194+
195+
# Re-initialize internal Orbax manager with MaxText's Grain handler
196+
# pylint: disable=access-member-before-definition
197+
if self._checkpoint_manager is not None:
198+
root_directory = self._checkpoint_manager.directory
199+
200+
if options is None:
201+
options = getattr(self._checkpoint_manager, "options", None)
202+
203+
item_handlers = {
204+
"model_params": checkpoint.PyTreeCheckpointHandler(),
205+
"optimizer_state": checkpoint.PyTreeCheckpointHandler(),
206+
"custom_metadata": checkpoint.JsonCheckpointHandler(),
207+
# Use MaxText's handler for the iterator
208+
"iter": GrainCheckpointHandler(),
209+
}
210+
211+
self._checkpoint_manager.close()
212+
self._checkpoint_manager = checkpoint.CheckpointManager(
213+
root_directory,
214+
item_handlers=item_handlers,
215+
options=options,
216+
)
217+
# pylint: enable=access-member-before-definition
218+
219+
def save(self, step, model, optimizer=None, save_only_lora_params=False, force=False, custom_metadata=None):
220+
"""Saves the checkpoint including the input pipeline state (if available)."""
221+
if self._checkpoint_manager is None:
222+
return False
223+
if not force and not self._checkpoint_manager.should_save(step):
224+
return False
225+
226+
# Standard Tunix Logic for Model/Optimizer
227+
if save_only_lora_params:
228+
params = nnx.state(model, nnx.LoRAParam)
229+
else:
230+
params = nnx.state(model)
231+
232+
# Define standard SaveArgs once to reuse
233+
default_save_args = checkpoint.SaveArgs()
234+
cp_save_args = {
235+
"model_params": checkpoint.args.PyTreeSave(
236+
item=params, save_args=jax.tree.map(lambda _: default_save_args, params)
237+
),
238+
}
239+
if optimizer is not None:
240+
optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState)
241+
cp_save_args["optimizer_state"] = checkpoint.args.PyTreeSave(
242+
item=optimizer_state, save_args=jax.tree.map(lambda _: default_save_args, optimizer_state)
243+
)
244+
245+
if self._iterator is not None:
246+
# Follow MaxText's logic to handle multi-process saving
247+
# Logic extracted from src/MaxText/common/checkpointing.py:save_checkpoint
248+
data_iterator = self._iterator
249+
if not isinstance(data_iterator, list):
250+
data_iterator = [data_iterator]
251+
252+
grain_iters_to_save = []
253+
process_count_total = jax.process_count() * len(data_iterator)
254+
255+
for i, data_iter in enumerate(data_iterator):
256+
process_index = jax.process_index() + i * jax.process_count()
257+
# MaxText iterators (MultiHostDataLoadIterator) wrap the actual Grain iterator in .local_iterator
258+
local_iter = data_iter.local_iterator if hasattr(data_iter, "local_iterator") else data_iter
259+
grain_iters_to_save.append((local_iter, process_index, process_count_total))
260+
261+
# Use GrainCheckpointSave wrapper
262+
cp_save_args["iter"] = GrainCheckpointSave(item=grain_iters_to_save)
263+
264+
return self._checkpoint_manager.save(
265+
step,
266+
args=checkpoint.args.Composite(**cp_save_args),
267+
custom_metadata=custom_metadata or {},
268+
force=force,
269+
)
270+
271+
def restore_iterator(self):
272+
"""Restores the iterator using MaxText's logic."""
273+
if self._checkpoint_manager is None or self._iterator is None:
274+
return None
275+
276+
step = self._checkpoint_manager.latest_step()
277+
if step is None:
278+
return None
279+
280+
try:
281+
# MaxText logic for restoration (simplified for standard case)
282+
# We assume 1-to-1 process mapping for now (no elasticity logic here yet)
283+
data_iter = self._iterator
284+
local_iter = data_iter.local_iterator if hasattr(data_iter, "local_iterator") else data_iter
285+
286+
restore_args = GrainCheckpointRestore(item=local_iter)
287+
288+
self._checkpoint_manager.restore(step, args=checkpoint.args.Composite(iter=restore_args))
289+
# Since Grain restores in-place via set_state(), we return the original object
290+
return self._iterator
291+
292+
except Exception as e: # pylint: disable=broad-exception-caught
293+
max_logging.log(f"Warning: Could not restore input pipeline: {e}")
294+
return None
295+
296+
def wait_until_finished(self):
297+
"""Blocks until all outstanding checkpoint operations are complete."""
298+
if self._checkpoint_manager is not None:
299+
self._checkpoint_manager.wait_until_finished()

0 commit comments

Comments
 (0)