Please refer to the README of this repository beforehand and set up the environment.
The following is explained:
MaxDiffusion provides training scripts:
- train.py : supports training sd1.x, sd 2 base and sd2.1.
- train_dreambooth.py : supports training dreambooth sd1.x, sd 2 base, and sd2.1.
- train_sdxl.py : supports sdxl training.
The maxdiffusion repo is based on configuration files with the idea that few to no code changes will be required to run a training or inference job and config settings are modified instead.
In this session, we'll explain some of the core config parameters and how they affect training. Let's start with configuration to model mappings:
| config | model | supports |
|---|---|---|
| base14.yml | stable-diffusion-v1-4 | training / inference |
| base_2_base.yml | stable-diffusion-2-base | training / inference |
| base21.yml | stable-diffusion-2-1 | training / inference |
| base_xl.yml | stable-diffusion-xl-base-1.0 | training / inference |
| base_xl_lightning.yml | stable-diffusion-xl-base-1.0 & ByteDance/SDXL-Lightning | inference |
Changes to a config can be applied by changing the yml file directly or by passing those parameters in cli when creating a job. The only required parameters to pass to a job are run_name and output_dir.
Let's start with a simple example. After setting up your environment, create a training job as follows:
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base14.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flashThe job will use the predefined parameters in base14.yml and will overwrite any parameters that as passed into the cli.
MaxDiffusion configs come with predefined models, mostly based on the base models created by StabilityAI and RunwayAI. The base model can be changed by setting pretrained_model_name_or_path to a different model, the only requirement is that the model is in diffusers format (full checkpoints will be supported in the future).
To load Pytorch weights, set from_pt=True and set revision=main. Let's look at an example. Here we'll load Stable Diffusion 1.5 from a Pytorch checkpoint.
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base14.yml run_name="my_run" output_dir="gs://your-bucket/" pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 from_pt=True revision=mainAfter training, a new folder structure with weights and metrics has been created under the output_dir folder:
├── output_dir
│ ├── run_name
│ ├── checkpoints
│ ├── metrics
│ ├── tensorboardIt is recommended to use a Google Cloud Storage bucket as the output_dir. This will ensure all your work persists across VM creations. You can also use a local directory.
To use the trained checkpoint, then run:
python src/maxdiffusion/generate.py src/maxdiffusion/configs/base14.yml output_dir="gs://your-bucket/" run_name="my_run"MaxDiffusion models use logical axis annotations, which allows users to explore different sharding layouts without making changes to the model code. To learn more about distributed arrays and Flax partitioning, checkout JAX's Distributed arrays and automatic parallelization and then FLAX's Scale up Flax Modules on multiple devices
The main config values for these are:
- mesh_axes
- logical_axis_rules
- data_sharding
- dcn_data_parallelism
- dcn_fsdp_parallelism
- dcn_tensor_parallelism
- ici_data_parallelism
- ici_fsdp_parallelism
- ici_tensor_parallelism
Out of the box, all maxdiffusion configs use data parallelism.
mesh_axes supports 3 mesh axes: data, fsdp and tensor.
logical_axis_rules are used to define which weights and activations should be sharded across a mesh axes.
data_sharding defines the data sharding strategy.
dcn_* stands for data center network parallelism and define parallelism strategies for TPU multi-slice.
ici_* stands for interchip interconnect parallelism and define parallelism strategies for TPU single-slice.
See Multislice vs single slice.
Note: maxdiffusion does not yet support multi-slice.
Let's look at how these settings work to implement data parallelism. Let's assume we're using a TPUv4-8 and define the ici parallelism strategy:
mesh_axes: ['data', 'fsdp', 'tensor']
ici_data_parallelism: -1
ici_fsdp_parallelism: 1
ici_tensor_parallelism: 1Recall that in a TPUv4-8 configuration, the number of chips is 4 (each TPU v4 chip contains two TensorCores). Passing a -1 to an axis tells maxdiffusion to set all devices to that given axis, thus our mesh is created as Mesh('data': 4, 'fsdp': 1, 'tensor': 1).
Now let's change the configuration as follows:
mesh_axes: ['data', 'fsdp', 'tensor']
ici_data_parallelism: 2
ici_fsdp_parallelism: 2
ici_tensor_parallelism: 1Then our mesh will look like Mesh('data': 2, 'fsdp': 2, 'tensor': 1).
The logical_axis_rules specifies the sharding across the mesh. You are encouraged to add or remove rules and find what best works for you.
Checkpointing can be enabled by using checkpoint_every. It is based on the number of samples (per_device_batch_size * jax.device_count()).
Orbax is used to save checkpoints, however, orbax does not currently store tokenizers. Instead the tokenizer model name or path is stored inside of the checkpoint and then loaded during inference.