-
Notifications
You must be signed in to change notification settings - Fork 70
Expand file tree
/
Copy path_tfds_data_processing.py
More file actions
172 lines (142 loc) · 6.79 KB
/
_tfds_data_processing.py
File metadata and controls
172 lines (142 loc) · 6.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from datasets import load_dataset, load_from_disk
from maxdiffusion import multihost_dataloading
AUTOTUNE = tf.data.AUTOTUNE
def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
dataset = dataset.with_format("tensorflow")[:]
tf_dataset = tf.data.Dataset.from_tensor_slices(dataset)
if shuffle:
tf_dataset = tf_dataset.shuffle(len(tf_dataset))
tf_dataset = tf_dataset.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
tf_dataset = tf_dataset.prefetch(AUTOTUNE)
tf_dataset = tf_dataset.repeat(-1)
return tf_dataset
def make_tf_iterator(
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, tokenize_fn, image_transforms_fn
):
if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location):
train_ds = load_from_disk(config.dataset_save_location)
else:
train_ds = load_dataset(config.dataset_name, split=config.train_split)
train_ds = train_ds.select_columns([config.caption_column, config.image_column])
train_ds = train_ds.map(
function=tokenize_fn,
batched=True,
remove_columns=[config.caption_column],
num_proc=1 if config.cache_latents_text_encoder_outputs else config.tokenize_captions_num_proc,
desc="Running tokenizer on train dataset",
)
# need to do it before load_as_tf_dataset
# since raw images are different sizes
# will break from_tensor_slices
train_ds = train_ds.map(
function=image_transforms_fn,
batched=True,
remove_columns=[config.image_column],
num_proc=1 if config.cache_latents_text_encoder_outputs else config.transform_images_num_proc,
desc="Transforming images",
)
if config.cache_latents_text_encoder_outputs:
train_ds.save_to_disk(config.dataset_save_location)
train_ds.cleanup_cache_files()
train_ds = load_as_tf_dataset(train_ds, global_batch_size, True, dataloading_host_count)
train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter
def make_cached_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
):
"""
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
latents, input_ids, prompt_embeds, and text_embeds.
"""
feature_description = {
"pixel_values": tf.io.FixedLenFeature([], tf.string),
"input_ids": tf.io.FixedLenFeature([], tf.string),
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
"text_embeds": tf.io.FixedLenFeature([], tf.string),
}
def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)
def prepare_sample(features):
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
# This pipeline reads the sharded files and applies the parsing and preparation.
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
def make_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
):
"""Iterator for TFRecord format. For Laion dataset,
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
if (config.cache_latents_text_encoder_outputs
and os.path.isdir(config.dataset_save_location)
and 'load_tfrecord_cached'in config.get_keys()
and config.load_tfrecord_cached):
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
feature_description = {
"moments": tf.io.FixedLenFeature([], tf.string),
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
}
def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)
def prepare_sample(features):
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
return {"pixel_values": moments, "input_ids": clip_embeddings}
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter