Skip to content

Commit a80ab23

Browse files
committed
Use torch cpu, async write to tensorboard, script to convert latents to tfrecord, batch iterator for tfrecord cached, namedsharding instead of positional sharding
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent 3d5ef04 commit a80ab23

6 files changed

Lines changed: 218 additions & 30 deletions

File tree

requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
jax>=0.4.30
1+
--extra-index-url https://download.pytorch.org/whl/cpu
2+
jax==0.5.3
23
jaxlib>=0.4.30
34
grain-nightly==0.0.10
45
google-cloud-storage==2.17.0
56
absl-py
67
datasets
78
flax>=0.10.2
89
optax>=0.2.3
9-
torch==2.5.1
10-
torchvision==0.20.1
10+
torch==2.6.0
11+
torchvision>=0.20.1
1112
ftfy
1213
tensorboard>=2.17.0
1314
tensorboardx==2.6.2.2

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from maxdiffusion import multihost_dataloading
2323

24-
AUTOTUNE = tf.data.experimental.AUTOTUNE
24+
AUTOTUNE = tf.data.AUTOTUNE
2525

2626

2727
def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
@@ -31,7 +31,7 @@ def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_cou
3131
if shuffle:
3232
tf_dataset = tf_dataset.shuffle(len(tf_dataset))
3333
tf_dataset = tf_dataset.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
34-
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)
34+
tf_dataset = tf_dataset.prefetch(AUTOTUNE)
3535
tf_dataset = tf_dataset.repeat(-1)
3636

3737
return tf_dataset
@@ -74,6 +74,57 @@ def make_tf_iterator(
7474
return train_iter
7575

7676

77+
def make_cached_tfrecord_iterator(
78+
config,
79+
dataloading_host_index,
80+
dataloading_host_count,
81+
mesh,
82+
global_batch_size,
83+
):
84+
"""
85+
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
86+
latents, input_ids, prompt_embeds, and text_embeds.
87+
"""
88+
feature_description = {
89+
"pixel_values": tf.io.FixedLenFeature([], tf.string),
90+
"input_ids": tf.io.FixedLenFeature([], tf.string),
91+
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
92+
"text_embeds": tf.io.FixedLenFeature([], tf.string),
93+
}
94+
95+
def _parse_tfrecord_fn(example):
96+
return tf.io.parse_single_example(example, feature_description)
97+
98+
def prepare_sample(features):
99+
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
100+
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
101+
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
102+
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
103+
104+
return {
105+
"pixel_values": pixel_values,
106+
"input_ids": input_ids,
107+
"prompt_embeds": prompt_embeds,
108+
"text_embeds": text_embeds
109+
}
110+
111+
# This pipeline reads the sharded files and applies the parsing and preparation.
112+
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
113+
train_ds = (
114+
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
115+
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
116+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
117+
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
118+
.shuffle(global_batch_size * 10)
119+
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
120+
.repeat(-1)
121+
.prefetch(AUTOTUNE)
122+
)
123+
124+
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
125+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
126+
return train_iter
127+
77128
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
78129
def make_tfrecord_iterator(
79130
config,
@@ -86,19 +137,23 @@ def make_tfrecord_iterator(
86137
check out preparation script
87138
maxdiffusion/pedagogical_examples/to_tfrecords.py
88139
"""
140+
if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location):
141+
return make_cached_tfrecord_iterator(config, dataloading_host_index,
142+
dataloading_host_count, mesh,
143+
global_batch_size)
89144
feature_description = {
90145
"moments": tf.io.FixedLenFeature([], tf.string),
91146
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
92147
}
93-
148+
94149
def _parse_tfrecord_fn(example):
95150
return tf.io.parse_single_example(example, feature_description)
96151

97152
def prepare_sample(features):
98153
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
99154
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
100155
return {"pixel_values": moments, "input_ids": clip_embeddings}
101-
156+
102157
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
103158
train_ds = (
104159
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import os
2+
import argparse
3+
import tensorflow as tf
4+
from datasets import load_from_disk
5+
import numpy as np
6+
7+
def _bytes_feature(value):
8+
"""Returns a bytes_list from a serialized tensor."""
9+
if not isinstance(value, tf.Tensor):
10+
value = tf.constant(value)
11+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
12+
13+
def create_4_feature_example(record):
14+
"""Creates a tf.train.Example proto with all 4 pre-computed features."""
15+
pixel_values = tf.io.serialize_tensor(record['pixel_values'])
16+
input_ids = tf.io.serialize_tensor(record['input_ids'])
17+
prompt_embeds = tf.io.serialize_tensor(record['prompt_embeds'])
18+
text_embeds = tf.io.serialize_tensor(record['text_embeds'])
19+
20+
feature = {
21+
"pixel_values": _bytes_feature(pixel_values),
22+
"input_ids": _bytes_feature(input_ids),
23+
"prompt_embeds": _bytes_feature(prompt_embeds),
24+
"text_embeds": _bytes_feature(text_embeds)
25+
}
26+
return tf.train.Example(features=tf.train.Features(feature=feature))
27+
28+
def run(args):
29+
"""Main processing function."""
30+
# Load the cached dataset from the location specified in the arguments
31+
print(f"Loading processed dataset from disk: {args.dataset_save_location}")
32+
processed_ds = load_from_disk(args.dataset_save_location)
33+
print("Dataset loaded successfully.")
34+
35+
# Get sharding and output directory from the arguments
36+
tfrecords_dir = args.tfrecords_dir
37+
num_shards = args.data_num_shards
38+
os.makedirs(tfrecords_dir, exist_ok=True)
39+
40+
writers = [
41+
tf.io.TFRecordWriter(os.path.join(tfrecords_dir, f"shard-{i:05d}-of-{num_shards:05d}.tfrecord"))
42+
for i in range(num_shards)
43+
]
44+
45+
print(f"Writing {len(processed_ds)} records into {num_shards} TFRecord shards...")
46+
47+
for i, record in enumerate(processed_ds):
48+
# Create a new record with explicit casting for float types
49+
casted_record = {
50+
"pixel_values": np.float32(record['pixel_values']),
51+
"input_ids": record['input_ids'], # This is already integer type
52+
"prompt_embeds": np.float32(record['prompt_embeds']),
53+
"text_embeds": np.float32(record['text_embeds'])
54+
}
55+
56+
writer_index = i % num_shards
57+
tf_example = create_4_feature_example(casted_record)
58+
writers[writer_index].write(tf_example.SerializeToString())
59+
60+
for writer in writers:
61+
writer.close()
62+
63+
print("TFRecord conversion complete.")
64+
65+
66+
def main():
67+
"""Parses command-line arguments and runs the conversion."""
68+
parser = argparse.ArgumentParser(
69+
description="Convert a cached Hugging Face dataset to sharded TFRecords."
70+
)
71+
parser.add_argument(
72+
"--dataset_save_location",
73+
type=str,
74+
required=False,
75+
default="/tmp/pokemon-gpt4-captions_xl",
76+
help="Path to the cached dataset created by the training pipeline."
77+
)
78+
parser.add_argument(
79+
"--tfrecords_dir",
80+
type=str,
81+
required=False,
82+
default="/tmp/cached_pokemon_tfrecords_sharded",
83+
help="Output directory to save the sharded TFRecord files."
84+
)
85+
parser.add_argument(
86+
"--data_num_shards",
87+
type=int,
88+
default=128,
89+
help="Number of shards to split the TFRecord dataset into."
90+
)
91+
92+
args = parser.parse_args()
93+
run(args)
94+
95+
if __name__ == "__main__":
96+
main()

src/maxdiffusion/train_utils.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import numpy as np
1818
import jax
1919
import jax.numpy as jnp
20+
import threading
21+
import queue
2022

2123
from maxdiffusion import max_utils, max_logging
2224

@@ -67,10 +69,28 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
6769
metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()})
6870
metrics["scalar"].update({"learning/current_learning_rate": lr})
6971

70-
72+
_metrics_queue = queue.Queue()
7173
_buffered_step = None
7274
_buffered_metrics = None
7375

76+
def _tensorboard_writer_worker(writer, config):
77+
"""
78+
A worker function that runs in a separate thread.
79+
It waits for metrics to appear in the queue and writes them to TensorBoard.
80+
"""
81+
while True:
82+
data = _metrics_queue.get()
83+
if data is None:
84+
break
85+
metrics, step = data
86+
if jax.process_index() == 0:
87+
for metric_name in metrics.get("scalar", []):
88+
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
89+
for metric_name in metrics.get("scalars", []):
90+
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
91+
92+
if step % config.log_period == 0:
93+
writer.flush()
7494

7595
def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
7696
"""Entry point for all metrics writing in Train's Main.
@@ -81,15 +101,18 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
81101
The logic is that this ensures that Jax is able to queues train_steps and we
82102
don't block when turning "lazy" Jax arrays into real Python numbers.
83103
"""
84-
global _buffered_step, _buffered_metrics
104+
global _buffered_step, _buffered_metrics, _metrics_queue
85105

106+
if metrics:
107+
_metrics_queue.put((metrics, step))
86108
if _buffered_metrics is not None:
109+
if config.metrics_file:
110+
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
111+
87112
if _buffered_step is None:
88113
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
89114
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)
90115

91-
if config.metrics_file:
92-
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
93116

94117
if config.gcs_metrics and jax.process_index() == 0:
95118
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)
@@ -100,13 +123,6 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
100123

101124
def write_metrics_to_tensorboard(writer, metrics, step, config):
102125
"""Writes metrics to tensorboard"""
103-
if jax.process_index() == 0:
104-
for metric_name in metrics.get("scalar", []):
105-
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
106-
for metric_name in metrics.get("scalars", []):
107-
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
108-
109-
full_log = step % config.log_period == 0
110126
if jax.process_index() == 0:
111127
max_logging.log(
112128
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
@@ -116,6 +132,13 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
116132
float(metrics["scalar"]["learning/loss"]),
117133
)
118134
)
135+
if jax.process_index() == 0:
136+
for metric_name in metrics.get("scalar", []):
137+
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
138+
for metric_name in metrics.get("scalars", []):
139+
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
140+
141+
full_log = step % config.log_period == 0
119142

120143
if full_log and jax.process_index() == 0:
121144
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def start_training(self):
8787
state_shardings = {}
8888

8989
# move params to accelerator
90-
encoders_sharding = PositionalSharding(self.devices_array).replicate()
90+
encoders_sharding = jax.NamedSharding(self.mesh, P(None))
9191
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
9292
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
9393
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)

0 commit comments

Comments
 (0)