File tree Expand file tree Collapse file tree
trainers/post_train/distillation Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1081,15 +1081,13 @@ class Distillation(BaseModel):
10811081 default_factory = dict ,
10821082 description = "Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})." ,
10831083 )
1084-
1084+
10851085 # --- Offline Distillation Fields ---
10861086 offline_distillation : bool = Field (
1087- False ,
1088- description = "If True, enables offline distillation using pre-generated teacher data."
1087+ False , description = "If True, enables offline distillation using pre-generated teacher data."
10891088 )
10901089 offline_data_dir : Optional [str ] = Field (
1091- None ,
1092- description = "GCS or local path to the pre-generated ArrayRecord teacher data."
1090+ None , description = "GCS or local path to the pre-generated ArrayRecord teacher data."
10931091 )
10941092
10951093 # --- Loss Params ---
Original file line number Diff line number Diff line change 1818model structures with Tunix's training interfaces.
1919"""
2020
21- import os
2221import pickle
2322import tensorflow as tf
2423from array_record .python import array_record_module
@@ -82,11 +81,7 @@ class OfflineArrayRecordIterator:
8281 """Reads the pre-generated global top-k logits file."""
8382
8483 def __init__ (self , data_dir : str , epochs : int = 100 ):
85- # Check if the user passed a directory or a direct file path
86- if tf .io .gfile .isdir (data_dir ):
87- self .filepath = os .path .join (data_dir , "teacher_top_k_global.array_record" )
88- else :
89- self .filepath = data_dir
84+ self .filepath = data_dir
9085
9186 if not tf .io .gfile .exists (self .filepath ):
9287 raise FileNotFoundError (f"Offline distillation file not found: { self .filepath } " )
Original file line number Diff line number Diff line change 32323. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose
3333 a standard interface (call signature) that the Tunix `DistillationTrainer` expects.
3434"""
35- import argparse
36- import functools
37- import sys
38-
3935from typing import Sequence , Callable
4036from absl import app
4137from flax import nnx
@@ -654,4 +650,4 @@ def main(argv: Sequence[str]) -> None:
654650
655651
656652if __name__ == "__main__" :
657- app .run (main )
653+ app .run (main )
You can’t perform that action at this time.
0 commit comments