2121
2222if TYPE_CHECKING :
2323 import datasets
24+ import tensorflow as tf
2425
2526import grain .python as grain
2627import numpy as np
27- import tensorflow as tf
28+ from maxtext . input_pipeline . protos import example_pb2
2829from maxtext .input_pipeline import tokenizer
2930from maxtext .multimodal import processor as mm_processor
3031from maxtext .multimodal import utils as mm_utils
3132from maxtext .utils import max_logging
3233
33- Features = dict [str , tf .Tensor ]
34- AUTOTUNE = tf .data .experimental .AUTOTUNE
34+ Features = dict [str , Any ]
3535INPUT_TOKENS_KEY = "input_ids"
3636
3737########## Functions used by TFDS pipeline
@@ -58,6 +58,8 @@ def shift_data_by_truncation(x):
5858
5959
6060def add_segmentation_and_position (x , data_columns , padding_token = 0 ):
61+ import tensorflow as tf # pylint: disable=import-outside-toplevel
62+
6163 for data_column in data_columns :
6264 x [f"{ data_column } _segmentation" ] = tf .cast (x [data_column ] != padding_token , tf .int32 )
6365 x [f"{ data_column } _position" ] = tf .broadcast_to (
@@ -68,6 +70,7 @@ def add_segmentation_and_position(x, data_columns, padding_token=0):
6870
6971def TokenizeOp (tokenizer_model , features : Features , data_keys : Iterable [str ] = ("inputs" , "targets" )) -> Features :
7072 """Op for tokenization"""
73+ import tensorflow as tf # pylint: disable=import-outside-toplevel
7174
7275 def _process_string (string_tensor ):
7376 # Extract string value and decode it if necessary
@@ -421,20 +424,23 @@ class ParseFeatures(grain.MapTransform):
421424
422425 def __init__ (self , data_columns , tokenize ):
423426 self .data_columns = data_columns
424- if tokenize :
425- self .dtype = tf .string
426- else :
427- self .dtype = tf .int64
427+ self .tokenize = tokenize
428428
429429 def map (self , element ):
430- def _parse (example ):
431- parsed = tf .io .parse_example (
432- example ,
433- {col : tf .io .FixedLenSequenceFeature ([], dtype = self .dtype , allow_missing = True ) for col in self .data_columns },
434- )
435- return parsed
436-
437- return _parse (element )
430+ """Parse a serialized tf.train.Example proto and extract features."""
431+ example = example_pb2 .Example ()
432+ example .ParseFromString (element )
433+ features = example .features .feature
434+
435+ parsed = {}
436+ for col in self .data_columns :
437+ if col in features :
438+ f = features [col ]
439+ if self .tokenize :
440+ parsed [col ] = np .array (f .bytes_list .value , dtype = object )
441+ else :
442+ parsed [col ] = np .array (f .int64_list .value , dtype = np .int32 )
443+ return parsed
438444
439445
440446@dataclasses .dataclass
@@ -447,9 +453,9 @@ def __init__(self, column_names, tokenize):
447453
448454 def map (self , element ):
449455 if self .tokenize :
450- return {col : element [col ]. numpy () [0 ].decode () for col in self .column_names }
456+ return {col : element [col ][0 ].decode () for col in self .column_names }
451457 else :
452- return {col : element [col ]. numpy () for col in self .column_names }
458+ return {col : element [col ] for col in self .column_names }
453459
454460
455461@dataclasses .dataclass
0 commit comments