We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1da3909 commit f61a2e5Copy full SHA for f61a2e5
1 file changed
src/maxtext/trainers/post_train/distillation/distillation_utils.py
@@ -101,15 +101,13 @@ def __iter__(self):
101
return self
102
103
def __next__(self):
104
- if self.record_index < self.num_records:
105
- pass
+ if self.record_index >= self.num_records:
+ self.current_epoch += 1
106
+ if self.current_epoch >= self.epochs:
107
+ raise StopIteration
108
- self.current_epoch += 1
- if self.current_epoch >= self.epochs:
109
- raise StopIteration
110
-
111
- self.record_index = 0
112
- self.reader = array_record_module.ArrayRecordReader(self.filepath)
+ self.record_index = 0
+ self.reader = array_record_module.ArrayRecordReader(self.filepath)
113
114
record = self.reader.read()
115
self.record_index += 1
0 commit comments