diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index b8710e1a6..aa68267a5 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -33,6 +33,7 @@ STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT" STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT" +FLUX_CHECKPOINT = "FLUX_CHECKPOINT" def create_orbax_checkpoint_manager( @@ -56,17 +57,20 @@ def create_orbax_checkpoint_manager( max_logging.log(f"checkpoint dir: {checkpoint_dir}") p = epath.Path(checkpoint_dir) - item_names = ( - "unet_config", - "vae_config", - "text_encoder_config", - "scheduler_config", - "unet_state", - "vae_state", - "text_encoder_state", - "tokenizer_config", - ) - if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT: + if checkpoint_type == FLUX_CHECKPOINT: + item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config") + else: + item_names = ( + "unet_config", + "vae_config", + "text_encoder_config", + "scheduler_config", + "unet_state", + "vae_state", + "text_encoder_state", + "tokenizer_config", + ) + if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT: item_names += ( "text_encoder_2_state", "text_encoder_2_config", @@ -117,7 +121,7 @@ def load_stable_diffusion_configs( "tokenizer_config": orbax.checkpoint.args.JsonRestore(), } - if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT: + if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT: restore_args["text_encoder_2_config"] = orbax.checkpoint.args.JsonRestore() return (checkpoint_manager.restore(step, args=orbax.checkpoint.args.Composite(**restore_args)), None) @@ -139,6 +143,7 @@ def load_params_from_path( ckpt_path = os.path.join(config.checkpoint_dir, str(step), checkpoint_item) ckpt_path = epath.Path(ckpt_path) + ckpt_path = os.path.abspath(ckpt_path) restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params) restored = ckptr.restore( diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py new file mode 100644 index 000000000..a5e1bfc2f --- /dev/null +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -0,0 +1,305 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from abc import ABC +from contextlib import nullcontext +import functools +import json +import jax +from jax.sharding import Mesh +import orbax.checkpoint as ocp +import grain.python as grain +from maxdiffusion import ( + max_utils, + FlaxAutoencoderKL, + max_logging, +) +from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel +from ..pipelines.flux.flux_pipeline import FluxPipeline + +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer) + +from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) +from maxdiffusion.models.flux.util import load_flow_model + +FLUX_CHECKPOINT = "FLUX_CHECKPOINT" +_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX" + +FLUX_STATE_KEY = "flux_state" +FLUX_TRANSFORMER_PARAMS_KEY = "flux_transformer_params" +FLUX_STATE_SHARDINGS_KEY = "flux_state_shardings" +FLUX_VAE_PARAMS_KEY = "flux_vae" +VAE_STATE_KEY = "vae_state" +VAE_STATE_SHARDINGS_KEY = "vae_state_shardings" + + +class FluxCheckpointer(ABC): + + def __init__(self, config, checkpoint_type): + self.config = config + self.checkpoint_type = checkpoint_type + self.checkpoint_format = None + + self.rng = jax.random.PRNGKey(self.config.seed) + self.devices_array = max_utils.create_device_mesh(config) + self.mesh = Mesh(self.devices_array, self.config.mesh_axes) + self.total_train_batch_size = self.config.total_train_batch_size + + self.checkpoint_manager = create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, + ) + + def _create_optimizer(self, config, learning_rate): + + learning_rate_scheduler = max_utils.create_learning_rate_schedule( + learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps + ) + tx = max_utils.create_optimizer(config, learning_rate_scheduler) + return tx, learning_rate_scheduler + + def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training): + transformer = pipeline.flux + + tx, learning_rate_scheduler = None, None + if is_training: + learning_rate = self.config.learning_rate + + tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate) + + transformer_eval_params = transformer.init_weights( + rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True + ) + + transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu") + + weights_init_fn = functools.partial( + pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length + ) + flux_state, state_mesh_shardings = max_utils.setup_initial_state( + model=pipeline.flux, + tx=tx, + config=self.config, + mesh=self.mesh, + weights_init_fn=weights_init_fn, + model_params=None, + checkpoint_manager=self.checkpoint_manager, + checkpoint_item=checkpoint_item_name, + training=is_training, + ) + if not self.config.train_new_flux: + flux_state = flux_state.replace(params=transformer_params) + flux_state = jax.device_put(flux_state, state_mesh_shardings) + return flux_state, state_mesh_shardings, learning_rate_scheduler + + def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): + + # Currently VAE training is not supported. + weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng) + return max_utils.setup_initial_state( + model=pipeline.vae, + tx=None, + config=self.config, + mesh=self.mesh, + weights_init_fn=weights_init_fn, + model_params=params.get("flux_vae", None), + checkpoint_manager=self.checkpoint_manager, + checkpoint_item=checkpoint_item_name, + training=is_training, + ) + + def restore_data_iterator_state(self, data_iterator): + if ( + self.config.dataset_type == "grain" + and data_iterator is not None + and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists() + ): + max_logging.log("Restoring data iterator from checkpoint") + restored = self.checkpoint_manager.restore( + self.checkpoint_manager.latest_step(), + args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)), + ) + data_iterator.local_iterator = restored["iter"] + else: + max_logging.log("data iterator checkpoint not found") + return data_iterator + + def _get_pipeline_class(self): + return FluxPipeline + + def _set_checkpoint_format(self, checkpoint_format): + self.checkpoint_format = checkpoint_format + + def save_checkpoint(self, train_step, pipeline, train_states): + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + items = { + "flux_config": ocp.args.JsonSave(config_to_json(pipeline.flux)), + "vae_config": ocp.args.JsonSave(config_to_json(pipeline.vae)), + "scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)), + } + + items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY]) + items["vae_state"] = ocp.args.PyTreeSave(train_states["vae_state"]) + items["scheduler"] = ocp.args.PyTreeSave(train_states["scheduler"]) + + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + + def load_params(self, step=None): + + self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX + + def load_flux_configs_from_orbax(self, step): + max_logging.log("Restoring stable diffusion configs") + if step is None: + step = self.checkpoint_manager.latest_step() + if step is None: + return None + + restore_args = { + "flux_config": ocp.args.JsonRestore(), + "vae_config": ocp.args.JsonRestore(), + "scheduler_config": ocp.args.JsonRestore(), + } + + return (self.checkpoint_manager.restore(step, args=ocp.args.Composite(**restore_args)), None) + + def load_diffusers_checkpoint(self): + flash_block_sizes = max_utils.get_flash_block_sizes(self.config) + + if jax.device_count() == jax.local_device_count(): + context = jax.default_device(jax.devices("cpu")[0]) + else: + context = nullcontext() + + with context: + clip_encoder = FlaxCLIPTextModel.from_pretrained(self.config.clip_model_name_or_path, dtype=self.config.weights_dtype) + clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True) + t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) + t5_tokenizer = AutoTokenizer.from_pretrained( + self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True + ) + + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + self.config.pretrained_model_name_or_path, + subfolder="vae", + from_pt=True, + use_safetensors=True, + dtype=self.config.weights_dtype, + ) + + # loading from pretrained here causes a crash when trying to compile the model + # Failed to load HSACO: HIP_ERROR_NoBinaryForGpu + transformer = FluxTransformer2DModel.from_config( + self.config.pretrained_model_name_or_path, + subfolder="transformer", + mesh=self.mesh, + split_head_dim=self.config.split_head_dim, + attention_kernel=self.config.attention, + flash_block_sizes=flash_block_sizes, + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + precision=max_utils.get_precision(self.config), + ) + transformer_eval_params = transformer.init_weights( + rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True + ) + + transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu") + + pipeline = FluxPipeline( + t5_encoder, + clip_encoder, + vae, + t5_tokenizer, + clip_tokenizer, + transformer, + None, + dtype=self.config.activations_dtype, + mesh=self.mesh, + config=self.config, + rng=self.rng, + ) + + params = {FLUX_VAE_PARAMS_KEY: vae_params, FLUX_TRANSFORMER_PARAMS_KEY: transformer_params} + + return pipeline, params + + def load_checkpoint(self, step=None, scheduler_class=None): + + model_configs = self.load_flux_configs_from_orbax(step) + + pipeline, params = None, {} + + if model_configs: + if jax.device_count() == jax.local_device_count(): + context = jax.default_device(jax.devices("cpu")[0]) + else: + context = nullcontext() + + with context: + clip_encoder = FlaxCLIPTextModel.from_pretrained( + self.config.clip_model_name_or_path, dtype=self.config.weights_dtype + ) + clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True) + t5_encoder = FlaxT5EncoderModel.from_pretrained( + self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype + ) + t5_tokenizer = AutoTokenizer.from_pretrained( + self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True + ) + + vae = FlaxAutoencoderKL.from_config( + model_configs[0]["vae_config"], + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + from_pt=self.config.from_pt, + ) + + transformer = FluxTransformer2DModel.from_config( + model_configs[0]["flux_config"], + mesh=self.mesh, + split_head_dim=self.config.split_head_dim, + attention_kernel=self.config.attention, + flash_block_sizes=max_utils.get_flash_block_sizes(self.config), + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + precision=max_utils.get_precision(self.config), + from_pt=self.config.from_pt, + ) + + pipeline = FluxPipeline( + t5_encoder, + clip_encoder, + vae, + t5_tokenizer, + clip_tokenizer, + transformer, + None, + dtype=self.config.activations_dtype, + mesh=self.mesh, + config=self.config, + rng=self.rng, + ) + + else: + pipeline, params = self.load_diffusers_checkpoint() + + return pipeline, params diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 944153d64..167dc8bc8 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -75,9 +75,9 @@ flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 -# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# If train_new_flux, flux weights will be randomly initialized to train flux from scratch # else they will be loaded from pretrained_model_name_or_path -train_new_unet: False +train_new_flux: False # train text_encoder - Currently not supported for SDXL train_text_encoder: False @@ -177,7 +177,7 @@ hf_train_files: '' hf_access_token: '' image_column: 'image' caption_column: 'text' -resolution: 1024 +resolution: 512 center_crop: False random_flip: False # If cache_latents_text_encoder_outputs is True @@ -193,17 +193,17 @@ checkpoint_every: -1 enable_single_replica_ckpt_restoring: False # Training loop -learning_rate: 4.e-7 +learning_rate: 1.e-5 scale_lr: False max_train_samples: -1 # max_train_steps takes priority over num_train_epochs. -max_train_steps: 200 +max_train_steps: 1500 num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' per_device_batch_size: 1 -warmup_steps_fraction: 0.0 +warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. # However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before @@ -213,7 +213,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 1.e-2 # AdamW Weight decay +adam_weight_decay: 0 # AdamW Weight decay max_grad_norm: 1.0 enable_profiler: False @@ -231,6 +231,7 @@ guidance_scale: 3.5 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 50 +save_final_checkpoint: False # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 3106255a9..188074a5b 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -83,9 +83,9 @@ flash_block_sizes: { # GroupNorm groups norm_num_groups: 32 -# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# If train_new_flux, flux weights will be randomly initialized to train flux from scratch # else they will be loaded from pretrained_model_name_or_path -train_new_unet: False +train_new_flux: False # train text_encoder - Currently not supported for SDXL train_text_encoder: False diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py new file mode 100644 index 000000000..8887375d0 --- /dev/null +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -0,0 +1,125 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Sequence +from absl import app +from contextlib import ExitStack +import functools +import time +import numpy as np +from PIL import Image +import jax + +from maxdiffusion import pyconfig, max_logging, max_utils + +from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path +from maxdiffusion.max_utils import setup_initial_state + + +def run(config): + from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer + + checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT") + pipeline, params = checkpoint_loader.load_checkpoint() + + if not params: + ## VAE + weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False + ) + # load unet params from orbax checkpoint + vae_params = load_params_from_path( + config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state" + ) + + vae_state = {"params": vae_params} + + ## Flux + weights_init_fn = functools.partial( + pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length + ) + + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False + ) + # load unet params from orbax checkpoint + flux_params = load_params_from_path( + config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state" + ) + flux_state = {"params": flux_params} + else: + weights_init_fn = functools.partial( + pipeline.flux.init_weights, + rngs=checkpoint_loader.rng, + max_sequence_length=config.max_sequence_length, + eval_only=False, + ) + transformer_state, flux_state_shardings = setup_initial_state( + model=pipeline.flux, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + transformer_state = transformer_state.replace(params=params["flux_transformer_params"]) + transformer_state = jax.device_put(transformer_state, flux_state_shardings) + + weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng) + vae_state, _ = setup_initial_state( + model=pipeline.vae, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=params["flux_vae"], + training=False, + ) + + vae_state = {"params": vae_state.params} + flux_state = {"params": transformer_state.params} + + t0 = time.perf_counter() + with ExitStack(): + imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready() + t1 = time.perf_counter() + max_logging.log(f"Compile time: {t1 - t0:.1f}s.") + + t0 = time.perf_counter() + with ExitStack(): + imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready() + imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) + t1 = time.perf_counter() + max_logging.log(f"Inference time: {t1 - t0:.1f}s.") + imgs = np.array(imgs) + imgs = (imgs * 0.5 + 0.5).clip(0, 1) + imgs = np.transpose(imgs, (0, 2, 3, 1)) + imgs = np.uint8(imgs * 255) + for i, image in enumerate(imgs): + Image.fromarray(image).save(f"flux_{i}.png") + + return imgs + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 05f9802a4..43400a62e 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -256,6 +256,64 @@ def calculate_unet_tflops(config, pipeline, batch_size, rngs, train): ) +def get_dummy_flux_inputs(config, pipeline, batch_size): + """Returns randomly initialized flux inputs.""" + latents, latents_ids = pipeline.prepare_latents( + batch_size=batch_size, + num_channels_latents=pipeline.flux.in_channels // 4, + height=config.resolution, + width=config.resolution, + vae_scale_factor=pipeline.vae_scale_factor, + dtype=config.activations_dtype, + rng=pipeline.rng, + ) + guidance_vec = jnp.asarray([config.guidance_scale] * batch_size, dtype=config.activations_dtype) + + timesteps = jnp.ones((batch_size,), dtype=config.weights_dtype) + t5_hidden_states_shape = ( + batch_size, + config.max_sequence_length, + 4096, + ) + t5_hidden_states = jnp.zeros(t5_hidden_states_shape, dtype=config.weights_dtype) + t5_ids = jnp.zeros((batch_size, t5_hidden_states.shape[1], 3), dtype=config.weights_dtype) + + clip_hidden_states_shape = ( + batch_size, + 768, + ) + clip_hidden_states = jnp.zeros(clip_hidden_states_shape, dtype=config.weights_dtype) + + return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) + + +def calculate_flux_tflops(config, pipeline, batch_size, rngs, train): + """ + Calculates jflux tflops. + batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't + cache the compilation when flash is enabled. + """ + + (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs( + config, pipeline, batch_size + ) + return ( + max_utils.calculate_model_tflops( + pipeline.flux, + rngs, + train, + hidden_states=latents, + img_ids=latents_ids, + encoder_hidden_states=t5_hidden_states, + txt_ids=t5_ids, + pooled_projections=clip_hidden_states, + timestep=timesteps, + guidance=guidance_vec, + ) + / jax.local_device_count() + ) + + def tokenize_captions(examples, caption_column, tokenizer, input_ids_key="input_ids", p_encode=None): """Tokenize captions for sd1.x,sd2.x models.""" captions = list(examples[caption_column]) diff --git a/src/maxdiffusion/pipelines/flux/__init__.py b/src/maxdiffusion/pipelines/flux/__init__.py new file mode 100644 index 000000000..fcc4854fb --- /dev/null +++ b/src/maxdiffusion/pipelines/flux/__init__.py @@ -0,0 +1,5 @@ +_import_structure = {"pipeline_jflux": "JfluxPipeline"} + +from .flux_pipeline import ( + FluxPipeline, +) diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py new file mode 100644 index 000000000..e655b491d --- /dev/null +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -0,0 +1,408 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import List, Optional, Union, Callable + +import jax +import jax.numpy as jnp +import math +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer) +from einops import rearrange +from jax.typing import DTypeLike +from chex import Array + +from flax.linen import partitioning as nn_partitioning + +from maxdiffusion.utils import logging + +from ...models import FlaxAutoencoderKL +from ...schedulers import (FlaxEulerDiscreteScheduler) +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + + +class FluxPipeline(FlaxDiffusionPipeline): + + def __init__( + self, + t5_encoder: FlaxCLIPTextModel, + clip_encoder: FlaxCLIPTextModel, + vae: FlaxAutoencoderKL, + t5_tokenizer: FlaxT5EncoderModel, + clip_tokenizer: CLIPTokenizer, + flux: FluxTransformer2DModel, + scheduler: FlaxEulerDiscreteScheduler, + dtype: jnp.dtype = jnp.float32, + mesh: Optional = None, + config: Optional = None, + rng: Optional = None, + ): + super().__init__() + self.dtype = dtype + self.register_modules( + vae=vae, + t5_encoder=t5_encoder, + clip_encoder=clip_encoder, + t5_tokenizer=t5_tokenizer, + clip_tokenizer=clip_tokenizer, + flux=flux, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + self.mesh = mesh + self._config = config + self.rng = rng + + def create_noise( + self, + num_samples: int, + height: int, + width: int, + dtype: DTypeLike, + seed: jax.random.PRNGKey, + ): + return jax.random.normal( + key=seed, + shape=(num_samples, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16)), + dtype=dtype, + ) + + def unpack(self, x: Array, height: int, width: int) -> Array: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + + def vae_decode(self, latents, vae, state, config): + img = self.unpack(x=latents, height=config.resolution, width=config.resolution) + img = img / vae.config.scaling_factor + vae.config.shift_factor + img = vae.apply({"params": state["params"]}, img, deterministic=True, method=vae.decode).sample + return img + + def vae_encode(self, latents, vae, state): + img = vae.apply({"params": state["params"]}, latents, deterministic=True, method=vae.encode).latent_dist.mode() + img = vae.config.scaling_factor * (img - vae.config.shift_factor) + return img + + # this is the reverse of the unpack function + def pack_latents( + self, + latents: Array, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + ): + latents = jnp.reshape(latents, (batch_size, num_channels_latents, height // 2, 2, width // 2, 2)) + latents = jnp.permute_dims(latents, (0, 2, 4, 1, 3, 5)) + latents = jnp.reshape(latents, (batch_size, (height // 2) * (width // 2), num_channels_latents * 4)) + + return latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + vae_scale_factor: int, + dtype: jnp.dtype, + rng: Array, + ): + + # VAE applies 8x compression on images but we must also account for packing which + # requires latent height and width to be divisibly by 2. + height = 2 * (height // (vae_scale_factor * 2)) + width = 2 * (width // (vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + latents = jax.random.normal(rng, shape=shape, dtype=jnp.bfloat16) + # pack latents + latents = self.pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self.prepare_latent_image_ids(height // 2, width // 2) + latent_image_ids = jnp.tile(latent_image_ids, (batch_size, 1, 1)) + + return latents, latent_image_ids + + def prepare_latent_image_ids(self, height, width): + latent_image_ids = jnp.zeros((height, width, 3)) + latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None]) + latent_image_ids = latent_image_ids.at[..., 2].set(jnp.arange(width)[None, :]) + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) + + return latent_image_ids.astype(jnp.bfloat16) + + def get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + tokenizer: CLIPTokenizer, + text_encoder: FlaxCLIPTextModel, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="np", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False) + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = jnp.tile(prompt_embeds, (num_images_per_prompt, 1)) + return prompt_embeds + + def get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + tokenizer: AutoTokenizer, + text_encoder: FlaxT5EncoderModel, + max_sequence_length: int = 512, + encode_in_batches=False, + encode_batch_size=None, + ): + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + truncation=True, + max_length=max_sequence_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + if encode_in_batches: + prompt_embeds = None + for i in range(0, text_input_ids.shape[0], encode_batch_size): + batch_prompt_embeds = text_encoder( + text_input_ids[i : i + encode_batch_size], attention_mask=None, output_hidden_states=False + )["last_hidden_state"] + if prompt_embeds is None: + prompt_embeds = batch_prompt_embeds + else: + prompt_embeds = jnp.concatenate([prompt_embeds, batch_prompt_embeds]) + else: + prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"] + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.astype(dtype) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + clip_tokenizer: CLIPTokenizer, + clip_text_encoder: FlaxCLIPTextModel, + t5_tokenizer: AutoTokenizer, + t5_text_encoder: FlaxT5EncoderModel, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + encode_in_batches: bool = False, + encode_batch_size: int = None, + ): + + if encode_in_batches: + assert encode_in_batches is not None + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_2 = prompt or prompt_2 + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + pooled_prompt_embeds = self.get_clip_prompt_embeds( + prompt=prompt, num_images_per_prompt=num_images_per_prompt, tokenizer=clip_tokenizer, text_encoder=clip_text_encoder + ) + + prompt_embeds = self.get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + tokenizer=t5_tokenizer, + text_encoder=t5_text_encoder, + max_sequence_length=max_sequence_length, + encode_in_batches=encode_in_batches, + encode_batch_size=encode_batch_size, + ) + + text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) + return prompt_embeds, pooled_prompt_embeds, text_ids + + def _generate( + self, flux_params, vae_params, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts + ): + + def loop_body( + step, + args, + transformer, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, + ): + latents, state, c_ts, p_ts = args + latents_dtype = latents.dtype + t_curr = c_ts[step] + t_prev = p_ts[step] + t_vec = jnp.full((latents.shape[0],), t_curr, dtype=latents.dtype) + pred = transformer.apply( + {"params": state["params"]}, + hidden_states=latents, + img_ids=latent_image_ids, + encoder_hidden_states=prompt_embeds, + txt_ids=txt_ids, + timestep=t_vec, + guidance=guidance_vec, + pooled_projections=vec, + ).sample + latents = latents + (t_prev - t_curr) * pred + latents = jnp.array(latents, dtype=latents_dtype) + return latents, state, c_ts, p_ts + + loop_body_p = partial( + loop_body, + transformer=self.flux, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=txt_ids, + vec=vec, + guidance_vec=guidance_vec, + ) + + vae_decode_p = partial(self.vae_decode, vae=self.vae, state=vae_params, config=self._config) + + with self.mesh, nn_partitioning.axis_rules(self._config.logical_axis_rules): + latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, flux_params, c_ts, p_ts)) + image = vae_decode_p(latents) + return image + + def do_time_shift(self, mu: float, sigma: float, t: Array): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def get_lin_function( + self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 + ) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + def time_shift(self, latents, timesteps): + # estimate mu based on linear estimation between two points + lin_function = self.get_lin_function( + x1=self._config.max_sequence_length, y1=self._config.base_shift, y2=self._config.max_shift + ) + mu = lin_function(latents.shape[1]) + timesteps = self.do_time_shift(mu, 1.0, timesteps) + return timesteps + + def __call__(self, timesteps: int, flux_params, vae_params): + r""" + The call function to the pipeline for generation. + + Args: + txt: jnp.array, + txt_ids: jnp.array, + vec: jnp.array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + img: Optional[jnp.ndarray] = None, + shift: bool = False, + jit (`bool`, defaults to `False`): + + Examples: + + """ + + if isinstance(timesteps, int): + timesteps = jnp.linspace(1, 0, timesteps + 1) + + global_batch_size = 1 * jax.local_device_count() + + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=self._config.prompt, + prompt_2=self._config.prompt_2, + clip_tokenizer=self.clip_tokenizer, + clip_text_encoder=self.clip_encoder, + t5_tokenizer=self.t5_tokenizer, + t5_text_encoder=self.t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=self._config.max_sequence_length, + ) + + num_channels_latents = self.flux.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=self._config.resolution, + width=self._config.resolution, + dtype=jnp.bfloat16, + vae_scale_factor=self.vae_scale_factor, + rng=self.rng, + ) + + if self._config.time_shift: + timesteps = self.time_shift(latents, timesteps) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + + guidance = jnp.asarray([self._config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) + + images = self._generate( + flux_params, + vae_params, + latents, + latent_image_ids, + prompt_embeds, + text_ids, + pooled_prompt_embeds, + guidance, + c_ts, + p_ts, + ) + + images = images + return images diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index bc9013cbf..863fa26cd 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -144,7 +144,7 @@ def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndar return sample def set_timesteps( - self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = (), timestep_spacing: str = "" ) -> EulerDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -155,17 +155,21 @@ def set_timesteps( num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ + if timestep_spacing == "": + timestep_spacing = self.config.timestep_spacing - if self.config.timestep_spacing == "linspace": + if timestep_spacing == "linspace": timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) - elif self.config.timestep_spacing == "leading": + elif timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // num_inference_steps timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += 1 - elif self.config.timestep_spacing == "trailing": + elif timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps timesteps = (jnp.arange(self.config.num_train_timesteps, 0, -step_ratio)).round() timesteps -= 1 + elif timestep_spacing == "flux": + timesteps = jnp.linspace(1, 0, num_inference_steps + 1) else: raise ValueError( f"timestep_spacing must be one of ['linspace', 'leading', 'trailing'], got {self.config.timestep_spacing}" @@ -250,7 +254,15 @@ def add_noise( original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, + flux: bool = False, ) -> jnp.ndarray: + + if flux: + t = state.timesteps[timesteps] + t = t[:, None, None] + noisy_samples = t * noise + (1 - t) * original_samples + return noisy_samples + sigma = state.sigmas[timesteps].flatten() sigma = broadcast_to_shape_from_left(sigma, noise.shape) diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py new file mode 100644 index 000000000..e3b161039 --- /dev/null +++ b/src/maxdiffusion/train_flux.py @@ -0,0 +1,49 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from typing import Sequence + +import jax +from absl import app +from maxdiffusion import ( + max_logging, + pyconfig, + mllog_utils, +) + +from maxdiffusion.train_utils import ( + validate_train_config, +) + + +def train(config): + from maxdiffusion.trainers.flux_trainer import FluxTrainer + + trainer = FluxTrainer(config) + trainer.start_training() + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + config = pyconfig.config + mllog_utils.train_init_start(config) + validate_train_config(config) + max_logging.log(f"Found {jax.device_count()} devices.") + train(config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 648c10d7e..1337f2329 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -111,9 +111,12 @@ def write_metrics_to_tensorboard(writer, metrics, step, config): full_log = step % config.log_period == 0 if jax.process_index() == 0: max_logging.log( - f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " - f"loss: {metrics['scalar']['learning/loss']:.3f}" + "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( + step, + metrics["scalar"]["perf/step_time_seconds"], + metrics["scalar"]["perf/per_device_tflops_per_sec"], + float(metrics["scalar"]["learning/loss"]), + ) ) if full_log and jax.process_index() == 0: diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py new file mode 100644 index 000000000..ed29fe915 --- /dev/null +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -0,0 +1,431 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +from functools import partial +import datetime +import time +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import PositionalSharding, PartitionSpec as P +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.checkpointing.flux_checkpointer import ( + FluxCheckpointer, + FLUX_CHECKPOINT, + FLUX_STATE_KEY, + FLUX_STATE_SHARDINGS_KEY, + VAE_STATE_KEY, + VAE_STATE_SHARDINGS_KEY, +) + +from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) + +from maxdiffusion import (max_utils, max_logging) + +from maxdiffusion.train_utils import ( + get_first_step, + load_next_batch, + record_scalar_metrics, + write_metrics, +) + +from maxdiffusion.maxdiffusion_utils import calculate_flux_tflops + +from ..schedulers import (FlaxEulerDiscreteScheduler) + + +class FluxTrainer(FluxCheckpointer): + + def __init__(self, config): + FluxCheckpointer.__init__(self, config, FLUX_CHECKPOINT) + + self.text_encoder_2_learning_rate_scheduler = None + + if config.train_text_encoder: + raise ValueError("this script currently doesn't support training text_encoders") + + def post_training_steps(self, pipeline, params, train_states, msg=""): + pass + + def create_scheduler(self, pipeline, params): + noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32 + ) + noise_scheduler_state = noise_scheduler.set_timesteps( + state=noise_scheduler_state, num_inference_steps=self.config.num_inference_steps, timestep_spacing="flux" + ) + return noise_scheduler, noise_scheduler_state + + def calculate_tflops(self, pipeline): + per_device_tflops = calculate_flux_tflops(self.config, pipeline, self.total_train_batch_size, self.rng, train=True) + max_logging.log(f"JFLUX per device TFLOPS: {per_device_tflops}") + return per_device_tflops + + def start_training(self): + + # Hook + # self.pre_training_steps() + # Load checkpoint - will load or create states + pipeline, params = self.load_checkpoint() + + # create train states + train_states = {} + state_shardings = {} + + # move params to accelerator + encoders_sharding = PositionalSharding(self.devices_array).replicate() + partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding) + pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params) + pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params) + pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params) + pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params) + + vae_state, vae_state_mesh_shardings = self.create_vae_state( + pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False + ) + train_states[VAE_STATE_KEY] = vae_state + state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings + + # Load dataset + data_iterator = self.load_dataset(pipeline, params, train_states) + if self.config.dataset_type == "grain": + data_iterator = self.restore_data_iterator_state(data_iterator) + + # don't need this anymore, clear some memory. + del pipeline.t5_encoder + + # evaluate shapes + + flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( + # ambiguous here, but if params=None + # Then its 1 of 2 scenarios: + # 1. flux state will be loaded directly from orbax + # 2. a new flux is being trained from scratch. + pipeline=pipeline, + params=None, # Params are loaded inside create_flux_state + checkpoint_item_name=FLUX_STATE_KEY, + is_training=True, + ) + flux_state = jax.device_put(flux_state, flux_state_mesh_shardings) + train_states[FLUX_STATE_KEY] = flux_state + state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings + # self.post_training_steps(pipeline, params, train_states, msg="before_training") + + # Create scheduler + noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) + pipeline.scheduler = noise_scheduler + train_states["scheduler"] = noise_scheduler_state + + # Calculate tflops + per_device_tflops = self.calculate_tflops(pipeline) + self.per_device_tflops = per_device_tflops + + data_shardings = self.get_data_shardings() + # Compile train_step + p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings) + # Start training + train_states = self.training_loop( + p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler + ) + # 6. save final checkpoint + # Hook + self.post_training_steps(pipeline, params, train_states, "after_training") + + def get_shaped_batch(self, config, pipeline=None): + """Return the shape of the batch - this is what eval_shape would return for the + output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078. + """ + + scale_factor = 16 # hardcoded in jflux.get_noise + h = config.resolution // scale_factor + w = config.resolution // scale_factor + c = 16 + ph = pw = 2 + batch_image_shape = (self.total_train_batch_size, h * w, c * ph * pw) # b + img_ids_shape = (self.total_train_batch_size, (2 * h // 2) * (2 * w // 2), 3) + text_shape = ( + self.total_train_batch_size, + config.max_sequence_length, + 4096, # Sequence length of text encoder, how to get this programmatically? + ) + text_ids_shape = ( + self.total_train_batch_size, + config.max_sequence_length, + 3, + ) + prompt_embeds_shape = ( + self.total_train_batch_size, + 768, # Sequence length of clip, how to get this programmatically? + ) + input_ids_dtype = self.config.activations_dtype + + shaped_batch = {} + shaped_batch["pixel_values"] = jax.ShapeDtypeStruct(batch_image_shape, input_ids_dtype) + shaped_batch["text_embeds"] = jax.ShapeDtypeStruct(text_shape, input_ids_dtype) + shaped_batch["input_ids"] = jax.ShapeDtypeStruct(text_ids_shape, input_ids_dtype) + shaped_batch["prompt_embeds"] = jax.ShapeDtypeStruct(prompt_embeds_shape, input_ids_dtype) + shaped_batch["img_ids"] = jax.ShapeDtypeStruct(img_ids_shape, input_ids_dtype) + return shaped_batch + + def get_data_shardings(self): + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = { + "text_embeds": data_sharding, + "input_ids": data_sharding, + "prompt_embeds": data_sharding, + "pixel_values": data_sharding, + "img_ids": data_sharding, + } + + return data_sharding + + # adapted from max_utils.tokenize_captions_xl + @staticmethod + def tokenize_captions(examples, caption_column, encoder): + prompt = list(examples[caption_column]) + + prompt_embeds, pooled_prompt_embeds, text_ids = encoder(prompt, prompt) + + examples["text_embeds"] = jnp.float16(prompt_embeds) + examples["input_ids"] = jnp.float16(text_ids) + examples["prompt_embeds"] = jnp.float16(pooled_prompt_embeds) + + return examples + + @staticmethod + def transform_images(examples, image_column, image_resolution, vae_encode, pack_latents, prepare_latent_imgage_ids): + """Preprocess images to latents.""" + images = list(examples[image_column]) + + images = [ + jax.image.resize( + jnp.asarray(image) / 127.5 - 1.0, [image_resolution, image_resolution, 3], method="bilinear", antialias=True + ) + for image in images + ] + + images = jnp.stack(images, axis=0, dtype=jnp.float16) + batch_size = 8 + num_batches = len(images) // batch_size + int(len(images) % batch_size != 0) + encoded_images = [] + for i in range(num_batches): + batch_images = images[i * batch_size : (i + 1) * batch_size] + batch_images = jnp.transpose(batch_images, (0, 3, 1, 2)) + batch_images = vae_encode(batch_images) + batch_images = jnp.transpose(batch_images, (0, 3, 1, 2)) + encoded_images.append(batch_images) + + images = jnp.concatenate(encoded_images, axis=0, dtype=jnp.float16) + b, c, h, w = images.shape + images = pack_latents(latents=images, batch_size=b, num_channels_latents=c, height=h, width=w) + + img_ids = prepare_latent_imgage_ids(h // 2, w // 2) + img_ids = jnp.tile(img_ids, (b, 1, 1)) + + examples["pixel_values"] = jnp.float16(images) + examples["img_ids"] = jnp.float16(img_ids) + + return examples + + def load_dataset(self, pipeline, params, train_states): + config = self.config + total_train_batch_size = self.total_train_batch_size + mesh = self.mesh + + encode_fn = partial( + pipeline.encode_prompt, + clip_tokenizer=pipeline.clip_tokenizer, + t5_tokenizer=pipeline.t5_tokenizer, + clip_text_encoder=pipeline.clip_encoder, + t5_text_encoder=pipeline.t5_encoder, + encode_in_batches=True, + encode_batch_size=16, + ) + pack_latents_p = partial(pipeline.pack_latents) + prepare_latent_image_ids_p = partial(pipeline.prepare_latent_image_ids) + vae_encode_p = partial(pipeline.vae_encode, vae=pipeline.vae, state=train_states["vae_state"]) + + tokenize_fn = partial(FluxTrainer.tokenize_captions, caption_column=config.caption_column, encoder=encode_fn) + image_transforms_fn = partial( + FluxTrainer.transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + vae_encode=vae_encode_p, + pack_latents=pack_latents_p, + prepare_latent_imgage_ids=prepare_latent_image_ids_p, + ) + + data_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + total_train_batch_size, + tokenize_fn=tokenize_fn, + image_transforms_fn=image_transforms_fn, + ) + + return data_iterator + + def compile_train_step(self, pipeline, params, train_states, state_shardings, data_shardings): + self.rng, train_rngs = jax.random.split(self.rng) + guidance_vec = jnp.full((self.total_train_batch_size,), self.config.guidance_scale, dtype=self.config.activations_dtype) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + p_train_step = jax.jit( + partial( + _train_step, + guidance_vec=guidance_vec, + pipeline=pipeline, + scheduler=train_states["scheduler"], + config=self.config, + ), + in_shardings=( + state_shardings["flux_state_shardings"], + data_shardings, + None, + ), + out_shardings=(state_shardings["flux_state_shardings"], None, None), + donate_argnums=(0,), + ) + max_logging.log("Precompiling...") + s = time.time() + dummy_batch = self.get_shaped_batch(self.config, pipeline) + p_train_step = p_train_step.lower(train_states[FLUX_STATE_KEY], dummy_batch, train_rngs) + p_train_step = p_train_step.compile() + max_logging.log(f"Compile time: {(time.time() - s )}") + return p_train_step + + def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): + + writer = max_utils.initialize_summary_writer(self.config) + flux_state = train_states[FLUX_STATE_KEY] + num_model_parameters = max_utils.calculate_num_params_from_pytree(flux_state.params) + + max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) + max_utils.add_config_to_summary_writer(self.config, writer) + + if jax.process_index() == 0: + max_logging.log("***** Running training *****") + max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.total_train_batch_size}") + max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") + + last_step_completion = datetime.datetime.now() + local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None + running_gcs_metrics = [] if self.config.gcs_metrics else None + example_batch = None + + first_profiling_step = self.config.skip_first_n_steps_for_profiler + if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + last_profiling_step = np.clip( + first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 + ) + start_step = get_first_step(train_states[FLUX_STATE_KEY]) + _, train_rngs = jax.random.split(self.rng) + times = [] + for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + + example_batch = load_next_batch(data_iterator, example_batch, self.config) + example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()} + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + with self.mesh: + flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs) + + samples_count = self.total_train_batch_size * (step + 1) + new_time = datetime.datetime.now() + + record_scalar_metrics( + train_metric, new_time - last_step_completion, self.per_device_tflops, unet_learning_rate_scheduler(step) + ) + if self.config.write_metrics: + write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + times.append(new_time - last_step_completion) + last_step_completion = new_time + + if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: + max_logging.log(f"Saving checkpoint for step {step}") + train_states[FLUX_STATE_KEY] = flux_state + self.save_checkpoint(step, pipeline, train_states) + + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + + train_states[FLUX_STATE_KEY] = flux_state + if len(times) > 0: + max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}") + if self.config.save_final_checkpoint: + max_logging.log(f"Saving checkpoint for step {step}") + self.save_checkpoint(step, pipeline, train_states) + self.checkpoint_manager.wait_until_finished() + return train_states + + +def _train_step(flux_state, batch, train_rng, guidance_vec, pipeline, scheduler, config): + _, gen_dummy_rng = jax.random.split(train_rng) + sample_rng, timestep_bias_rng, new_train_rng = jax.random.split(gen_dummy_rng, 3) + state_params = {FLUX_STATE_KEY: flux_state.params} + + def compute_loss(state_params): + latents = batch["pixel_values"] + text_embeds_ids = batch["input_ids"] + text_embeds = batch["text_embeds"] + prompt_embeds = batch["prompt_embeds"] + img_ids = batch["img_ids"] + + # Sample noise that we'll add to the latents + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal( + key=noise_rng, + shape=latents.shape, + dtype=latents.dtype, + ) + # Sample a random timestep for each image + bsz = latents.shape[0] + timesteps = jax.random.randint(timestep_rng, shape=(bsz,), minval=0, maxval=len(scheduler.timesteps) - 1) + noisy_latents = pipeline.scheduler.add_noise(scheduler, latents, noise, timesteps, flux=True) + + model_pred = pipeline.flux.apply( + {"params": state_params[FLUX_STATE_KEY]}, + hidden_states=noisy_latents, + img_ids=img_ids, + encoder_hidden_states=text_embeds, + txt_ids=text_embeds_ids, + timestep=scheduler.timesteps[timesteps], + guidance=guidance_vec, + pooled_projections=prompt_embeds, + ).sample + + target = noise - latents + loss = (target - model_pred) ** 2 + + loss = jnp.mean(loss) + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state_params) + + new_state = flux_state.apply_gradients(grads=grad[FLUX_STATE_KEY]) + + metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} + + return new_state, metrics, new_train_rng