Skip to content

Commit f61a2e5

Browse files
committed
updated iterator to ensure weight updates when training student model
1 parent 1da3909 commit f61a2e5

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,13 @@ def __iter__(self):
101101
return self
102102

103103
def __next__(self):
104-
if self.record_index < self.num_records:
105-
pass
104+
if self.record_index >= self.num_records:
105+
self.current_epoch += 1
106+
if self.current_epoch >= self.epochs:
107+
raise StopIteration
106108

107-
self.current_epoch += 1
108-
if self.current_epoch >= self.epochs:
109-
raise StopIteration
110-
111-
self.record_index = 0
112-
self.reader = array_record_module.ArrayRecordReader(self.filepath)
109+
self.record_index = 0
110+
self.reader = array_record_module.ArrayRecordReader(self.filepath)
113111

114112
record = self.reader.read()
115113
self.record_index += 1

0 commit comments

Comments
 (0)