1717import dataclasses
1818import warnings
1919from threading import current_thread
20- from typing import Any
21- import datasets
22- from datasets .distributed import split_dataset_by_node
20+ from typing import Any , TYPE_CHECKING
21+
22+ if TYPE_CHECKING :
23+ import datasets
24+
2325import grain .python as grain
2426import numpy as np
2527import tensorflow as tf
@@ -145,6 +147,8 @@ def is_conversational(features, data_columns):
145147 data_columns = ["prompt", "completion"]
146148 is_conversational(features, data_columns) returns False.
147149 """
150+ import datasets # pylint: disable=import-outside-toplevel
151+
148152 for column in data_columns :
149153 messages = features [column ]
150154 if isinstance (messages , datasets .Sequence ):
@@ -293,13 +297,16 @@ class HFDataSource(grain.RandomAccessDataSource):
293297
294298 def __init__ (
295299 self ,
296- dataset : datasets .IterableDataset ,
300+ dataset : " datasets.IterableDataset" ,
297301 dataloading_host_index : int ,
298302 dataloading_host_count : int ,
299303 num_threads : int ,
300304 max_target_length : int ,
301305 data_column_names : list [str ],
302306 ):
307+ from datasets .distributed import split_dataset_by_node # pylint: disable=import-outside-toplevel
308+
309+ self ._split_dataset_by_node = split_dataset_by_node
303310 self .dataset = dataset
304311 self .num_threads = num_threads
305312 self .dataloading_host_count = dataloading_host_count
@@ -312,7 +319,7 @@ def __init__(
312319 self .n_shards = 1
313320 self ._check_shard_count ()
314321 self .dataset_shards = [dataloading_host_index * self .num_threads + i for i in range (self .num_threads )]
315- self .datasets = [split_dataset_by_node (dataset , world_size = self .n_shards , rank = x ) for x in self .dataset_shards ]
322+ self .datasets = [self . _split_dataset_by_node (dataset , world_size = self .n_shards , rank = x ) for x in self .dataset_shards ]
316323 self .data_iters = []
317324
318325 def _check_shard_count (self ):
@@ -333,7 +340,9 @@ def _update_shard(self, idx):
333340 )
334341 max_logging .log (f"New shard is { new_shard } " )
335342 self .dataset_shards [idx ] = new_shard
336- self .datasets [idx ] = split_dataset_by_node (self .dataset , world_size = self .n_shards , rank = self .dataset_shards [idx ])
343+ self .datasets [idx ] = self ._split_dataset_by_node (
344+ self .dataset , world_size = self .n_shards , rank = self .dataset_shards [idx ]
345+ )
337346 self .data_iters [idx ] = iter (self .datasets [idx ])
338347 else :
339348 raise StopIteration (f"Run out of shards on host { self .dataloading_host_index } , shard { new_shard } is not available" )
0 commit comments