@@ -108,6 +108,7 @@ def __init__(
108108 self .current_shard = dataloading_host_index
109109 self .dataset_shard = split_dataset_by_node (dataset , world_size = self .n_shards , rank = self .current_shard )
110110 self .data_iter = None
111+ self .out_of_data = False
111112
112113 def _check_shard_count (self ):
113114 if self .n_shards < self .dataloading_host_count :
@@ -119,11 +120,15 @@ def _check_shard_count(self):
119120 self .n_shards = self .dataloading_host_count
120121
121122 def _update_shard (self ):
122- new_shard = (self .current_shard + self .dataloading_host_count ) % self .n_shards
123- max_logging .log (f"Updating host { self .dataloading_host_index } dataset from shard { self .current_shard } to { new_shard } " )
124- self .current_shard = new_shard
125- self .dataset_shard = split_dataset_by_node (self .dataset , world_size = self .n_shards , rank = self .current_shard )
126- self .data_iter = iter (self .dataset_shard )
123+ new_shard = self .current_shard + self .dataloading_host_count
124+ if new_shard < self .n_shards :
125+ max_logging .log (f"Updating host { self .dataloading_host_index } dataset from shard { self .current_shard } to { new_shard } " )
126+ self .current_shard = new_shard
127+ self .dataset_shard = split_dataset_by_node (self .dataset , world_size = self .n_shards , rank = self .current_shard )
128+ self .data_iter = iter (self .dataset_shard )
129+ else :
130+ max_logging .log (f"Run out of shards on host { self .dataloading_host_index } , shard { new_shard } is not available" )
131+ self .out_of_data = True
127132
128133 def __len__ (self ):
129134 """Return length of the HF dataset. Since HuggingFace IterableDataset does not have length,
@@ -138,6 +143,8 @@ def __getitem__(self, index):
138143
139144 while True :
140145 try :
146+ if self .out_of_data :
147+ return None
141148 data = next (self .data_iter )
142149 return data
143150 except StopIteration :
0 commit comments