From 1203400f193aa178a1ab2cc624d9ae23fef50625 Mon Sep 17 00:00:00 2001 From: Kunjan Date: Thu, 12 Jun 2025 07:01:17 +0000 Subject: [PATCH 1/7] fix profiling --- src/maxdiffusion/trainers/sdxl_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index bae48a57a..f76134869 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -223,6 +223,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera (unet_state, train_metric, train_rngs) = p_train_step( unet_state, vae_state, text_encoder_state, text_encoder_2_state, example_batch, train_rngs ) + train_metric['scalar']['learning/loss'].block_until_ready() samples_count = self.total_train_batch_size * (step + 1) new_time = datetime.datetime.now() From 04735f4b6d7dea12407b92d447c2969f523af5e6 Mon Sep 17 00:00:00 2001 From: Kunjan Date: Mon, 9 Jun 2025 19:57:32 +0000 Subject: [PATCH 2/7] Use torch cpu, async write to tensorboard, script to convert latents to tfrecord, batch iterator for tfrecord cached, namedsharding instead of positional sharding Signed-off-by: Kunjan --- requirements.txt | 7 +- .../input_pipeline/_tfds_data_processing.py | 63 +++++++++++- .../dataset_tf_cache_to_tfrecord.py | 96 +++++++++++++++++++ src/maxdiffusion/train_utils.py | 45 ++++++--- src/maxdiffusion/trainers/flux_trainer.py | 2 +- src/maxdiffusion/trainers/sdxl_trainer.py | 88 ++++++++++------- 6 files changed, 246 insertions(+), 55 deletions(-) create mode 100644 src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py diff --git a/requirements.txt b/requirements.txt index e26b45b80..eeaf2c9e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -jax>=0.4.30 +--extra-index-url https://download.pytorch.org/whl/cpu +jax==0.5.3 jaxlib>=0.4.30 grain-nightly==0.0.10 google-cloud-storage==2.17.0 @@ -6,8 +7,8 @@ absl-py datasets flax>=0.10.2 optax>=0.2.3 -torch==2.5.1 -torchvision==0.20.1 +torch==2.6.0 +torchvision>=0.20.1 ftfy tensorboard>=2.17.0 tensorboardx==2.6.2.2 diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 89226f377..d5b774071 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -21,7 +21,7 @@ from maxdiffusion import multihost_dataloading -AUTOTUNE = tf.data.experimental.AUTOTUNE +AUTOTUNE = tf.data.AUTOTUNE def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count): @@ -31,7 +31,7 @@ def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_cou if shuffle: tf_dataset = tf_dataset.shuffle(len(tf_dataset)) tf_dataset = tf_dataset.batch(global_batch_size // dataloading_host_count, drop_remainder=True) - tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE) + tf_dataset = tf_dataset.prefetch(AUTOTUNE) tf_dataset = tf_dataset.repeat(-1) return tf_dataset @@ -74,6 +74,57 @@ def make_tf_iterator( return train_iter +def make_cached_tfrecord_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, +): + """ + New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings: + latents, input_ids, prompt_embeds, and text_embeds. + """ + feature_description = { + "pixel_values": tf.io.FixedLenFeature([], tf.string), + "input_ids": tf.io.FixedLenFeature([], tf.string), + "prompt_embeds": tf.io.FixedLenFeature([], tf.string), + "text_embeds": tf.io.FixedLenFeature([], tf.string), + } + + def _parse_tfrecord_fn(example): + return tf.io.parse_single_example(example, feature_description) + + def prepare_sample(features): + pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32) + input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32) + prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) + text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) + + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "prompt_embeds": prompt_embeds, + "text_embeds": text_embeds + } + + # This pipeline reads the sharded files and applies the parsing and preparation. + filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) + train_ds = ( + tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) + .shard(num_shards=dataloading_host_count, index=dataloading_host_index) + .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) + .map(prepare_sample, num_parallel_calls=AUTOTUNE) + .shuffle(global_batch_size * 10) + .batch(global_batch_size // dataloading_host_count, drop_remainder=True) + .repeat(-1) + .prefetch(AUTOTUNE) + ) + + # This wraps the tf.data.Dataset for use in the multi-host JAX environment. + train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) + return train_iter + # TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py def make_tfrecord_iterator( config, @@ -86,11 +137,15 @@ def make_tfrecord_iterator( check out preparation script maxdiffusion/pedagogical_examples/to_tfrecords.py """ + if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location): + return make_cached_tfrecord_iterator(config, dataloading_host_index, + dataloading_host_count, mesh, + global_batch_size) feature_description = { "moments": tf.io.FixedLenFeature([], tf.string), "clip_embeddings": tf.io.FixedLenFeature([], tf.string), } - + def _parse_tfrecord_fn(example): return tf.io.parse_single_example(example, feature_description) @@ -98,7 +153,7 @@ def prepare_sample(features): moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32) clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32) return {"pixel_values": moments, "input_ids": clip_embeddings} - + filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) train_ds = ( tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) diff --git a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py new file mode 100644 index 000000000..7012284e9 --- /dev/null +++ b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py @@ -0,0 +1,96 @@ +import os +import argparse +import tensorflow as tf +from datasets import load_from_disk +import numpy as np + +def _bytes_feature(value): + """Returns a bytes_list from a serialized tensor.""" + if not isinstance(value, tf.Tensor): + value = tf.constant(value) + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()])) + +def create_4_feature_example(record): + """Creates a tf.train.Example proto with all 4 pre-computed features.""" + pixel_values = tf.io.serialize_tensor(record['pixel_values']) + input_ids = tf.io.serialize_tensor(record['input_ids']) + prompt_embeds = tf.io.serialize_tensor(record['prompt_embeds']) + text_embeds = tf.io.serialize_tensor(record['text_embeds']) + + feature = { + "pixel_values": _bytes_feature(pixel_values), + "input_ids": _bytes_feature(input_ids), + "prompt_embeds": _bytes_feature(prompt_embeds), + "text_embeds": _bytes_feature(text_embeds) + } + return tf.train.Example(features=tf.train.Features(feature=feature)) + +def run(args): + """Main processing function.""" + # Load the cached dataset from the location specified in the arguments + print(f"Loading processed dataset from disk: {args.dataset_save_location}") + processed_ds = load_from_disk(args.dataset_save_location) + print("Dataset loaded successfully.") + + # Get sharding and output directory from the arguments + tfrecords_dir = args.tfrecords_dir + num_shards = args.data_num_shards + os.makedirs(tfrecords_dir, exist_ok=True) + + writers = [ + tf.io.TFRecordWriter(os.path.join(tfrecords_dir, f"shard-{i:05d}-of-{num_shards:05d}.tfrecord")) + for i in range(num_shards) + ] + + print(f"Writing {len(processed_ds)} records into {num_shards} TFRecord shards...") + + for i, record in enumerate(processed_ds): + # Create a new record with explicit casting for float types + casted_record = { + "pixel_values": np.float32(record['pixel_values']), + "input_ids": record['input_ids'], # This is already integer type + "prompt_embeds": np.float32(record['prompt_embeds']), + "text_embeds": np.float32(record['text_embeds']) + } + + writer_index = i % num_shards + tf_example = create_4_feature_example(casted_record) + writers[writer_index].write(tf_example.SerializeToString()) + + for writer in writers: + writer.close() + + print("TFRecord conversion complete.") + + +def main(): + """Parses command-line arguments and runs the conversion.""" + parser = argparse.ArgumentParser( + description="Convert a cached Hugging Face dataset to sharded TFRecords." + ) + parser.add_argument( + "--dataset_save_location", + type=str, + required=False, + default="/tmp/pokemon-gpt4-captions_xl", + help="Path to the cached dataset created by the training pipeline." + ) + parser.add_argument( + "--tfrecords_dir", + type=str, + required=False, + default="/tmp/cached_pokemon_tfrecords_sharded", + help="Output directory to save the sharded TFRecord files." + ) + parser.add_argument( + "--data_num_shards", + type=int, + default=128, + help="Number of shards to split the TFRecord dataset into." + ) + + args = parser.parse_args() + run(args) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index a040f85d2..41118cf96 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -17,6 +17,8 @@ import numpy as np import jax import jax.numpy as jnp +import threading +import queue from maxdiffusion import max_utils, max_logging @@ -67,10 +69,28 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()}) metrics["scalar"].update({"learning/current_learning_rate": lr}) - +_metrics_queue = queue.Queue() _buffered_step = None _buffered_metrics = None +def _tensorboard_writer_worker(writer, config): + """ + A worker function that runs in a separate thread. + It waits for metrics to appear in the queue and writes them to TensorBoard. + """ + while True: + data = _metrics_queue.get() + if data is None: + break + metrics, step = data + if jax.process_index() == 0: + for metric_name in metrics.get("scalar", []): + writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) + for metric_name in metrics.get("scalars", []): + writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) + + if step % config.log_period == 0: + writer.flush() def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config): """Entry point for all metrics writing in Train's Main. @@ -81,15 +101,18 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step The logic is that this ensures that Jax is able to queues train_steps and we don't block when turning "lazy" Jax arrays into real Python numbers. """ - global _buffered_step, _buffered_metrics + global _buffered_step, _buffered_metrics, _metrics_queue + if metrics: + _metrics_queue.put((metrics, step)) if _buffered_metrics is not None: + if config.metrics_file: + max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file) + if _buffered_step is None: raise ValueError(f"When writing metrics, {_buffered_step=} was none") write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config) - if config.metrics_file: - max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file) if config.gcs_metrics and jax.process_index() == 0: running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics) @@ -100,13 +123,6 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step def write_metrics_to_tensorboard(writer, metrics, step, config): """Writes metrics to tensorboard""" - if jax.process_index() == 0: - for metric_name in metrics.get("scalar", []): - writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) - for metric_name in metrics.get("scalars", []): - writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) - - full_log = step % config.log_period == 0 if jax.process_index() == 0: max_logging.log( "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( @@ -116,6 +132,13 @@ def write_metrics_to_tensorboard(writer, metrics, step, config): float(metrics["scalar"]["learning/loss"]), ) ) + if jax.process_index() == 0: + for metric_name in metrics.get("scalar", []): + writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) + for metric_name in metrics.get("scalars", []): + writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) + + full_log = step % config.log_period == 0 if full_log and jax.process_index() == 0: max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index b6a47c0d5..da6cbf45b 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -87,7 +87,7 @@ def start_training(self): state_shardings = {} # move params to accelerator - encoders_sharding = PositionalSharding(self.devices_array).replicate() + encoders_sharding = jax.NamedSharding(self.mesh, P(None)) partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding) pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params) pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params) diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index f76134869..1c46d0551 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -17,12 +17,15 @@ import os from functools import partial import datetime +import threading import time import numpy as np import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P from flax.linen import partitioning as nn_partitioning +from jax.experimental import io_callback +from concurrent.futures import ThreadPoolExecutor from maxdiffusion.trainers.stable_diffusion_trainer import (StableDiffusionTrainer) from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) @@ -30,12 +33,14 @@ from maxdiffusion import (max_utils, maxdiffusion_utils, max_logging) from maxdiffusion.train_utils import ( + _tensorboard_writer_worker, compute_snr, generate_timestep_weights, get_first_step, load_next_batch, record_scalar_metrics, write_metrics, + _metrics_queue ) from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (STABLE_DIFFUSION_XL_CHECKPOINT) @@ -62,7 +67,7 @@ def get_shaped_batch(self, config, pipeline): total_train_batch_size = config.total_train_batch_size shaped_batch = {} - if self.config.dataset_type == "tf" and self.config.cache_latents_text_encoder_outputs: + if self.config.dataset_type in ["tf","tfrecord"] and self.config.cache_latents_text_encoder_outputs: batch_image_shape = ( total_train_batch_size, pipeline.unet.config.in_channels, @@ -87,7 +92,7 @@ def get_shaped_batch(self, config, pipeline): def get_data_shardings(self): data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - if self.config.dataset_type == "tf" and self.config.cache_latents_text_encoder_outputs: + if self.config.dataset_type in ["tf","tfrecord"] and self.config.cache_latents_text_encoder_outputs: data_sharding = { "input_ids": data_sharding, "pixel_values": data_sharding, @@ -183,6 +188,12 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): writer = max_utils.initialize_summary_writer(self.config) + writer_thread = threading.Thread( + target=_tensorboard_writer_worker, + args=(writer, self.config), + daemon=True + ) + writer_thread.start() unet_state = train_states["unet_state"] vae_state = train_states["vae_state"] text_encoder_state = train_states["text_encoder_state"] @@ -212,44 +223,49 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera ) start_step = get_first_step(train_states["unet_state"]) _, train_rngs = jax.random.split(self.rng) - - for step in np.arange(start_step, self.config.max_train_steps): - if self.config.enable_profiler and step == first_profiling_step: - max_utils.activate_profiler(self.config) - - example_batch = load_next_batch(data_iterator, example_batch, self.config) - - with jax.profiler.StepTraceAnnotation("train", step_num=step): - (unet_state, train_metric, train_rngs) = p_train_step( - unet_state, vae_state, text_encoder_state, text_encoder_2_state, example_batch, train_rngs + example_batch = load_next_batch(data_iterator, None, self.config) + with ThreadPoolExecutor(max_workers=1) as executor: + for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + + next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config) + start_step_time = datetime.datetime.now() + with jax.profiler.StepTraceAnnotation("train-new", step_num=step): + (unet_state, train_metric, train_rngs) = p_train_step( + unet_state, vae_state, text_encoder_state, text_encoder_2_state, example_batch, train_rngs + ) + train_metric['scalar']['learning/loss'].block_until_ready() + samples_count = self.total_train_batch_size * (step + 1) + last_step_completion = datetime.datetime.now() + time_difference = last_step_completion - start_step_time + difference_in_ms = time_difference.total_seconds() * 1000 + max_logging.log(f"Step time {difference_in_ms}ms") + record_scalar_metrics( + train_metric, last_step_completion - start_step_time, self.per_device_tflops, unet_learning_rate_scheduler(step) ) - train_metric['scalar']['learning/loss'].block_until_ready() - - samples_count = self.total_train_batch_size * (step + 1) - new_time = datetime.datetime.now() - - record_scalar_metrics( - train_metric, new_time - last_step_completion, self.per_device_tflops, unet_learning_rate_scheduler(step) - ) - if self.config.write_metrics: - write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) - last_step_completion = new_time - - if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: - train_states["unet_state"] = unet_state - train_states["vae_state"] = vae_state - train_states["text_encoder_state"] = text_encoder_state - train_states["text_encoder_2_state"] = text_encoder_2_state - self.save_checkpoint(step, pipeline, params, train_states) - - if self.config.enable_profiler and step == last_profiling_step: - max_utils.deactivate_profiler(self.config) + if self.config.write_metrics: + write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + example_batch = next_batch_future.result() + + if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: + train_states["unet_state"] = unet_state + train_states["vae_state"] = vae_state + train_states["text_encoder_state"] = text_encoder_state + train_states["text_encoder_2_state"] = text_encoder_2_state + self.save_checkpoint(step, pipeline, params, train_states) + + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) if self.config.write_metrics: write_metrics( writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config ) - + _metrics_queue.put(None) + writer_thread.join() + if writer: + writer.flush() train_states["unet_state"] = unet_state train_states["text_encoder_state"] = text_encoder_state train_states["text_encoder_2_state"] = text_encoder_2_state @@ -267,7 +283,7 @@ def _train_step(unet_state, vae_state, text_encoder_state, text_encoder_2_state, state_params = {"unet": unet_state.params} def compute_loss(state_params): - if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs: + if config.dataset_type in ["tf", "tfrecord"] and config.cache_latents_text_encoder_outputs: latents = batch["pixel_values"] prompt_embeds = batch["prompt_embeds"] text_embeds = batch["text_embeds"] @@ -353,5 +369,5 @@ def compute_loss(state_params): new_state = unet_state.apply_gradients(grads=grad["unet"]) metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} - + return new_state, metrics, new_train_rng From d76d5e8a602d6d3a15d364f4a4233213a841305b Mon Sep 17 00:00:00 2001 From: Kunjan Date: Tue, 17 Jun 2025 13:33:32 +0000 Subject: [PATCH 3/7] Replace positional sharding with named sharding Signed-off-by: Kunjan --- src/maxdiffusion/generate_flux.py | 4 ++-- src/maxdiffusion/generate_flux_multi_res.py | 4 ++-- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 6 +++--- src/maxdiffusion/trainers/flux_trainer.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 615f3c241..0e6866346 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -23,7 +23,7 @@ import numpy as np from PIL import Image import jax -from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import jax.numpy as jnp import flax.linen as nn from chex import Array @@ -343,7 +343,7 @@ def run(config): config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True ) - encoders_sharding = PositionalSharding(devices_array).replicate() + encoders_sharding = NamedSharding(devices_array, P()) partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) diff --git a/src/maxdiffusion/generate_flux_multi_res.py b/src/maxdiffusion/generate_flux_multi_res.py index ed1baa67a..4c824db8b 100644 --- a/src/maxdiffusion/generate_flux_multi_res.py +++ b/src/maxdiffusion/generate_flux_multi_res.py @@ -23,7 +23,7 @@ import numpy as np from PIL import Image import jax -from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import jax.numpy as jnp import flax.linen as nn from chex import Array @@ -381,7 +381,7 @@ def run(config): config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True ) - encoders_sharding = PositionalSharding(devices_array).replicate() + encoders_sharding = NamedSharding(devices_array, P()) partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 8d9a2986b..cf97c890a 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -17,7 +17,7 @@ import numpy as np import jax import jax.numpy as jnp -from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import flax import flax.linen as nn from flax import nnx @@ -195,7 +195,7 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H # This replaces random params with the model. params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - params = jax.device_put(params, PositionalSharding(devices_array).replicate()) + params = jax.device_put(params, NamedSharding(devices_array, P())) wan_vae = nnx.merge(graphdef, params) p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) # Shard @@ -395,7 +395,7 @@ def __call__( num_channels_latents=num_channel_latents, ) - data_sharding = PositionalSharding(self.devices_array).replicate() + data_sharding = NamedSharding(self.devices_array, P()) if len(prompt) % jax.device_count() == 0: data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index da6cbf45b..32139faef 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -21,7 +21,7 @@ import numpy as np import jax import jax.numpy as jnp -from jax.sharding import PositionalSharding, PartitionSpec as P +from jax.sharding import NamedSharding, PartitionSpec as P from flax.linen import partitioning as nn_partitioning from maxdiffusion.checkpointing.flux_checkpointer import ( FluxCheckpointer, @@ -87,7 +87,7 @@ def start_training(self): state_shardings = {} # move params to accelerator - encoders_sharding = jax.NamedSharding(self.mesh, P(None)) + encoders_sharding = NamedSharding(self.mesh, P(None)) partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding) pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params) pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params) From b2956a77607745ac5f24b3120a30e757dc759550 Mon Sep 17 00:00:00 2001 From: Kunjan Date: Tue, 17 Jun 2025 13:34:39 +0000 Subject: [PATCH 4/7] Formatting Signed-off-by: Kunjan --- .../input_pipeline/_tfds_data_processing.py | 18 ++--- .../dataset_tf_cache_to_tfrecord.py | 79 +++++++++---------- src/maxdiffusion/train_utils.py | 40 +++++----- src/maxdiffusion/trainers/sdxl_trainer.py | 22 +++--- 4 files changed, 75 insertions(+), 84 deletions(-) diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index d5b774071..f81527bd9 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -100,13 +100,8 @@ def prepare_sample(features): input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32) prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) - - return { - "pixel_values": pixel_values, - "input_ids": input_ids, - "prompt_embeds": prompt_embeds, - "text_embeds": text_embeds - } + + return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds} # This pipeline reads the sharded files and applies the parsing and preparation. filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) @@ -125,6 +120,7 @@ def prepare_sample(features): train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter + # TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py def make_tfrecord_iterator( config, @@ -138,14 +134,12 @@ def make_tfrecord_iterator( maxdiffusion/pedagogical_examples/to_tfrecords.py """ if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location): - return make_cached_tfrecord_iterator(config, dataloading_host_index, - dataloading_host_count, mesh, - global_batch_size) + return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size) feature_description = { "moments": tf.io.FixedLenFeature([], tf.string), "clip_embeddings": tf.io.FixedLenFeature([], tf.string), } - + def _parse_tfrecord_fn(example): return tf.io.parse_single_example(example, feature_description) @@ -153,7 +147,7 @@ def prepare_sample(features): moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32) clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32) return {"pixel_values": moments, "input_ids": clip_embeddings} - + filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) train_ds = ( tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) diff --git a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py index 7012284e9..7db00f6d8 100644 --- a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py +++ b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py @@ -4,26 +4,29 @@ from datasets import load_from_disk import numpy as np + def _bytes_feature(value): """Returns a bytes_list from a serialized tensor.""" if not isinstance(value, tf.Tensor): - value = tf.constant(value) + value = tf.constant(value) return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()])) + def create_4_feature_example(record): - """Creates a tf.train.Example proto with all 4 pre-computed features.""" - pixel_values = tf.io.serialize_tensor(record['pixel_values']) - input_ids = tf.io.serialize_tensor(record['input_ids']) - prompt_embeds = tf.io.serialize_tensor(record['prompt_embeds']) - text_embeds = tf.io.serialize_tensor(record['text_embeds']) - - feature = { - "pixel_values": _bytes_feature(pixel_values), - "input_ids": _bytes_feature(input_ids), - "prompt_embeds": _bytes_feature(prompt_embeds), - "text_embeds": _bytes_feature(text_embeds) - } - return tf.train.Example(features=tf.train.Features(feature=feature)) + """Creates a tf.train.Example proto with all 4 pre-computed features.""" + pixel_values = tf.io.serialize_tensor(record["pixel_values"]) + input_ids = tf.io.serialize_tensor(record["input_ids"]) + prompt_embeds = tf.io.serialize_tensor(record["prompt_embeds"]) + text_embeds = tf.io.serialize_tensor(record["text_embeds"]) + + feature = { + "pixel_values": _bytes_feature(pixel_values), + "input_ids": _bytes_feature(input_ids), + "prompt_embeds": _bytes_feature(prompt_embeds), + "text_embeds": _bytes_feature(text_embeds), + } + return tf.train.Example(features=tf.train.Features(feature=feature)) + def run(args): """Main processing function.""" @@ -41,56 +44,52 @@ def run(args): tf.io.TFRecordWriter(os.path.join(tfrecords_dir, f"shard-{i:05d}-of-{num_shards:05d}.tfrecord")) for i in range(num_shards) ] - + print(f"Writing {len(processed_ds)} records into {num_shards} TFRecord shards...") - + for i, record in enumerate(processed_ds): - # Create a new record with explicit casting for float types - casted_record = { - "pixel_values": np.float32(record['pixel_values']), - "input_ids": record['input_ids'], # This is already integer type - "prompt_embeds": np.float32(record['prompt_embeds']), - "text_embeds": np.float32(record['text_embeds']) - } - - writer_index = i % num_shards - tf_example = create_4_feature_example(casted_record) - writers[writer_index].write(tf_example.SerializeToString()) + # Create a new record with explicit casting for float types + casted_record = { + "pixel_values": np.float32(record["pixel_values"]), + "input_ids": record["input_ids"], # This is already integer type + "prompt_embeds": np.float32(record["prompt_embeds"]), + "text_embeds": np.float32(record["text_embeds"]), + } + + writer_index = i % num_shards + tf_example = create_4_feature_example(casted_record) + writers[writer_index].write(tf_example.SerializeToString()) for writer in writers: - writer.close() - + writer.close() + print("TFRecord conversion complete.") def main(): """Parses command-line arguments and runs the conversion.""" - parser = argparse.ArgumentParser( - description="Convert a cached Hugging Face dataset to sharded TFRecords." - ) + parser = argparse.ArgumentParser(description="Convert a cached Hugging Face dataset to sharded TFRecords.") parser.add_argument( "--dataset_save_location", type=str, required=False, default="/tmp/pokemon-gpt4-captions_xl", - help="Path to the cached dataset created by the training pipeline." + help="Path to the cached dataset created by the training pipeline.", ) parser.add_argument( "--tfrecords_dir", type=str, required=False, default="/tmp/cached_pokemon_tfrecords_sharded", - help="Output directory to save the sharded TFRecord files." + help="Output directory to save the sharded TFRecord files.", ) parser.add_argument( - "--data_num_shards", - type=int, - default=128, - help="Number of shards to split the TFRecord dataset into." + "--data_num_shards", type=int, default=128, help="Number of shards to split the TFRecord dataset into." ) - + args = parser.parse_args() run(args) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 41118cf96..455ef31f2 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -69,28 +69,31 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()}) metrics["scalar"].update({"learning/current_learning_rate": lr}) + _metrics_queue = queue.Queue() _buffered_step = None _buffered_metrics = None + def _tensorboard_writer_worker(writer, config): - """ - A worker function that runs in a separate thread. - It waits for metrics to appear in the queue and writes them to TensorBoard. - """ - while True: - data = _metrics_queue.get() - if data is None: - break - metrics, step = data - if jax.process_index() == 0: - for metric_name in metrics.get("scalar", []): - writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) - for metric_name in metrics.get("scalars", []): - writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) - - if step % config.log_period == 0: - writer.flush() + """ + A worker function that runs in a separate thread. + It waits for metrics to appear in the queue and writes them to TensorBoard. + """ + while True: + data = _metrics_queue.get() + if data is None: + break + metrics, step = data + if jax.process_index() == 0: + for metric_name in metrics.get("scalar", []): + writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) + for metric_name in metrics.get("scalars", []): + writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) + + if step % config.log_period == 0: + writer.flush() + def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config): """Entry point for all metrics writing in Train's Main. @@ -108,12 +111,11 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step if _buffered_metrics is not None: if config.metrics_file: max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file) - + if _buffered_step is None: raise ValueError(f"When writing metrics, {_buffered_step=} was none") write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config) - if config.gcs_metrics and jax.process_index() == 0: running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics) diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index 1c46d0551..6daa1cc21 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -40,7 +40,7 @@ load_next_batch, record_scalar_metrics, write_metrics, - _metrics_queue + _metrics_queue, ) from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (STABLE_DIFFUSION_XL_CHECKPOINT) @@ -67,7 +67,7 @@ def get_shaped_batch(self, config, pipeline): total_train_batch_size = config.total_train_batch_size shaped_batch = {} - if self.config.dataset_type in ["tf","tfrecord"] and self.config.cache_latents_text_encoder_outputs: + if self.config.dataset_type in ["tf", "tfrecord"] and self.config.cache_latents_text_encoder_outputs: batch_image_shape = ( total_train_batch_size, pipeline.unet.config.in_channels, @@ -92,7 +92,7 @@ def get_shaped_batch(self, config, pipeline): def get_data_shardings(self): data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - if self.config.dataset_type in ["tf","tfrecord"] and self.config.cache_latents_text_encoder_outputs: + if self.config.dataset_type in ["tf", "tfrecord"] and self.config.cache_latents_text_encoder_outputs: data_sharding = { "input_ids": data_sharding, "pixel_values": data_sharding, @@ -188,11 +188,7 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): writer = max_utils.initialize_summary_writer(self.config) - writer_thread = threading.Thread( - target=_tensorboard_writer_worker, - args=(writer, self.config), - daemon=True - ) + writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) writer_thread.start() unet_state = train_states["unet_state"] vae_state = train_states["vae_state"] @@ -228,14 +224,14 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera for step in np.arange(start_step, self.config.max_train_steps): if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) - + next_batch_future = executor.submit(load_next_batch, data_iterator, example_batch, self.config) start_step_time = datetime.datetime.now() with jax.profiler.StepTraceAnnotation("train-new", step_num=step): (unet_state, train_metric, train_rngs) = p_train_step( unet_state, vae_state, text_encoder_state, text_encoder_2_state, example_batch, train_rngs ) - train_metric['scalar']['learning/loss'].block_until_ready() + train_metric["scalar"]["learning/loss"].block_until_ready() samples_count = self.total_train_batch_size * (step + 1) last_step_completion = datetime.datetime.now() time_difference = last_step_completion - start_step_time @@ -247,7 +243,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera if self.config.write_metrics: write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) example_batch = next_batch_future.result() - + if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: train_states["unet_state"] = unet_state train_states["vae_state"] = vae_state @@ -265,7 +261,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera _metrics_queue.put(None) writer_thread.join() if writer: - writer.flush() + writer.flush() train_states["unet_state"] = unet_state train_states["text_encoder_state"] = text_encoder_state train_states["text_encoder_2_state"] = text_encoder_2_state @@ -369,5 +365,5 @@ def compute_loss(state_params): new_state = unet_state.apply_gradients(grads=grad["unet"]) metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} - + return new_state, metrics, new_train_rng From f418a0cfa3d8668b7b68d579c07c13a2a5ac159d Mon Sep 17 00:00:00 2001 From: Kunjan Date: Tue, 17 Jun 2025 16:12:50 +0000 Subject: [PATCH 5/7] Formatting Signed-off-by: Kunjan --- src/maxdiffusion/train_utils.py | 1 - src/maxdiffusion/trainers/sdxl_trainer.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 455ef31f2..e3e75971c 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -17,7 +17,6 @@ import numpy as np import jax import jax.numpy as jnp -import threading import queue from maxdiffusion import max_utils, max_logging diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index 6daa1cc21..4cc81955b 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -24,7 +24,6 @@ import jax.numpy as jnp from jax.sharding import PartitionSpec as P from flax.linen import partitioning as nn_partitioning -from jax.experimental import io_callback from concurrent.futures import ThreadPoolExecutor from maxdiffusion.trainers.stable_diffusion_trainer import (StableDiffusionTrainer) From 02ca045fc0d66b89deb750004ea85c3b4d16af7d Mon Sep 17 00:00:00 2001 From: Kunjan Date: Wed, 18 Jun 2025 14:18:18 +0000 Subject: [PATCH 6/7] Fallback to regular tfrecord iterator for datasets without all the processed features Signed-off-by: Kunjan --- .../input_pipeline/_tfds_data_processing.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index f81527bd9..73b4f7016 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -73,7 +73,6 @@ def make_tf_iterator( train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter - def make_cached_tfrecord_iterator( config, dataloading_host_index, @@ -105,6 +104,7 @@ def prepare_sample(features): # This pipeline reads the sharded files and applies the parsing and preparation. filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) + train_ds = ( tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) .shard(num_shards=dataloading_host_count, index=dataloading_host_index) @@ -133,8 +133,16 @@ def make_tfrecord_iterator( check out preparation script maxdiffusion/pedagogical_examples/to_tfrecords.py """ - if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location): + + # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. + # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. + # Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. + if (config.cache_latents_text_encoder_outputs and + os.path.isdir(config.dataset_save_location) and + hasattr(config, 'load_tfrecord_cached') and + config.load_tfrecord_cached): return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size) + feature_description = { "moments": tf.io.FixedLenFeature([], tf.string), "clip_embeddings": tf.io.FixedLenFeature([], tf.string), From fe888893a57e55445d79f9e8dd3e9ab74d13a9aa Mon Sep 17 00:00:00 2001 From: Kunjan Date: Wed, 18 Jun 2025 17:33:10 +0000 Subject: [PATCH 7/7] README update --- .github/workflows/UnitTests.yml | 2 +- README.md | 2 +- src/maxdiffusion/input_pipeline/_tfds_data_processing.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 73dad0756..728d2f2e3 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -50,7 +50,7 @@ jobs: ruff check . - name: PyTest run: | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest + HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/README.md b/README.md index e14603ac4..68b887f9f 100644 --- a/README.md +++ b/README.md @@ -35,13 +35,13 @@ MaxDiffusion supports * Stable Diffusion 2 base (training and inference) * Stable Diffusion 2.1 (training and inference) * Stable Diffusion XL (training and inference). +* Flux Dev and Schnell (Training and inference). * Stable Diffusion Lightning (inference). * Hyper-SD XL LoRA loading (inference). * Load Multiple LoRA (SDXL inference). * ControlNet inference (Stable Diffusion 1.4 & SDXL). * Dreambooth training support for Stable Diffusion 1.x,2.x. -**WARNING: The training code is purely experimental and is under development.** # Table of Contents diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index 73b4f7016..6b588ed2d 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -137,10 +137,10 @@ def make_tfrecord_iterator( # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. # Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. - if (config.cache_latents_text_encoder_outputs and - os.path.isdir(config.dataset_save_location) and - hasattr(config, 'load_tfrecord_cached') and - config.load_tfrecord_cached): + if (config.cache_latents_text_encoder_outputs + and os.path.isdir(config.dataset_save_location) + and 'load_tfrecord_cached'in config.get_keys() + and config.load_tfrecord_cached): return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size) feature_description = {