Skip to content

Commit 5c76bdd

Browse files
Merge pull request #3260 from AI-Hypercomputer:cleanup_distillation_classes
PiperOrigin-RevId: 875907919
2 parents f255755 + bc7b717 commit 5c76bdd

3 files changed

Lines changed: 58 additions & 20 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@
4040
# -----------------------------------------------------------------------------
4141

4242

43+
@flax.struct.dataclass(frozen=True)
44+
class DistillationForwardOutput:
45+
"""Dataclass to carry MaxText-specific output fields."""
46+
47+
#: logits
48+
logits: jax.Array = None
49+
#: out_projection_activations
50+
out_projection_activations: jax.Array = None
51+
52+
4353
@flax.struct.dataclass(frozen=True)
4454
class MaxTextTrainingInput(distillation_trainer.TrainingInput):
4555
"""Extended TrainingInput dataclass to carry MaxText-specific fields."""
@@ -115,8 +125,8 @@ class CombinedDistillationStrategy(logit.LogitStrategy):
115125

116126
def __init__(
117127
self,
118-
student_forward_fn: Callable[..., jax.Array],
119-
teacher_forward_fn: Callable[..., jax.Array],
128+
student_forward_fn: Callable[..., DistillationForwardOutput],
129+
teacher_forward_fn: Callable[..., DistillationForwardOutput],
120130
labels_fn: Callable[..., jax.Array],
121131
temperature: float = 2.0,
122132
alpha: float = 0.5,
@@ -158,20 +168,20 @@ def __init__(
158168

159169
def compute_loss(
160170
self,
161-
student_output: jax.Array,
162-
teacher_output: jax.Array,
171+
student_output: DistillationForwardOutput,
172+
teacher_output: DistillationForwardOutput,
163173
labels: jax.Array,
164174
) -> tuple[jax.Array, dict[str, jax.Array]]:
165175
"""Computes Loss and Auxiliary Metrics."""
166176
# Calculate Distillation Loss (KL Divergence)
167177
# Scale logits by temperature T for soft targets
168178
# We use explicit float32 casting for stability in loss calculation
169-
s_logits = student_output[0].astype(jnp.float32)
170-
t_logits = teacher_output[0].astype(jnp.float32)
179+
s_logits = student_output.logits.astype(jnp.float32)
180+
t_logits = teacher_output.logits.astype(jnp.float32)
171181

172182
# Shape: [num_layers, batch, seq, hidden_dim]
173-
s_features = student_output[-1]
174-
t_features = teacher_output[-1]
183+
s_features = student_output.out_projection_activations
184+
t_features = teacher_output.out_projection_activations
175185

176186
if (s_features is None or t_features is None) and self.beta_feature > 0.0:
177187
raise ValueError(
@@ -210,6 +220,9 @@ def compute_loss(
210220
s_features_sliced = s_features
211221
t_features_sliced = t_features
212222

223+
s_features_sliced = s_features_sliced.astype(jnp.float32)
224+
t_features_sliced = t_features_sliced.astype(jnp.float32)
225+
213226
feature_loss = self.beta_feature * self.feature_loss_fn(s_features_sliced, t_features_sliced)
214227

215228
total_loss = base_logit_loss + feature_loss
@@ -227,13 +240,13 @@ def compute_loss(
227240

228241
def compute_eval_loss(
229242
self,
230-
student_output: jax.Array,
243+
student_output: DistillationForwardOutput,
231244
labels: jax.Array,
232245
) -> tuple[jax.Array, dict[str, jax.Array]]:
233246
"""Computes Eval Loss and returns empty aux dict (required for consistency)."""
234247
# Parent logic for task loss
235248
# We re-implement simple CE here to ensure float32 casting
236-
s_logits = student_output.astype(jnp.float32)
249+
s_logits = student_output.logits.astype(jnp.float32)
237250
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
238251
task_loss = jnp.mean(ce_loss)
239252

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434
"""
3535

36-
from typing import Sequence
36+
from typing import Sequence, Callable
3737
from absl import app
3838
from flax import nnx
3939
from flax.linen import partitioning as nn_partitioning
@@ -119,7 +119,7 @@ def optimizer_factory(learning_rate):
119119
return optimizer
120120

121121

122-
def create_forward_fn(config: pyconfig.HyperParameters):
122+
def create_forward_fn(config: pyconfig.HyperParameters) -> Callable[..., distillation_utils.DistillationForwardOutput]:
123123
"""Creates a forward function closure that binds the specific model configuration.
124124
125125
Args:
@@ -130,7 +130,9 @@ def create_forward_fn(config: pyconfig.HyperParameters):
130130
Tunix `LogitStrategy` and handles the MaxText-specific forward call.
131131
"""
132132

133-
def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs):
133+
def model_forward_fn(
134+
model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs
135+
) -> distillation_utils.DistillationForwardOutput:
134136
"""Forward pass wrapper adapted for raw MaxText models."""
135137
del kwargs # Unused
136138
del attention_mask # Unused
@@ -141,10 +143,14 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
141143
decoder_segment_ids=decoder_segment_ids,
142144
enable_dropout=config.enable_dropout,
143145
)
144-
hidden_features = None
146+
out_projection_activations = None
145147
if config.distill_beta > 0.0:
146-
hidden_features = maxtext_utils.get_intermediate_value(model, "out_projection_activations", clear=True)
147-
return logits, hidden_features
148+
out_projection_activations = maxtext_utils.get_intermediate_value(model, "out_projection_activations", clear=True)
149+
150+
retval = distillation_utils.DistillationForwardOutput(
151+
logits=logits, out_projection_activations=out_projection_activations
152+
)
153+
return retval
148154

149155
return model_forward_fn
150156

tests/unit/train_distill_test.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,21 @@ def test_monitored_strategy(self):
169169

170170
# Dummy inputs (batch=1, seq=2, vocab=4)
171171
# Note: Shapes must align for broadcasting
172-
student_logits = (jnp.array([[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]]) * 10, jnp.ones((32, 1, 1, 8)))
173-
teacher_logits = (jnp.array([[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]]) * 10, jnp.ones((32, 1, 1, 8)))
172+
student_output = distillation_utils.DistillationForwardOutput(
173+
logits=jnp.array([[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]]) * 10,
174+
out_projection_activations=jnp.ones((32, 1, 1, 8)),
175+
)
176+
teacher_output = distillation_utils.DistillationForwardOutput(
177+
logits=jnp.array([[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]]) * 10,
178+
out_projection_activations=jnp.ones((32, 1, 1, 8)),
179+
)
174180

175181
# Labels must be One-Hot Encoded to match logits shape (1, 2, 4)
176182
labels_indices = jnp.array([[0, 1]])
177183
labels = jax.nn.one_hot(labels_indices, 4)
178184

179185
# Run calculation
180-
_, metrics = strategy.compute_loss(student_logits, teacher_logits, labels)
186+
_, metrics = strategy.compute_loss(student_output, teacher_output, labels)
181187

182188
# Verify structure
183189
self.assertIsInstance(metrics, dict)
@@ -203,7 +209,20 @@ def test_strategy_compute_eval_loss(self):
203209
strategy = distillation_utils.CombinedDistillationStrategy(
204210
student_forward_fn=mock.Mock(), teacher_forward_fn=mock.Mock(), labels_fn=mock.Mock(), temperature=1.0, alpha=0.5
205211
)
206-
logits = jnp.array([[[10.0, 0.0]]])
212+
# Case where feature loss is enabled
213+
logits = distillation_utils.DistillationForwardOutput(
214+
logits=jnp.array([[[10.0, 0.0]]]), out_projection_activations=np.ones((32, 1, 1, 8))
215+
)
216+
labels = jnp.array([[[1.0, 0.0]]])
217+
218+
loss, aux = strategy.compute_eval_loss(logits, labels)
219+
self.assertTrue(isinstance(loss, jax.Array))
220+
self.assertEqual(aux, {})
221+
222+
# Case where feature loss is disabled.
223+
logits = distillation_utils.DistillationForwardOutput(
224+
logits=jnp.array([[[10.0, 0.0]]]), out_projection_activations=None
225+
)
207226
labels = jnp.array([[[1.0, 0.0]]])
208227

209228
loss, aux = strategy.compute_eval_loss(logits, labels)

0 commit comments

Comments
 (0)