Skip to content

Commit 9238bc9

Browse files
authored
enable grain and tfrecord for mlperf dataset in SD base_2_base training (#130)
1 parent 51e1db1 commit 9238bc9

9 files changed

Lines changed: 242 additions & 19 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ jobs:
3939
- name: Install dependencies
4040
run: |
4141
pip install -e .
42-
pip install -U -r requirements.txt
43-
export PATH=$PATH:$HOME/.local/bin
4442
pip uninstall jax jaxlib libtpu-nightly libtpu -y
45-
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
43+
bash setup.sh MODE=stable
44+
export PATH=$PATH:$HOME/.local/bin
4645
pip install ruff
4746
pip install isort
4847
pip install pytest

setup.sh

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@
2323
set -e
2424
export DEBIAN_FRONTEND=noninteractive
2525

26+
(sudo bash || bash) <<'EOF'
27+
apt update && \
28+
apt install -y numactl lsb-release gnupg curl net-tools iproute2 procps lsof git ethtool && \
29+
export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
30+
echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list
31+
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
32+
apt update -y && apt -y install gcsfuse
33+
rm -rf /var/lib/apt/lists/*
34+
EOF
35+
2636
# Set environment variables from command line arguments
2737
for ARGUMENT in "$@"; do
2838
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
@@ -97,4 +107,4 @@ else
97107
fi
98108

99109
# Install dependencies from requirements.txt
100-
pip3 install -U -r requirements.txt
110+
pip3 install -U -r requirements.txt

setup_gcsfuse.sh

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/bin/bash
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# Description:
18+
# bash setup_gcsfuse.sh DATASET_GCS_BUCKET=maxdiffusion-github-runner-test-assets MOUNT_PATH=/tmp/gcsfuse
19+
20+
set -e -x
21+
22+
# Set environment variables
23+
for ARGUMENT in "$@"; do
24+
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
25+
export "$KEY"="$VALUE"
26+
echo "$KEY"="$VALUE"
27+
done
28+
29+
if [[ -z ${DATASET_GCS_BUCKET} || -z ${MOUNT_PATH} ]]; then
30+
echo "Please set arguments: DATASET_GCS_BUCKET and MOUNT_PATH"
31+
exit 1
32+
fi
33+
34+
if [[ "$DATASET_GCS_BUCKET" =~ gs:\/\/ ]] ; then
35+
DATASET_GCS_BUCKET="${DATASET_GCS_BUCKET/gs:\/\//}"
36+
echo "Removed gs:// from GCS bucket name, GCS bucket is $DATASET_GCS_BUCKET"
37+
fi
38+
39+
if [[ -d ${MOUNT_PATH} ]]; then
40+
echo "$MOUNT_PATH exists, removing..."
41+
fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH
42+
fi
43+
44+
mkdir -p $MOUNT_PATH
45+
46+
# see https://cloud.google.com/storage/docs/gcsfuse-cli for all configurable options of gcsfuse CLI
47+
# Grain uses _PROCESS_MANAGEMENT_MAX_THREADS = 64 (https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)
48+
# Please make sure max-conns-per-host > grain_worker_count * _PROCESS_MANAGEMENT_MAX_THREADS
49+
50+
gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=2000 \
51+
--debug_fuse_errors --debug_fuse --debug_gcs --debug_invariants --debug_mutex \
52+
--log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH"

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ jax_cache_dir: ''
158158
hf_data_dir: ''
159159
hf_train_files: ''
160160
hf_access_token: ''
161+
grain_train_files: ''
162+
grain_worker_count: 4
161163
image_column: 'image'
162164
caption_column: 'text'
163165
resolution: 512
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import dataclasses
18+
import glob
19+
import tensorflow as tf
20+
import numpy as np
21+
import grain.python as grain
22+
23+
from maxdiffusion import multihost_dataloading
24+
25+
26+
def make_grain_iterator(
27+
config,
28+
dataloading_host_index,
29+
dataloading_host_count,
30+
mesh,
31+
global_batch_size,
32+
):
33+
"""Use Grain data input pipeline with ArrayRecord data format"""
34+
data_files = glob.glob(config.grain_train_files)
35+
data_source = grain.ArrayRecordDataSource(data_files)
36+
37+
operations = []
38+
operations.append(ParseFeatures())
39+
operations.append(grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True))
40+
41+
index_sampler = grain.IndexSampler(
42+
num_records=len(data_source),
43+
num_epochs=None,
44+
shard_options=grain.ShardOptions(
45+
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=True
46+
),
47+
shuffle=True,
48+
seed=config.seed,
49+
)
50+
51+
dataloader = grain.DataLoader(
52+
data_source=data_source,
53+
operations=operations,
54+
sampler=index_sampler,
55+
worker_count=config.grain_worker_count,
56+
)
57+
58+
data_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh)
59+
return data_iter
60+
61+
62+
@dataclasses.dataclass
63+
class ParseFeatures(grain.MapTransform):
64+
"""Parse serialized example"""
65+
66+
def __init__(self):
67+
self.feature_description = {
68+
"moments": tf.io.FixedLenFeature([], tf.string),
69+
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
70+
}
71+
72+
def map(self, example):
73+
def _parse(example):
74+
features = tf.io.parse_single_example(example, self.feature_description)
75+
moments = tf.io.parse_tensor(np.asarray(features["moments"]), out_type=tf.float32)
76+
clip_embeddings = tf.io.parse_tensor(np.asarray(features["clip_embeddings"]), out_type=tf.float32)
77+
return {"pixel_values": moments, "input_ids": clip_embeddings}
78+
79+
return _parse(example)

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,30 +87,29 @@ def make_tfrecord_iterator(
8787
maxdiffusion/pedagogical_examples/to_tfrecords.py
8888
"""
8989
feature_description = {
90-
"latents": tf.io.FixedLenFeature([], tf.string),
91-
"hidden_states": tf.io.FixedLenFeature([], tf.string),
90+
"moments": tf.io.FixedLenFeature([], tf.string),
91+
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
9292
}
9393

9494
def _parse_tfrecord_fn(example):
9595
return tf.io.parse_single_example(example, feature_description)
9696

9797
def prepare_sample(features):
98-
latents = tf.io.parse_tensor(tnp.asarray(features["latents"]), out_type=tf.float32)
99-
hidden_states = tf.io.parse_tensor(tnp.asarray(features["hidden_states"]), out_type=tf.float32)
100-
return {"pixel_values": latents, "input_ids": hidden_states}
98+
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
99+
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
100+
return {"pixel_values": moments, "input_ids": clip_embeddings}
101101

102102
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
103103
train_ds = (
104104
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
105+
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
105106
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
106107
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
107108
.shuffle(global_batch_size * 10)
108109
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
109-
.prefetch(AUTOTUNE)
110110
.repeat(-1)
111+
.prefetch(AUTOTUNE)
111112
)
112113

113-
train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
114-
115114
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
116115
return train_iter

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import jax
2222

2323
from maxdiffusion.input_pipeline import _hf_data_processing
24+
from maxdiffusion.input_pipeline import _grain_data_processing
2425
from maxdiffusion.input_pipeline import _tfds_data_processing
2526
from maxdiffusion import multihost_dataloading
2627
from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply
@@ -61,6 +62,14 @@ def make_data_iterator(
6162
tokenize_fn=tokenize_fn,
6263
image_transforms_fn=image_transforms_fn,
6364
)
65+
elif config.dataset_type == "grain":
66+
return _grain_data_processing.make_grain_iterator(
67+
config,
68+
dataloading_host_index,
69+
dataloading_host_count,
70+
mesh,
71+
global_batch_size,
72+
)
6473
elif config.dataset_type == "tf":
6574
return _tfds_data_processing.make_tf_iterator(
6675
config,
@@ -80,7 +89,7 @@ def make_data_iterator(
8089
global_batch_size,
8190
)
8291
else:
83-
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf)"
92+
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
8493

8594

8695
def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params):

src/maxdiffusion/tests/input_pipeline_interface_test.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from functools import partial
1919
import pathlib
2020
import shutil
21+
import subprocess
2122
import unittest
2223
from absl.testing import absltest
2324

@@ -425,13 +426,68 @@ def test_make_pokemon_iterator_sdxl_cache(self):
425426
config.resolution // vae_scale_factor,
426427
)
427428

429+
def test_make_laion_grain_iterator(self):
430+
try:
431+
subprocess.check_output(
432+
[
433+
"bash",
434+
"setup_gcsfuse.sh",
435+
"DATASET_GCS_BUCKET=maxdiffusion-github-runner-test-assets",
436+
"MOUNT_PATH=/tmp/gcsfuse",
437+
],
438+
stderr=subprocess.STDOUT,
439+
)
440+
except subprocess.CalledProcessError as e:
441+
raise ValueError(f"setup_gcsfuse failed with error: {e.output}") from e
442+
pyconfig.initialize(
443+
[
444+
None,
445+
os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"),
446+
"grain_train_files=/tmp/gcsfuse/datasets/array-record/laion400m/tf_records_512_encoder_state_fp32/*.arrayrecord",
447+
"dataset_type=grain",
448+
],
449+
unittest=True,
450+
)
451+
config = pyconfig.config
452+
global_batch_size = config.per_device_batch_size * jax.device_count()
453+
devices_array = max_utils.create_device_mesh(config)
454+
mesh = Mesh(devices_array, config.mesh_axes)
455+
456+
pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained(
457+
config.pretrained_model_name_or_path,
458+
revision=config.revision,
459+
dtype=config.activations_dtype,
460+
safety_checker=None,
461+
feature_extractor=None,
462+
from_pt=config.from_pt,
463+
)
464+
465+
train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size)
466+
data = next(train_iterator)
467+
device_count = jax.device_count()
468+
469+
vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1)
470+
encoder_hidden_states = data["input_ids"]
471+
472+
# TODO - laion dataset was prepared with an extra dim.
473+
# need to preprocess the dataset with dim removed.
474+
if len(encoder_hidden_states.shape) == 4:
475+
encoder_hidden_states = jnp.squeeze(encoder_hidden_states)
476+
477+
assert encoder_hidden_states.shape == (device_count, 77, 1024)
478+
assert data["pixel_values"].shape == (
479+
config.total_train_batch_size,
480+
config.resolution // vae_scale_factor,
481+
config.resolution // vae_scale_factor,
482+
8,
483+
)
484+
428485
def test_make_laion_tfrecord_iterator(self):
429486
pyconfig.initialize(
430487
[
431488
None,
432489
os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"),
433-
"cache_latents_text_encoder_outputs=True",
434-
"train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/processed/laion400m_tfrec",
490+
"train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/raw_data/tf_records_512_encoder_state_fp32",
435491
"dataset_type=tfrecord",
436492
],
437493
unittest=True,
@@ -464,10 +520,10 @@ def test_make_laion_tfrecord_iterator(self):
464520

465521
assert encoder_hidden_states.shape == (device_count, 77, 1024)
466522
assert data["pixel_values"].shape == (
467-
device_count,
468-
pipeline.unet.config.in_channels,
523+
config.total_train_batch_size,
469524
config.resolution // vae_scale_factor,
470525
config.resolution // vae_scale_factor,
526+
8,
471527
)
472528

473529
def test_tfrecord(self):

src/maxdiffusion/trainers/stable_diffusion_trainer.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from maxdiffusion import (FlaxDDPMScheduler, maxdiffusion_utils, train_utils, max_utils, max_logging)
3030

3131
from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator)
32-
32+
from maxdiffusion.models.vae_flax import FlaxDiagonalGaussianDistribution
3333

3434
from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (STABLE_DIFFUSION_CHECKPOINT)
3535

@@ -67,6 +67,19 @@ def get_shaped_batch(self, config, pipeline):
6767
pipeline.text_encoder.config.hidden_size,
6868
)
6969
input_ids_dtype = jnp.float32
70+
elif config.dataset_type in ("tfrecord", "grain"):
71+
batch_image_shape = (
72+
total_train_batch_size,
73+
config.resolution // vae_scale_factor,
74+
config.resolution // vae_scale_factor,
75+
8,
76+
)
77+
batch_ids_shape = (
78+
total_train_batch_size,
79+
pipeline.text_encoder.config.max_position_embeddings,
80+
pipeline.text_encoder.config.hidden_size,
81+
)
82+
input_ids_dtype = jnp.float32
7083
else:
7184
batch_image_shape = (total_train_batch_size, 3, config.resolution, config.resolution)
7285
batch_ids_shape = (total_train_batch_size, pipeline.text_encoder.config.max_position_embeddings)
@@ -240,10 +253,14 @@ def _train_step(unet_state, vae_state, text_encoder_state, batch, train_rng, pip
240253
state_params = {"unet": unet_state.params}
241254

242255
def compute_loss(state_params):
243-
244256
if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs:
245257
latents = batch["pixel_values"]
246258
encoder_hidden_states = batch["input_ids"]
259+
elif config.dataset_type in ("tfrecord", "grain"):
260+
latents = FlaxDiagonalGaussianDistribution(batch["pixel_values"]).sample(sample_rng)
261+
latents = jnp.transpose(latents, (0, 3, 1, 2))
262+
latents = latents * pipeline.vae.config.scaling_factor
263+
encoder_hidden_states = batch["input_ids"]
247264
else:
248265
# Convert images to latent space
249266
vae_outputs = pipeline.vae.apply(

0 commit comments

Comments
 (0)