Skip to content

Commit fa559ee

Browse files
committed
migrate arrayrecord parsing off TF
1 parent 05c5083 commit fa559ee

8 files changed

Lines changed: 587 additions & 18 deletions

File tree

.pre-commit-config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ repos:
88
- id: codespell
99
args:
1010
- '-w'
11-
- '--skip="*.txt,pylintrc,.*,src/maxtext/assets/*"'
11+
- '--skip="*.txt,pylintrc,.*,src/maxtext/assets/*,src/maxtext/input_pipeline/protos/*"'
1212
- '-L ND,nd,sems,TE,ROUGE,rouge,astroid,ags,dout'
1313
- '.'
1414
additional_dependencies:
@@ -30,6 +30,7 @@ repos:
3030
args:
3131
- '--disable=R0401,R0917,W0201,W0613'
3232
- "--ignore-patterns='.pytype,.*pyi$'"
33+
- '--ignore-paths=src/maxtext/input_pipeline/protos'
3334
- 'benchmarks'
3435
- 'src'
3536
- 'tests'
@@ -47,6 +48,7 @@ repos:
4748
rev: 24.10.1
4849
hooks:
4950
- id: pyink
51+
exclude: src/maxtext/input_pipeline/protos/
5052
args:
5153
- '--pyink-indentation=2'
5254
- '--line-length=122'

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@
2121

2222
if TYPE_CHECKING:
2323
import datasets
24+
import tensorflow as tf
2425

2526
import grain.python as grain
2627
import numpy as np
27-
import tensorflow as tf
28+
from maxtext.input_pipeline.protos import example_pb2
2829
from maxtext.input_pipeline import tokenizer
2930
from maxtext.multimodal import processor as mm_processor
3031
from maxtext.multimodal import utils as mm_utils
3132
from maxtext.utils import max_logging
3233

33-
Features = dict[str, tf.Tensor]
34-
AUTOTUNE = tf.data.experimental.AUTOTUNE
34+
Features = dict[str, Any]
3535
INPUT_TOKENS_KEY = "input_ids"
3636

3737
########## Functions used by TFDS pipeline
@@ -58,6 +58,8 @@ def shift_data_by_truncation(x):
5858

5959

6060
def add_segmentation_and_position(x, data_columns, padding_token=0):
61+
import tensorflow as tf # pylint: disable=import-outside-toplevel
62+
6163
for data_column in data_columns:
6264
x[f"{data_column}_segmentation"] = tf.cast(x[data_column] != padding_token, tf.int32)
6365
x[f"{data_column}_position"] = tf.broadcast_to(
@@ -68,6 +70,7 @@ def add_segmentation_and_position(x, data_columns, padding_token=0):
6870

6971
def TokenizeOp(tokenizer_model, features: Features, data_keys: Iterable[str] = ("inputs", "targets")) -> Features:
7072
"""Op for tokenization"""
73+
import tensorflow as tf # pylint: disable=import-outside-toplevel
7174

7275
def _process_string(string_tensor):
7376
# Extract string value and decode it if necessary
@@ -421,20 +424,23 @@ class ParseFeatures(grain.MapTransform):
421424

422425
def __init__(self, data_columns, tokenize):
423426
self.data_columns = data_columns
424-
if tokenize:
425-
self.dtype = tf.string
426-
else:
427-
self.dtype = tf.int64
427+
self.tokenize = tokenize
428428

429429
def map(self, element):
430-
def _parse(example):
431-
parsed = tf.io.parse_example(
432-
example,
433-
{col: tf.io.FixedLenSequenceFeature([], dtype=self.dtype, allow_missing=True) for col in self.data_columns},
434-
)
435-
return parsed
436-
437-
return _parse(element)
430+
"""Parse a serialized tf.train.Example proto and extract features."""
431+
example = example_pb2.Example()
432+
example.ParseFromString(element)
433+
features = example.features.feature
434+
435+
parsed = {}
436+
for col in self.data_columns:
437+
if col in features:
438+
f = features[col]
439+
if self.tokenize:
440+
parsed[col] = np.array(f.bytes_list.value, dtype=object)
441+
else:
442+
parsed[col] = np.array(f.int64_list.value, dtype=np.int32)
443+
return parsed
438444

439445

440446
@dataclasses.dataclass
@@ -447,9 +453,9 @@ def __init__(self, column_names, tokenize):
447453

448454
def map(self, element):
449455
if self.tokenize:
450-
return {col: element[col].numpy()[0].decode() for col in self.column_names}
456+
return {col: element[col][0].decode() for col in self.column_names}
451457
else:
452-
return {col: element[col].numpy() for col in self.column_names}
458+
return {col: element[col] for col in self.column_names}
453459

454460

455461
@dataclasses.dataclass
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)