Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
ruff check .
- name: PyTest
run: |
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ MaxDiffusion supports
* Stable Diffusion 2 base (training and inference)
* Stable Diffusion 2.1 (training and inference)
* Stable Diffusion XL (training and inference).
* Flux Dev and Schnell (Training and inference).
* Stable Diffusion Lightning (inference).
* Hyper-SD XL LoRA loading (inference).
* Load Multiple LoRA (SDXL inference).
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
* Dreambooth training support for Stable Diffusion 1.x,2.x.

**WARNING: The training code is purely experimental and is under development.**

# Table of Contents

Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
jax>=0.4.30
--extra-index-url https://download.pytorch.org/whl/cpu
jax==0.5.3
jaxlib>=0.4.30
grain-nightly==0.0.10
google-cloud-storage==2.17.0
absl-py
datasets
flax>=0.10.2
optax>=0.2.3
torch==2.5.1
torchvision==0.20.1
torch==2.6.0
torchvision>=0.20.1
ftfy
tensorboard>=2.17.0
tensorboardx==2.6.2.2
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
from PIL import Image
import jax
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import jax.numpy as jnp
import flax.linen as nn
from chex import Array
Expand Down Expand Up @@ -343,7 +343,7 @@ def run(config):
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
)

encoders_sharding = PositionalSharding(devices_array).replicate()
encoders_sharding = NamedSharding(devices_array, P())
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/generate_flux_multi_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
from PIL import Image
import jax
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import jax.numpy as jnp
import flax.linen as nn
from chex import Array
Expand Down Expand Up @@ -381,7 +381,7 @@ def run(config):
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
)

encoders_sharding = PositionalSharding(devices_array).replicate()
encoders_sharding = NamedSharding(devices_array, P())
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)
Expand Down
61 changes: 59 additions & 2 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from maxdiffusion import multihost_dataloading

AUTOTUNE = tf.data.experimental.AUTOTUNE
AUTOTUNE = tf.data.AUTOTUNE


def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
Expand All @@ -31,7 +31,7 @@ def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_cou
if shuffle:
tf_dataset = tf_dataset.shuffle(len(tf_dataset))
tf_dataset = tf_dataset.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)
tf_dataset = tf_dataset.prefetch(AUTOTUNE)
tf_dataset = tf_dataset.repeat(-1)

return tf_dataset
Expand Down Expand Up @@ -73,6 +73,53 @@ def make_tf_iterator(
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter

def make_cached_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
):
"""
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
latents, input_ids, prompt_embeds, and text_embeds.
"""
feature_description = {
"pixel_values": tf.io.FixedLenFeature([], tf.string),
"input_ids": tf.io.FixedLenFeature([], tf.string),
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
"text_embeds": tf.io.FixedLenFeature([], tf.string),
}

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)

def prepare_sample(features):
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)

return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}

# This pipeline reads the sharded files and applies the parsing and preparation.
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))

train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)

# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter


# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
def make_tfrecord_iterator(
Expand All @@ -86,6 +133,16 @@ def make_tfrecord_iterator(
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""

# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
if (config.cache_latents_text_encoder_outputs
and os.path.isdir(config.dataset_save_location)
and 'load_tfrecord_cached'in config.get_keys()
and config.load_tfrecord_cached):
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)

feature_description = {
"moments": tf.io.FixedLenFeature([], tf.string),
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import argparse
import tensorflow as tf
from datasets import load_from_disk
import numpy as np


def _bytes_feature(value):
"""Returns a bytes_list from a serialized tensor."""
if not isinstance(value, tf.Tensor):
value = tf.constant(value)
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))


def create_4_feature_example(record):
"""Creates a tf.train.Example proto with all 4 pre-computed features."""
pixel_values = tf.io.serialize_tensor(record["pixel_values"])
input_ids = tf.io.serialize_tensor(record["input_ids"])
prompt_embeds = tf.io.serialize_tensor(record["prompt_embeds"])
text_embeds = tf.io.serialize_tensor(record["text_embeds"])

feature = {
"pixel_values": _bytes_feature(pixel_values),
"input_ids": _bytes_feature(input_ids),
"prompt_embeds": _bytes_feature(prompt_embeds),
"text_embeds": _bytes_feature(text_embeds),
}
return tf.train.Example(features=tf.train.Features(feature=feature))


def run(args):
"""Main processing function."""
# Load the cached dataset from the location specified in the arguments
print(f"Loading processed dataset from disk: {args.dataset_save_location}")
processed_ds = load_from_disk(args.dataset_save_location)
print("Dataset loaded successfully.")

# Get sharding and output directory from the arguments
tfrecords_dir = args.tfrecords_dir
num_shards = args.data_num_shards
os.makedirs(tfrecords_dir, exist_ok=True)

writers = [
tf.io.TFRecordWriter(os.path.join(tfrecords_dir, f"shard-{i:05d}-of-{num_shards:05d}.tfrecord"))
for i in range(num_shards)
]

print(f"Writing {len(processed_ds)} records into {num_shards} TFRecord shards...")

for i, record in enumerate(processed_ds):
# Create a new record with explicit casting for float types
casted_record = {
"pixel_values": np.float32(record["pixel_values"]),
"input_ids": record["input_ids"], # This is already integer type
"prompt_embeds": np.float32(record["prompt_embeds"]),
"text_embeds": np.float32(record["text_embeds"]),
}

writer_index = i % num_shards
tf_example = create_4_feature_example(casted_record)
writers[writer_index].write(tf_example.SerializeToString())

for writer in writers:
writer.close()

print("TFRecord conversion complete.")


def main():
"""Parses command-line arguments and runs the conversion."""
parser = argparse.ArgumentParser(description="Convert a cached Hugging Face dataset to sharded TFRecords.")
parser.add_argument(
"--dataset_save_location",
type=str,
required=False,
default="/tmp/pokemon-gpt4-captions_xl",
help="Path to the cached dataset created by the training pipeline.",
)
parser.add_argument(
"--tfrecords_dir",
type=str,
required=False,
default="/tmp/cached_pokemon_tfrecords_sharded",
help="Output directory to save the sharded TFRecord files.",
)
parser.add_argument(
"--data_num_shards", type=int, default=128, help="Number of shards to split the TFRecord dataset into."
)

args = parser.parse_args()
run(args)


if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import flax
import flax.linen as nn
from flax import nnx
Expand Down Expand Up @@ -195,7 +195,7 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
# This replaces random params with the model.
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
params = jax.device_put(params, PositionalSharding(devices_array).replicate())
params = jax.device_put(params, NamedSharding(devices_array, P()))
wan_vae = nnx.merge(graphdef, params)
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
# Shard
Expand Down Expand Up @@ -395,7 +395,7 @@ def __call__(
num_channels_latents=num_channel_latents,
)

data_sharding = PositionalSharding(self.devices_array).replicate()
data_sharding = NamedSharding(self.devices_array, P())
if len(prompt) % jax.device_count() == 0:
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))

Expand Down
46 changes: 35 additions & 11 deletions src/maxdiffusion/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import jax
import jax.numpy as jnp
import queue

from maxdiffusion import max_utils, max_logging

Expand Down Expand Up @@ -68,10 +69,31 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
metrics["scalar"].update({"learning/current_learning_rate": lr})


_metrics_queue = queue.Queue()
_buffered_step = None
_buffered_metrics = None


def _tensorboard_writer_worker(writer, config):
"""
A worker function that runs in a separate thread.
It waits for metrics to appear in the queue and writes them to TensorBoard.
"""
while True:
data = _metrics_queue.get()
if data is None:
break
metrics, step = data
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

if step % config.log_period == 0:
writer.flush()


def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
"""Entry point for all metrics writing in Train's Main.
TODO: would be better as a Class in the future (that initialized all state!)
Expand All @@ -81,16 +103,18 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
global _buffered_step, _buffered_metrics
global _buffered_step, _buffered_metrics, _metrics_queue

if metrics:
_metrics_queue.put((metrics, step))
if _buffered_metrics is not None:
if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)

if _buffered_step is None:
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)

if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)

if config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)

Expand All @@ -100,13 +124,6 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step

def write_metrics_to_tensorboard(writer, metrics, step, config):
"""Writes metrics to tensorboard"""
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0
if jax.process_index() == 0:
max_logging.log(
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
Expand All @@ -116,6 +133,13 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
float(metrics["scalar"]["learning/loss"]),
)
)
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/trainers/flux_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PositionalSharding, PartitionSpec as P
from jax.sharding import NamedSharding, PartitionSpec as P
from flax.linen import partitioning as nn_partitioning
from maxdiffusion.checkpointing.flux_checkpointer import (
FluxCheckpointer,
Expand Down Expand Up @@ -87,7 +87,7 @@ def start_training(self):
state_shardings = {}

# move params to accelerator
encoders_sharding = PositionalSharding(self.devices_array).replicate()
encoders_sharding = NamedSharding(self.mesh, P(None))
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
Expand Down
Loading
Loading