Skip to content

Commit 27b3790

Browse files
Merge pull request #3332 from AI-Hypercomputer:vladk/distill-sft-hf
PiperOrigin-RevId: 881116982
2 parents d2c172a + bf2ead8 commit 27b3790

7 files changed

Lines changed: 131 additions & 15 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Soft Distillation Configuration
16+
17+
# Inherit MaxText defaults
18+
base_config: "post_train/distillation.yml"
19+
20+
use_sft: True
21+
sft_train_on_completion_only: True
22+
23+
# --- Dataset & Tokenizer ---
24+
hf_path: "HuggingFaceH4/ultrachat_200k"
25+
dataset_type: "hf"
26+
27+
# chat_template is required for sft mode & HF pipeline. Many tokenizers already provide it.
28+
# Some non-instruct versions of hf tokenizers have no chat template defined so one has to specify it here
29+
chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = '<|begin_of_text|>' + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
30+
train_split: "train_sft"
31+
eval_split: "test_sft"
32+
train_data_columns: ["messages"]
33+
eval_data_columns: ["messages"]

src/maxtext/configs/post_train/distillation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ per_device_batch_size: 2
5151
gradient_accumulation_steps: 8
5252

5353
# --- Learning Rate Schedule ---
54-
learning_rate: 2.0e-4
54+
learning_rate: 2.0e-4
5555
learning_rate_schedule_steps: 200000
5656
warmup_steps_fraction: 0.1
57-
cosine_learning_rate_final_fraction: 0.1
57+
learning_rate_final_fraction: 0.1

src/maxtext/configs/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import datetime
2020
import enum
2121
from enum import Enum
22+
from jinja2 import Environment, TemplateSyntaxError
2223
import logging
2324
import math
2425
from math import prod
@@ -926,6 +927,9 @@ class Tokenizer(BaseModel):
926927
tokenizer_type: TokenizerType = Field(TokenizerType.SENTENCEPIECE, description="The type of tokenizer.")
927928
use_chat_template: bool = Field(False, description="Whether to use the chat template for tokenization.")
928929
chat_template_path: str = Field("", description="Path to chat template json file.")
930+
chat_template: str = Field(
931+
"", description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template."
932+
)
929933
tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.")
930934
tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.")
931935
add_bos: bool = Field(True, description="Whether to add a beginning-of-sentence token.")
@@ -1991,6 +1995,15 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
19911995
)
19921996
self.tokenizer_path = tokenizer_path
19931997

1998+
# validate chat_template format if defined
1999+
chat_template = getattr(self, "chat_template", "")
2000+
if chat_template:
2001+
try:
2002+
env = Environment()
2003+
env.parse(chat_template)
2004+
except TemplateSyntaxError as e:
2005+
raise ValueError(f"Specified chat_template is invalid: {e}") from e
2006+
19942007
# C. SET PRIMARY DEPENDENCIES & DEFAULTS
19952008
# If learning_rate_schedule_steps is -1, it defaults to the total number of training steps.
19962009
if self.learning_rate_schedule_steps == -1:

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Input pipeline using Huggingface datasets."""
1616

17+
from typing import Optional
18+
1719
import ml_collections
1820

1921
import jax
@@ -213,6 +215,7 @@ def preprocessing_pipeline(
213215
grain_worker_count=1, # only support 0 or 1
214216
max_segments_per_seq=None,
215217
num_epoch=1,
218+
chat_template: Optional[str] = None,
216219
):
217220
"""pipeline for preprocessing HF dataset"""
218221
import datasets # pylint: disable=import-outside-toplevel
@@ -242,19 +245,22 @@ def preprocessing_pipeline(
242245
token=hf_access_token,
243246
)
244247

248+
dataset = dataset.select_columns(data_column_names)
249+
245250
if use_sft:
246-
dataset = dataset.select_columns(data_column_names)
251+
if chat_template:
252+
tokenizer.chat_template = chat_template
247253

248254
supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]]
249255
assert any(
250256
set(data_column_names) == set(supported) for supported in supported_columns
251257
), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}"
252258

253259
# convert instruction dataset to conversational format
260+
# currently only works for Q&A datasets
254261
dataset, data_column_names = instruction_data_processing.convert_to_conversational_format(
255262
dataset=dataset, data_columns=data_column_names, chat_template_path=chat_template_path
256263
)
257-
258264
assert input_pipeline_utils.is_conversational(
259265
dataset.features, data_column_names
260266
), "Dataset is not in conversational format."
@@ -276,8 +282,6 @@ def preprocessing_pipeline(
276282
input_pipeline_utils.apply_chat_template,
277283
fn_kwargs={"tokenizer_model": tokenizer, "data_column_name": data_column_names[0]},
278284
)
279-
else:
280-
dataset = dataset.select_columns(data_column_names)
281285

282286
pad_id = _get_pad_id(tokenizer)
283287

@@ -426,6 +430,7 @@ def make_hf_train_iterator(
426430
chat_template_path=config.chat_template_path,
427431
max_segments_per_seq=config.max_segments_per_seq,
428432
num_epoch=config.num_epoch,
433+
chat_template=config.chat_template,
429434
)
430435
return train_iter
431436

@@ -482,5 +487,6 @@ def make_hf_eval_iterator(
482487
sft_train_on_completion_only=config.sft_train_on_completion_only,
483488
chat_template_path=config.chat_template_path,
484489
max_segments_per_seq=config.max_segments_per_seq,
490+
chat_template=config.chat_template,
485491
)
486492
return eval_iter

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class MaxTextTrainingInput(distillation_trainer.TrainingInput):
6060
decoder_segment_ids: jax.Array = None
6161
#: Ground truth target tokens (used for loss calculation and logging).
6262
targets: jax.Array = None
63+
#: Position indices for the target tokens.
64+
targets_position: jax.Array = None
65+
#: Segment IDs for packed target tokens.
66+
targets_segmentation: jax.Array = None
6367

6468

6569
# -----------------------------------------------------------------------------
@@ -106,6 +110,11 @@ def __next__(self) -> MaxTextTrainingInput:
106110
input_mask = jnp.ones_like(batch["inputs"], dtype=bool)
107111
seg_ids = None
108112

113+
# If in SFT-mode, 'targets' contains prompts which should be masked out when computing the loss.
114+
# If using with packing the targets_segmentation mask is supposed to be a combined target+packing mask
115+
targets_segmentation = batch.get("targets_segmentation", jnp.ones_like(batch["targets"]))
116+
targets_position = batch.get("targets_position", batch.get("inputs_position"))
117+
109118
# pylint: disable=unexpected-keyword-arg
110119
return MaxTextTrainingInput(
111120
input_tokens=batch["inputs"],
@@ -114,6 +123,8 @@ def __next__(self) -> MaxTextTrainingInput:
114123
positions=batch["inputs_position"],
115124
decoder_segment_ids=seg_ids,
116125
targets=batch["targets"],
126+
targets_position=targets_position,
127+
targets_segmentation=targets_segmentation,
117128
)
118129

119130

@@ -134,6 +145,7 @@ def __init__(
134145
layer_indices: Optional[List[int]] = None,
135146
feature_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
136147
cosine_distance_axis: int | tuple[int, ...] = -1,
148+
sft_mode: bool = False,
137149
):
138150
"""Initializes the Combined strategy using tunix logit.LogitStrategy.
139151
@@ -165,6 +177,7 @@ def __init__(
165177
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
166178
optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis)
167179
)
180+
self.sft_mode = sft_mode
168181

169182
def compute_loss(
170183
self,
@@ -192,19 +205,23 @@ def compute_loss(
192205
log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1)
193206
teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1)
194207

208+
# labels are supposed to have all sft masks applied by this moment
209+
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True) if self.sft_mode else None
210+
mean_mask = jnp.squeeze(labels_mask, axis=-1) if labels_mask is not None else None
211+
195212
# KL(Teacher || Student)
196-
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp)
213+
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp, where=labels_mask)
197214

198215
# Scale gradients by T^2 (Hinton et al.)
199-
soft_loss = jnp.mean(kl_div) * (self.temperature**2)
216+
soft_loss = jnp.mean(kl_div, where=mean_mask) * (self.temperature**2)
200217

201218
# 1. Student Hard Loss (Existing)
202-
ce_loss_student = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
203-
hard_loss = jnp.mean(ce_loss_student)
219+
ce_loss_student = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask)
220+
hard_loss = jnp.mean(ce_loss_student, where=mean_mask)
204221

205222
# 2. Teacher Hard Loss (For Verification)
206-
ce_loss_teacher = optax.softmax_cross_entropy(logits=t_logits, labels=labels)
207-
teacher_hard_loss = jnp.mean(ce_loss_teacher)
223+
ce_loss_teacher = optax.softmax_cross_entropy(logits=t_logits, labels=labels, where=labels_mask)
224+
teacher_hard_loss = jnp.mean(ce_loss_teacher, where=mean_mask)
208225

209226
# 3. Combine losses
210227
base_logit_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,15 @@ def model_forward_fn(
134134
model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs
135135
) -> distillation_utils.DistillationForwardOutput:
136136
"""Forward pass wrapper adapted for raw MaxText models."""
137-
del kwargs # Unused
138137
del attention_mask # Unused
139138
del cache # Unused
140139
logits = model(
141140
decoder_input_tokens=input_tokens,
142141
decoder_positions=positions,
143142
decoder_segment_ids=decoder_segment_ids,
144143
enable_dropout=config.enable_dropout,
144+
decoder_target_tokens=kwargs.get("targets", None),
145+
decoder_target_mask=kwargs.get("targets_segmentation", None),
145146
)
146147
out_projection_activations = None
147148
if config.distill_beta > 0.0:
@@ -213,6 +214,8 @@ def _prepare_inputs(
213214
positions=input_data.positions,
214215
decoder_segment_ids=input_data.decoder_segment_ids,
215216
targets=input_data.targets,
217+
targets_position=input_data.targets_position,
218+
targets_segmentation=input_data.targets_segmentation,
216219
)
217220

218221
def _post_process_train_step(self, aux: dict[str, jax.Array]) -> None:
@@ -352,11 +355,13 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
352355
teacher_model = get_maxtext_model(teacher_config, mesh)
353356

354357
# 3. Define Distillation Strategy
355-
def labels_fn(targets, **kwargs):
358+
def labels_fn(targets, targets_segmentation=None, **kwargs):
356359
"""Converts integer targets to masked one-hot vectors for hard label loss."""
357360
del kwargs # Unused
358361
one_hot = jax.nn.one_hot(targets, student_config.vocab_size)
359362
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]
363+
if targets_segmentation is not None:
364+
mask = mask * (targets_segmentation != 0)[..., None]
360365
return one_hot * mask
361366

362367
# Both Student and Teacher use the same forward logic via the adapter
@@ -372,6 +377,7 @@ def labels_fn(targets, **kwargs):
372377
alpha=student_config.distill_alpha,
373378
beta_feature=student_config.distill_beta,
374379
layer_indices=student_config.distill_layer_indices,
380+
sft_mode=student_config.use_sft,
375381
)
376382

377383
student_model, teacher_model = strategy.pre_process_models(student_model, teacher_model)
@@ -436,6 +442,8 @@ def labels_fn(targets, **kwargs):
436442
"attention_mask": batch.input_mask,
437443
"decoder_segment_ids": batch.decoder_segment_ids,
438444
"targets": batch.targets, # Passed to strategy (labels_fn)
445+
"targets_position": batch.targets_position, # Passed to strategy (labels_fn)
446+
"targets_segmentation": batch.targets_segmentation, # Passed to strategy (labels_fn)
439447
"cache": None,
440448
}
441449
)

tests/unit/train_distill_test.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,27 @@ def test_maxtext_to_tunix_iterator(self):
7676
expected_mask = dummy_batch["inputs_segmentation"] != 0
7777
np.testing.assert_array_equal(tunix_input.input_mask, expected_mask)
7878

79+
def test_maxtext_to_tunix_iterator_sft(self):
80+
"""Verifies SFT-related fields are handled correctly."""
81+
# 1. Create a dummy batch with SFT fields
82+
dummy_batch_sft = {
83+
"inputs": np.array([[10, 11]]),
84+
"inputs_position": np.array([[0, 1]]),
85+
"targets": np.array([[11, 12]]),
86+
"targets_position": np.array([[100, 101]]), # Custom position
87+
"targets_segmentation": np.array([[0, 1]]), # Custom segmentation (mask)
88+
}
89+
dummy_iter_sft = iter([dummy_batch_sft])
90+
91+
# 2. Initialize Adapter and get output
92+
adapter_sft = distillation_utils.MaxTextToTunixIterator(dummy_iter_sft)
93+
tunix_input_sft = next(adapter_sft)
94+
95+
# 3. Verify SFT fields are passed through
96+
self.assertIsInstance(tunix_input_sft, distillation_utils.MaxTextTrainingInput)
97+
np.testing.assert_array_equal(tunix_input_sft.targets_position, dummy_batch_sft["targets_position"])
98+
np.testing.assert_array_equal(tunix_input_sft.targets_segmentation, dummy_batch_sft["targets_segmentation"])
99+
79100
def test_maxtext_to_tunix_iterator_packed_fallback(self):
80101
"""Verifies fallback behavior when segmentation is missing."""
81102
dummy_batch = {
@@ -161,6 +182,12 @@ def test_optimizer_factory(self):
161182
train_distill.get_distillation_optimizer(config, max_train_steps=100)
162183

163184
def test_monitored_strategy(self):
185+
self._test_monitored_strategy(False)
186+
187+
def test_monitored_strategy_sft(self):
188+
self._test_monitored_strategy(True)
189+
190+
def _test_monitored_strategy(self, sft_mode: bool):
164191
"""Verifies the strategy calculates metrics and returns the correct tuple."""
165192
strategy = distillation_utils.CombinedDistillationStrategy(
166193
student_forward_fn=lambda m, **k: None,
@@ -170,6 +197,7 @@ def test_monitored_strategy(self):
170197
alpha=0.5,
171198
beta_feature=1.0,
172199
layer_indices=None,
200+
sft_mode=sft_mode,
173201
)
174202

175203
# Dummy inputs (batch=1, seq=2, vocab=4)
@@ -210,9 +238,17 @@ def test_monitored_strategy(self):
210238
self.assertLess(metrics["distill/out_proj_feature_loss"], 1e-5)
211239

212240
def test_strategy_compute_eval_loss(self):
241+
self._verify_strategy_compute_eval_loss(sft_mode=False)
242+
243+
def _verify_strategy_compute_eval_loss(self, sft_mode):
213244
"""Covers MonitoredLogitStrategy.compute_eval_loss."""
214245
strategy = distillation_utils.CombinedDistillationStrategy(
215-
student_forward_fn=mock.Mock(), teacher_forward_fn=mock.Mock(), labels_fn=mock.Mock(), temperature=1.0, alpha=0.5
246+
student_forward_fn=mock.Mock(),
247+
teacher_forward_fn=mock.Mock(),
248+
labels_fn=mock.Mock(),
249+
temperature=1.0,
250+
alpha=0.5,
251+
sft_mode=sft_mode,
216252
)
217253
# Case where feature loss is enabled
218254
logits = distillation_utils.DistillationForwardOutput(
@@ -234,6 +270,9 @@ def test_strategy_compute_eval_loss(self):
234270
self.assertTrue(isinstance(loss, jax.Array))
235271
self.assertEqual(aux, {})
236272

273+
def test_strategy_compute_eval_loss_sft(self):
274+
self._verify_strategy_compute_eval_loss(sft_mode=True)
275+
237276
def test_setup_pipeline_grain_enabled(self):
238277
"""Covers _setup_and_restore_input_pipeline when Grain IS detected."""
239278
mock_trainer = mock.Mock()

0 commit comments

Comments
 (0)