Skip to content

Latest commit

 

History

History
138 lines (91 loc) · 7.72 KB

File metadata and controls

138 lines (91 loc) · 7.72 KB

Common Learning Guide

Overview

Please refer to the README of this repository beforehand and set up the environment.

The following is explained:

  1. Training Scripts.
  2. Configs.

Training Scripts

MaxDiffusion provides training scripts:

Configs

The maxdiffusion repo is based on configuration files with the idea that few to no code changes will be required to run a training or inference job and config settings are modified instead.

In this session, we'll explain some of the core config parameters and how they affect training. Let's start with configuration to model mappings:

config model supports
base14.yml stable-diffusion-v1-4 training / inference
base_2_base.yml stable-diffusion-2-base training / inference
base21.yml stable-diffusion-2-1 training / inference
base_xl.yml stable-diffusion-xl-base-1.0 training / inference
base_xl_lightning.yml stable-diffusion-xl-base-1.0 & ByteDance/SDXL-Lightning inference

Changes to a config can be applied by changing the yml file directly or by passing those parameters in cli when creating a job. The only required parameters to pass to a job are run_name and output_dir.

Let's start with a simple example. After setting up your environment, create a training job as follows:

export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base14.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash

The job will use the predefined parameters in base14.yml and will overwrite any parameters that as passed into the cli.

Changing The Base Model

MaxDiffusion configs come with predefined models, mostly based on the base models created by StabilityAI and RunwayAI. The base model can be changed by setting pretrained_model_name_or_path to a different model, the only requirement is that the model is in diffusers format (full checkpoints will be supported in the future).

To load Pytorch weights, set from_pt=True and set revision=main. Let's look at an example. Here we'll load Stable Diffusion 1.5 from a Pytorch checkpoint.

export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base14.yml run_name="my_run" output_dir="gs://your-bucket/" pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 from_pt=True revision=main

After training, a new folder structure with weights and metrics has been created under the output_dir folder:

├── output_dir
│   ├── run_name
│       ├── checkpoints
│       ├── metrics
│       ├── tensorboard

It is recommended to use a Google Cloud Storage bucket as the output_dir. This will ensure all your work persists across VM creations. You can also use a local directory.

To use the trained checkpoint, then run:

python src/maxdiffusion/generate.py src/maxdiffusion/configs/base14.yml output_dir="gs://your-bucket/" run_name="my_run"

Changing The Sharding Strategy

MaxDiffusion models use logical axis annotations, which allows users to explore different sharding layouts without making changes to the model code. To learn more about distributed arrays and Flax partitioning, checkout JAX's Distributed arrays and automatic parallelization and then FLAX's Scale up Flax Modules on multiple devices

The main config values for these are:

  • mesh_axes
  • logical_axis_rules
  • data_sharding
  • dcn_data_parallelism
  • dcn_fsdp_parallelism
  • dcn_tensor_parallelism
  • ici_data_parallelism
  • ici_fsdp_parallelism
  • ici_tensor_parallelism

Out of the box, all maxdiffusion configs use data parallelism.

mesh_axes supports 3 mesh axes: data, fsdp and tensor.

logical_axis_rules are used to define which weights and activations should be sharded across a mesh axes.

data_sharding defines the data sharding strategy.

dcn_* stands for data center network parallelism and define parallelism strategies for TPU multi-slice.

ici_* stands for interchip interconnect parallelism and define parallelism strategies for TPU single-slice.

See Multislice vs single slice.

Note: maxdiffusion does not yet support multi-slice.

Let's look at how these settings work to implement data parallelism. Let's assume we're using a TPUv4-8 and define the ici parallelism strategy:

mesh_axes: ['data', 'fsdp', 'tensor']
ici_data_parallelism: -1
ici_fsdp_parallelism: 1  
ici_tensor_parallelism: 1

Recall that in a TPUv4-8 configuration, the number of chips is 4 (each TPU v4 chip contains two TensorCores). Passing a -1 to an axis tells maxdiffusion to set all devices to that given axis, thus our mesh is created as Mesh('data': 4, 'fsdp': 1, 'tensor': 1).

Now let's change the configuration as follows:

mesh_axes: ['data', 'fsdp', 'tensor']
ici_data_parallelism: 2
ici_fsdp_parallelism: 2  
ici_tensor_parallelism: 1

Then our mesh will look like Mesh('data': 2, 'fsdp': 2, 'tensor': 1).

The logical_axis_rules specifies the sharding across the mesh. You are encouraged to add or remove rules and find what best works for you.

Checkpointing

Checkpointing can be enabled by using checkpoint_every. It is based on the number of samples (per_device_batch_size * jax.device_count()).

Orbax is used to save checkpoints, however, orbax does not currently store tokenizers. Instead the tokenizer model name or path is stored inside of the checkpoint and then loaded during inference.