4040# -----------------------------------------------------------------------------
4141
4242
43+ @flax .struct .dataclass (frozen = True )
44+ class DistillationForwardOutput :
45+ """Dataclass to carry MaxText-specific output fields."""
46+
47+ #: logits
48+ logits : jax .Array = None
49+ #: out_projection_activations
50+ out_projection_activations : jax .Array = None
51+
52+
4353@flax .struct .dataclass (frozen = True )
4454class MaxTextTrainingInput (distillation_trainer .TrainingInput ):
4555 """Extended TrainingInput dataclass to carry MaxText-specific fields."""
@@ -115,8 +125,8 @@ class CombinedDistillationStrategy(logit.LogitStrategy):
115125
116126 def __init__ (
117127 self ,
118- student_forward_fn : Callable [..., jax . Array ],
119- teacher_forward_fn : Callable [..., jax . Array ],
128+ student_forward_fn : Callable [..., DistillationForwardOutput ],
129+ teacher_forward_fn : Callable [..., DistillationForwardOutput ],
120130 labels_fn : Callable [..., jax .Array ],
121131 temperature : float = 2.0 ,
122132 alpha : float = 0.5 ,
@@ -158,20 +168,20 @@ def __init__(
158168
159169 def compute_loss (
160170 self ,
161- student_output : jax . Array ,
162- teacher_output : jax . Array ,
171+ student_output : DistillationForwardOutput ,
172+ teacher_output : DistillationForwardOutput ,
163173 labels : jax .Array ,
164174 ) -> tuple [jax .Array , dict [str , jax .Array ]]:
165175 """Computes Loss and Auxiliary Metrics."""
166176 # Calculate Distillation Loss (KL Divergence)
167177 # Scale logits by temperature T for soft targets
168178 # We use explicit float32 casting for stability in loss calculation
169- s_logits = student_output [ 0 ] .astype (jnp .float32 )
170- t_logits = teacher_output [ 0 ] .astype (jnp .float32 )
179+ s_logits = student_output . logits .astype (jnp .float32 )
180+ t_logits = teacher_output . logits .astype (jnp .float32 )
171181
172182 # Shape: [num_layers, batch, seq, hidden_dim]
173- s_features = student_output [ - 1 ]
174- t_features = teacher_output [ - 1 ]
183+ s_features = student_output . out_projection_activations
184+ t_features = teacher_output . out_projection_activations
175185
176186 if (s_features is None or t_features is None ) and self .beta_feature > 0.0 :
177187 raise ValueError (
@@ -210,6 +220,9 @@ def compute_loss(
210220 s_features_sliced = s_features
211221 t_features_sliced = t_features
212222
223+ s_features_sliced = s_features_sliced .astype (jnp .float32 )
224+ t_features_sliced = t_features_sliced .astype (jnp .float32 )
225+
213226 feature_loss = self .beta_feature * self .feature_loss_fn (s_features_sliced , t_features_sliced )
214227
215228 total_loss = base_logit_loss + feature_loss
@@ -227,13 +240,13 @@ def compute_loss(
227240
228241 def compute_eval_loss (
229242 self ,
230- student_output : jax . Array ,
243+ student_output : DistillationForwardOutput ,
231244 labels : jax .Array ,
232245 ) -> tuple [jax .Array , dict [str , jax .Array ]]:
233246 """Computes Eval Loss and returns empty aux dict (required for consistency)."""
234247 # Parent logic for task loss
235248 # We re-implement simple CE here to ensure float32 casting
236- s_logits = student_output .astype (jnp .float32 )
249+ s_logits = student_output .logits . astype (jnp .float32 )
237250 ce_loss = optax .softmax_cross_entropy (logits = s_logits , labels = labels )
238251 task_loss = jnp .mean (ce_loss )
239252
0 commit comments