Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
jax>=0.4.30
--extra-index-url https://download.pytorch.org/whl/cpu
jax==0.5.3
jaxlib>=0.4.30
grain-nightly==0.0.10
google-cloud-storage==2.17.0
absl-py
datasets
flax>=0.10.2
optax>=0.2.3
torch==2.5.1
torchvision==0.20.1
torch==2.6.0
torchvision>=0.20.1
ftfy
tensorboard>=2.17.0
tensorboardx==2.6.2.2
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
from PIL import Image
import jax
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import jax.numpy as jnp
import flax.linen as nn
from chex import Array
Expand Down Expand Up @@ -343,7 +343,7 @@ def run(config):
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
)

encoders_sharding = PositionalSharding(devices_array).replicate()
encoders_sharding = NamedSharding(devices_array, P())
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/generate_flux_multi_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
from PIL import Image
import jax
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import jax.numpy as jnp
import flax.linen as nn
from chex import Array
Expand Down Expand Up @@ -381,7 +381,7 @@ def run(config):
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
)

encoders_sharding = PositionalSharding(devices_array).replicate()
encoders_sharding = NamedSharding(devices_array, P())
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
Expand Down
53 changes: 51 additions & 2 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from maxdiffusion import multihost_dataloading

AUTOTUNE = tf.data.experimental.AUTOTUNE
AUTOTUNE = tf.data.AUTOTUNE


def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
Expand All @@ -31,7 +31,7 @@ def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_cou
if shuffle:
tf_dataset = tf_dataset.shuffle(len(tf_dataset))
tf_dataset = tf_dataset.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)
tf_dataset = tf_dataset.prefetch(AUTOTUNE)
tf_dataset = tf_dataset.repeat(-1)

return tf_dataset
Expand Down Expand Up @@ -74,6 +74,53 @@ def make_tf_iterator(
return train_iter


def make_cached_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
):
"""
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
latents, input_ids, prompt_embeds, and text_embeds.
"""
feature_description = {
"pixel_values": tf.io.FixedLenFeature([], tf.string),
"input_ids": tf.io.FixedLenFeature([], tf.string),
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
"text_embeds": tf.io.FixedLenFeature([], tf.string),
}

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)

def prepare_sample(features):
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)

return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}

# This pipeline reads the sharded files and applies the parsing and preparation.
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)

# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter


# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
def make_tfrecord_iterator(
config,
Expand All @@ -86,6 +133,8 @@ def make_tfrecord_iterator(
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""
if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location):
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
feature_description = {
"moments": tf.io.FixedLenFeature([], tf.string),
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import argparse
import tensorflow as tf
from datasets import load_from_disk
import numpy as np


def _bytes_feature(value):
"""Returns a bytes_list from a serialized tensor."""
if not isinstance(value, tf.Tensor):
value = tf.constant(value)
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))


def create_4_feature_example(record):
"""Creates a tf.train.Example proto with all 4 pre-computed features."""
pixel_values = tf.io.serialize_tensor(record["pixel_values"])
input_ids = tf.io.serialize_tensor(record["input_ids"])
prompt_embeds = tf.io.serialize_tensor(record["prompt_embeds"])
text_embeds = tf.io.serialize_tensor(record["text_embeds"])

feature = {
"pixel_values": _bytes_feature(pixel_values),
"input_ids": _bytes_feature(input_ids),
"prompt_embeds": _bytes_feature(prompt_embeds),
"text_embeds": _bytes_feature(text_embeds),
}
return tf.train.Example(features=tf.train.Features(feature=feature))


def run(args):
"""Main processing function."""
# Load the cached dataset from the location specified in the arguments
print(f"Loading processed dataset from disk: {args.dataset_save_location}")
processed_ds = load_from_disk(args.dataset_save_location)
print("Dataset loaded successfully.")

# Get sharding and output directory from the arguments
tfrecords_dir = args.tfrecords_dir
num_shards = args.data_num_shards
os.makedirs(tfrecords_dir, exist_ok=True)

writers = [
tf.io.TFRecordWriter(os.path.join(tfrecords_dir, f"shard-{i:05d}-of-{num_shards:05d}.tfrecord"))
for i in range(num_shards)
]

print(f"Writing {len(processed_ds)} records into {num_shards} TFRecord shards...")

for i, record in enumerate(processed_ds):
# Create a new record with explicit casting for float types
casted_record = {
"pixel_values": np.float32(record["pixel_values"]),
"input_ids": record["input_ids"], # This is already integer type
"prompt_embeds": np.float32(record["prompt_embeds"]),
"text_embeds": np.float32(record["text_embeds"]),
}

writer_index = i % num_shards
tf_example = create_4_feature_example(casted_record)
writers[writer_index].write(tf_example.SerializeToString())

for writer in writers:
writer.close()

print("TFRecord conversion complete.")


def main():
"""Parses command-line arguments and runs the conversion."""
parser = argparse.ArgumentParser(description="Convert a cached Hugging Face dataset to sharded TFRecords.")
parser.add_argument(
"--dataset_save_location",
type=str,
required=False,
default="/tmp/pokemon-gpt4-captions_xl",
help="Path to the cached dataset created by the training pipeline.",
)
parser.add_argument(
"--tfrecords_dir",
type=str,
required=False,
default="/tmp/cached_pokemon_tfrecords_sharded",
help="Output directory to save the sharded TFRecord files.",
)
parser.add_argument(
"--data_num_shards", type=int, default=128, help="Number of shards to split the TFRecord dataset into."
)

args = parser.parse_args()
run(args)


if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import flax
import flax.linen as nn
from flax import nnx
Expand Down Expand Up @@ -195,7 +195,7 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
# This replaces random params with the model.
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
params = jax.device_put(params, PositionalSharding(devices_array).replicate())
params = jax.device_put(params, NamedSharding(devices_array, P()))
wan_vae = nnx.merge(graphdef, params)
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
# Shard
Expand Down Expand Up @@ -395,7 +395,7 @@ def __call__(
num_channels_latents=num_channel_latents,
)

data_sharding = PositionalSharding(self.devices_array).replicate()
data_sharding = NamedSharding(self.devices_array, P())
if len(prompt) % jax.device_count() == 0:
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))

Expand Down
46 changes: 35 additions & 11 deletions src/maxdiffusion/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import jax
import jax.numpy as jnp
import queue

from maxdiffusion import max_utils, max_logging

Expand Down Expand Up @@ -68,10 +69,31 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
metrics["scalar"].update({"learning/current_learning_rate": lr})


_metrics_queue = queue.Queue()
_buffered_step = None
_buffered_metrics = None


def _tensorboard_writer_worker(writer, config):
"""
A worker function that runs in a separate thread.
It waits for metrics to appear in the queue and writes them to TensorBoard.
"""
while True:
data = _metrics_queue.get()
if data is None:
break
metrics, step = data
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

if step % config.log_period == 0:
writer.flush()


def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
"""Entry point for all metrics writing in Train's Main.
TODO: would be better as a Class in the future (that initialized all state!)
Expand All @@ -81,16 +103,18 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
global _buffered_step, _buffered_metrics
global _buffered_step, _buffered_metrics, _metrics_queue

if metrics:
_metrics_queue.put((metrics, step))
if _buffered_metrics is not None:
if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)

if _buffered_step is None:
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)

if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)

if config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)

Expand All @@ -100,13 +124,6 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step

def write_metrics_to_tensorboard(writer, metrics, step, config):
"""Writes metrics to tensorboard"""
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0
if jax.process_index() == 0:
max_logging.log(
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
Expand All @@ -116,6 +133,13 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
float(metrics["scalar"]["learning/loss"]),
)
)
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/trainers/flux_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PositionalSharding, PartitionSpec as P
from jax.sharding import NamedSharding, PartitionSpec as P
from flax.linen import partitioning as nn_partitioning
from maxdiffusion.checkpointing.flux_checkpointer import (
FluxCheckpointer,
Expand Down Expand Up @@ -87,7 +87,7 @@ def start_training(self):
state_shardings = {}

# move params to accelerator
encoders_sharding = PositionalSharding(self.devices_array).replicate()
encoders_sharding = NamedSharding(self.mesh, P(None))
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
Expand Down
Loading
Loading