Skip to content

Commit b2956a7

Browse files
committed
Formatting
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent d76d5e8 commit b2956a7

4 files changed

Lines changed: 75 additions & 84 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,8 @@ def prepare_sample(features):
100100
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
101101
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
102102
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
103-
104-
return {
105-
"pixel_values": pixel_values,
106-
"input_ids": input_ids,
107-
"prompt_embeds": prompt_embeds,
108-
"text_embeds": text_embeds
109-
}
103+
104+
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
110105

111106
# This pipeline reads the sharded files and applies the parsing and preparation.
112107
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
@@ -125,6 +120,7 @@ def prepare_sample(features):
125120
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
126121
return train_iter
127122

123+
128124
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
129125
def make_tfrecord_iterator(
130126
config,
@@ -138,22 +134,20 @@ def make_tfrecord_iterator(
138134
maxdiffusion/pedagogical_examples/to_tfrecords.py
139135
"""
140136
if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location):
141-
return make_cached_tfrecord_iterator(config, dataloading_host_index,
142-
dataloading_host_count, mesh,
143-
global_batch_size)
137+
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
144138
feature_description = {
145139
"moments": tf.io.FixedLenFeature([], tf.string),
146140
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
147141
}
148-
142+
149143
def _parse_tfrecord_fn(example):
150144
return tf.io.parse_single_example(example, feature_description)
151145

152146
def prepare_sample(features):
153147
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
154148
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
155149
return {"pixel_values": moments, "input_ids": clip_embeddings}
156-
150+
157151
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
158152
train_ds = (
159153
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)

src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,29 @@
44
from datasets import load_from_disk
55
import numpy as np
66

7+
78
def _bytes_feature(value):
89
"""Returns a bytes_list from a serialized tensor."""
910
if not isinstance(value, tf.Tensor):
10-
value = tf.constant(value)
11+
value = tf.constant(value)
1112
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
1213

14+
1315
def create_4_feature_example(record):
14-
"""Creates a tf.train.Example proto with all 4 pre-computed features."""
15-
pixel_values = tf.io.serialize_tensor(record['pixel_values'])
16-
input_ids = tf.io.serialize_tensor(record['input_ids'])
17-
prompt_embeds = tf.io.serialize_tensor(record['prompt_embeds'])
18-
text_embeds = tf.io.serialize_tensor(record['text_embeds'])
19-
20-
feature = {
21-
"pixel_values": _bytes_feature(pixel_values),
22-
"input_ids": _bytes_feature(input_ids),
23-
"prompt_embeds": _bytes_feature(prompt_embeds),
24-
"text_embeds": _bytes_feature(text_embeds)
25-
}
26-
return tf.train.Example(features=tf.train.Features(feature=feature))
16+
"""Creates a tf.train.Example proto with all 4 pre-computed features."""
17+
pixel_values = tf.io.serialize_tensor(record["pixel_values"])
18+
input_ids = tf.io.serialize_tensor(record["input_ids"])
19+
prompt_embeds = tf.io.serialize_tensor(record["prompt_embeds"])
20+
text_embeds = tf.io.serialize_tensor(record["text_embeds"])
21+
22+
feature = {
23+
"pixel_values": _bytes_feature(pixel_values),
24+
"input_ids": _bytes_feature(input_ids),
25+
"prompt_embeds": _bytes_feature(prompt_embeds),
26+
"text_embeds": _bytes_feature(text_embeds),
27+
}
28+
return tf.train.Example(features=tf.train.Features(feature=feature))
29+
2730

2831
def run(args):
2932
"""Main processing function."""
@@ -41,56 +44,52 @@ def run(args):
4144
tf.io.TFRecordWriter(os.path.join(tfrecords_dir, f"shard-{i:05d}-of-{num_shards:05d}.tfrecord"))
4245
for i in range(num_shards)
4346
]
44-
47+
4548
print(f"Writing {len(processed_ds)} records into {num_shards} TFRecord shards...")
46-
49+
4750
for i, record in enumerate(processed_ds):
48-
# Create a new record with explicit casting for float types
49-
casted_record = {
50-
"pixel_values": np.float32(record['pixel_values']),
51-
"input_ids": record['input_ids'], # This is already integer type
52-
"prompt_embeds": np.float32(record['prompt_embeds']),
53-
"text_embeds": np.float32(record['text_embeds'])
54-
}
55-
56-
writer_index = i % num_shards
57-
tf_example = create_4_feature_example(casted_record)
58-
writers[writer_index].write(tf_example.SerializeToString())
51+
# Create a new record with explicit casting for float types
52+
casted_record = {
53+
"pixel_values": np.float32(record["pixel_values"]),
54+
"input_ids": record["input_ids"], # This is already integer type
55+
"prompt_embeds": np.float32(record["prompt_embeds"]),
56+
"text_embeds": np.float32(record["text_embeds"]),
57+
}
58+
59+
writer_index = i % num_shards
60+
tf_example = create_4_feature_example(casted_record)
61+
writers[writer_index].write(tf_example.SerializeToString())
5962

6063
for writer in writers:
61-
writer.close()
62-
64+
writer.close()
65+
6366
print("TFRecord conversion complete.")
6467

6568

6669
def main():
6770
"""Parses command-line arguments and runs the conversion."""
68-
parser = argparse.ArgumentParser(
69-
description="Convert a cached Hugging Face dataset to sharded TFRecords."
70-
)
71+
parser = argparse.ArgumentParser(description="Convert a cached Hugging Face dataset to sharded TFRecords.")
7172
parser.add_argument(
7273
"--dataset_save_location",
7374
type=str,
7475
required=False,
7576
default="/tmp/pokemon-gpt4-captions_xl",
76-
help="Path to the cached dataset created by the training pipeline."
77+
help="Path to the cached dataset created by the training pipeline.",
7778
)
7879
parser.add_argument(
7980
"--tfrecords_dir",
8081
type=str,
8182
required=False,
8283
default="/tmp/cached_pokemon_tfrecords_sharded",
83-
help="Output directory to save the sharded TFRecord files."
84+
help="Output directory to save the sharded TFRecord files.",
8485
)
8586
parser.add_argument(
86-
"--data_num_shards",
87-
type=int,
88-
default=128,
89-
help="Number of shards to split the TFRecord dataset into."
87+
"--data_num_shards", type=int, default=128, help="Number of shards to split the TFRecord dataset into."
9088
)
91-
89+
9290
args = parser.parse_args()
9391
run(args)
9492

93+
9594
if __name__ == "__main__":
96-
main()
95+
main()

src/maxdiffusion/train_utils.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,31 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
6969
metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()})
7070
metrics["scalar"].update({"learning/current_learning_rate": lr})
7171

72+
7273
_metrics_queue = queue.Queue()
7374
_buffered_step = None
7475
_buffered_metrics = None
7576

77+
7678
def _tensorboard_writer_worker(writer, config):
77-
"""
78-
A worker function that runs in a separate thread.
79-
It waits for metrics to appear in the queue and writes them to TensorBoard.
80-
"""
81-
while True:
82-
data = _metrics_queue.get()
83-
if data is None:
84-
break
85-
metrics, step = data
86-
if jax.process_index() == 0:
87-
for metric_name in metrics.get("scalar", []):
88-
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
89-
for metric_name in metrics.get("scalars", []):
90-
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
91-
92-
if step % config.log_period == 0:
93-
writer.flush()
79+
"""
80+
A worker function that runs in a separate thread.
81+
It waits for metrics to appear in the queue and writes them to TensorBoard.
82+
"""
83+
while True:
84+
data = _metrics_queue.get()
85+
if data is None:
86+
break
87+
metrics, step = data
88+
if jax.process_index() == 0:
89+
for metric_name in metrics.get("scalar", []):
90+
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
91+
for metric_name in metrics.get("scalars", []):
92+
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
93+
94+
if step % config.log_period == 0:
95+
writer.flush()
96+
9497

9598
def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
9699
"""Entry point for all metrics writing in Train's Main.
@@ -108,12 +111,11 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
108111
if _buffered_metrics is not None:
109112
if config.metrics_file:
110113
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
111-
114+
112115
if _buffered_step is None:
113116
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
114117
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)
115118

116-
117119
if config.gcs_metrics and jax.process_index() == 0:
118120
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)
119121

src/maxdiffusion/trainers/sdxl_trainer.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
load_next_batch,
4141
record_scalar_metrics,
4242
write_metrics,
43-
_metrics_queue
43+
_metrics_queue,
4444
)
4545

4646
from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (STABLE_DIFFUSION_XL_CHECKPOINT)
@@ -67,7 +67,7 @@ def get_shaped_batch(self, config, pipeline):
6767
total_train_batch_size = config.total_train_batch_size
6868
shaped_batch = {}
6969

70-
if self.config.dataset_type in ["tf","tfrecord"] and self.config.cache_latents_text_encoder_outputs:
70+
if self.config.dataset_type in ["tf", "tfrecord"] and self.config.cache_latents_text_encoder_outputs:
7171
batch_image_shape = (
7272
total_train_batch_size,
7373
pipeline.unet.config.in_channels,
@@ -92,7 +92,7 @@ def get_shaped_batch(self, config, pipeline):
9292

9393
def get_data_shardings(self):
9494
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
95-
if self.config.dataset_type in ["tf","tfrecord"] and self.config.cache_latents_text_encoder_outputs:
95+
if self.config.dataset_type in ["tf", "tfrecord"] and self.config.cache_latents_text_encoder_outputs:
9696
data_sharding = {
9797
"input_ids": data_sharding,
9898
"pixel_values": data_sharding,
@@ -188,11 +188,7 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da
188188
def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler):
189189

190190
writer = max_utils.initialize_summary_writer(self.config)
191-
writer_thread = threading.Thread(
192-
target=_tensorboard_writer_worker,
193-
args=(writer, self.config),
194-
daemon=True
195-
)
191+
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)
196192
writer_thread.start()
197193
unet_state = train_states["unet_state"]
198194
vae_state = train_states["vae_state"]
@@ -228,14 +224,14 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
228224
for step in np.arange(start_step, self.config.max_train_steps):
229225
if self.config.enable_profiler and step == first_profiling_step:
230226
max_utils.activate_profiler(self.config)
231-
227+
232228
next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config)
233229
start_step_time = datetime.datetime.now()
234230
with jax.profiler.StepTraceAnnotation("train-new", step_num=step):
235231
(unet_state, train_metric, train_rngs) = p_train_step(
236232
unet_state, vae_state, text_encoder_state, text_encoder_2_state, example_batch, train_rngs
237233
)
238-
train_metric['scalar']['learning/loss'].block_until_ready()
234+
train_metric["scalar"]["learning/loss"].block_until_ready()
239235
samples_count = self.total_train_batch_size * (step + 1)
240236
last_step_completion = datetime.datetime.now()
241237
time_difference = last_step_completion - start_step_time
@@ -247,7 +243,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
247243
if self.config.write_metrics:
248244
write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
249245
example_batch = next_batch_future.result()
250-
246+
251247
if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0:
252248
train_states["unet_state"] = unet_state
253249
train_states["vae_state"] = vae_state
@@ -265,7 +261,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
265261
_metrics_queue.put(None)
266262
writer_thread.join()
267263
if writer:
268-
writer.flush()
264+
writer.flush()
269265
train_states["unet_state"] = unet_state
270266
train_states["text_encoder_state"] = text_encoder_state
271267
train_states["text_encoder_2_state"] = text_encoder_2_state
@@ -369,5 +365,5 @@ def compute_loss(state_params):
369365
new_state = unet_state.apply_gradients(grads=grad["unet"])
370366

371367
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
372-
368+
373369
return new_state, metrics, new_train_rng

0 commit comments

Comments
 (0)