Skip to content

Commit 1da3909

Browse files
committed
updated code formatting and style
1 parent 18f2eb6 commit 1da3909

2 files changed

Lines changed: 23 additions & 20 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,15 @@ class MaxTextTrainingInput(peft_trainer.TrainingInput):
7272
top_k_logits: jax.Array = None
7373
top_k_indices: jax.Array = None
7474

75+
7576
# -----------------------------------------------------------------------------
7677
# Data Loading Adapter
7778
# -----------------------------------------------------------------------------
7879

80+
7981
class OfflineArrayRecordIterator:
8082
"""Reads the pre-generated global top-k logits file."""
83+
8184
def __init__(self, data_dir: str, epochs: int = 100):
8285
# Check if the user passed a directory or a direct file path
8386
if tf.io.gfile.isdir(data_dir):
@@ -100,18 +103,18 @@ def __iter__(self):
100103
def __next__(self):
101104
if self.record_index < self.num_records:
102105
pass
103-
106+
104107
self.current_epoch += 1
105108
if self.current_epoch >= self.epochs:
106-
raise StopIteration
107-
109+
raise StopIteration
110+
108111
self.record_index = 0
109112
self.reader = array_record_module.ArrayRecordReader(self.filepath)
110113

111114
record = self.reader.read()
112115
self.record_index += 1
113116
data = pickle.loads(record)
114-
117+
115118
# Map the arrays to match MaxText's expected dictionary
116119
batch = {
117120
"inputs": data["tokens"],
@@ -121,9 +124,10 @@ def __next__(self):
121124
for key in ["inputs_position", "inputs_segmentation", "targets_segmentation", "targets"]:
122125
if key in data:
123126
batch[key] = data[key]
124-
127+
125128
return batch
126129

130+
127131
class MaxTextToTunixIterator:
128132
"""Adapts the raw dictionary output of MaxText's data loader to Tunix objects.
129133

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def _prepare_inputs(
307307
targets_position=input_data.targets_position,
308308
targets_segmentation=input_data.targets_segmentation,
309309
top_k_logits=input_data.top_k_logits,
310-
top_k_indices=input_data.top_k_indices
310+
top_k_indices=input_data.top_k_indices,
311311
)
312312

313313
def _post_process_train_step(self, aux: dict[str, jax.Array]) -> None:
@@ -406,7 +406,12 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
406406
# -----------------------------------------------------------------------------
407407

408408

409-
def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters, is_offline: bool = False, offline_data_dir: str | None = None) -> None:
409+
def train_distill(
410+
student_config: pyconfig.HyperParameters,
411+
teacher_config: pyconfig.HyperParameters,
412+
is_offline: bool = False,
413+
offline_data_dir: str | None = None,
414+
) -> None:
410415
"""Main distillation training loop.
411416
412417
Orchestrates the loading of both student and teacher models, configures the
@@ -550,29 +555,23 @@ def custom_gen_model_input_fn(batch):
550555
"targets_segmentation": batch.targets_segmentation,
551556
"cache": None,
552557
}
553-
558+
554559
# If we are in online mode then we exit
555560
if getattr(batch, "top_k_logits", None) is None:
556561
return inputs_dict
557562

558563
# Scatter the offline arrays into a dense tensor of -10000s
559564
dense_shape = batch.input_tokens.shape + (student_config.vocab_size,)
560565
dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32)
561-
dense_logits = jnp.put_along_axis(
562-
dense_logits,
563-
batch.top_k_indices,
564-
batch.top_k_logits,
565-
axis=-1,
566-
inplace=False
567-
)
568-
566+
dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False)
567+
569568
# Inject it as teacher_output so the trainer skips the teacher forward pass
570569
inputs_dict["teacher_output"] = distillation_utils.DistillationForwardOutput(
571570
logits=dense_logits, out_projection_activations=None
572571
)
573-
572+
574573
return inputs_dict
575-
574+
576575
trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn)
577576

578577
# 9. Create Iterator Wrappers (Use Utils)
@@ -635,7 +634,7 @@ def main(argv: Sequence[str], local_args) -> None:
635634
student_config = pyconfig.initialize(argv, **student_overrides)
636635

637636
# 3. Initialize TEACHER Config
638-
# We isolate the Teacher from Student CLI arguments (like pruning params).
637+
# We isolate the Teacher from Student CLI arguments (like pruning params).
639638
teacher_overrides = global_config.teacher_overrides
640639

641640
# Ensure load_parameters_path is set in overrides
@@ -668,7 +667,7 @@ def main(argv: Sequence[str], local_args) -> None:
668667
default=None,
669668
help="GCS or local path to the pre-generated ArrayRecord teacher data.",
670669
)
671-
670+
672671
# parse_known_args separates our custom flags from MaxText's standard args
673672
local_arg, remaining_args = parser.parse_known_args()
674673

0 commit comments

Comments
 (0)