From abc6f6bbb65c92c962599042112351d4e5171172 Mon Sep 17 00:00:00 2001 From: ksikiric Date: Tue, 11 Feb 2025 08:09:36 +0000 Subject: [PATCH 1/9] Added training code, loss and results are stable --- .../checkpointing/checkpointing_utils.py | 5 +- .../checkpointing/flux_checkpointer.py | 210 ++++++++ src/maxdiffusion/configs/base_flux_dev.yml | 19 +- .../configs/base_flux_schnell.yml | 4 +- src/maxdiffusion/maxdiffusion_utils.py | 55 +++ src/maxdiffusion/pipelines/flux/__init__.py | 5 + .../pipelines/flux/flux_pipeline.py | 383 +++++++++++++++ .../scheduling_euler_discrete_flax.py | 22 +- src/maxdiffusion/train_flux.py | 49 ++ src/maxdiffusion/train_utils.py | 9 +- src/maxdiffusion/trainers/flux_trainer.py | 449 ++++++++++++++++++ 11 files changed, 1190 insertions(+), 20 deletions(-) create mode 100644 src/maxdiffusion/checkpointing/flux_checkpointer.py create mode 100644 src/maxdiffusion/pipelines/flux/__init__.py create mode 100644 src/maxdiffusion/pipelines/flux/flux_pipeline.py create mode 100644 src/maxdiffusion/train_flux.py create mode 100644 src/maxdiffusion/trainers/flux_trainer.py diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index b8710e1a6..d771a15de 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -33,6 +33,7 @@ STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT" STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT" +FLUX_CHECKPOINT = "FLUX_CHECKPOINT" def create_orbax_checkpoint_manager( @@ -66,7 +67,7 @@ def create_orbax_checkpoint_manager( "text_encoder_state", "tokenizer_config", ) - if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT: + if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT: item_names += ( "text_encoder_2_state", "text_encoder_2_config", @@ -117,7 +118,7 @@ def load_stable_diffusion_configs( "tokenizer_config": orbax.checkpoint.args.JsonRestore(), } - if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT: + if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT: restore_args["text_encoder_2_config"] = orbax.checkpoint.args.JsonRestore() return (checkpoint_manager.restore(step, args=orbax.checkpoint.args.Composite(**restore_args)), None) diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py new file mode 100644 index 000000000..4da642beb --- /dev/null +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -0,0 +1,210 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from abc import ABC +from contextlib import nullcontext +import os +import json +import functools +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P +import orbax.checkpoint as ocp +import grain.python as grain +from maxdiffusion import ( + max_utils, + FlaxAutoencoderKL, + max_logging, +) +from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel +from ..pipelines.flux.flux_pipeline import FluxPipeline + +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) + +from maxdiffusion.checkpointing.checkpointing_utils import ( + create_orbax_checkpoint_manager, + load_stable_diffusion_configs, +) +from maxdiffusion.models.flux.util import load_flow_model + +FLUX_CHECKPOINT = "FLUX_CHECKPOINT" +_CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS" +_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX" + + +class FluxCheckpointer(ABC): + + def __init__(self, config, checkpoint_type): + self.config = config + self.checkpoint_type = checkpoint_type + self.checkpoint_format = None + + self.rng = jax.random.PRNGKey(self.config.seed) + self.devices_array = max_utils.create_device_mesh(config) + self.mesh = Mesh(self.devices_array, self.config.mesh_axes) + self.total_train_batch_size = self.config.total_train_batch_size + + self.checkpoint_manager = create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, + ) + + def _create_optimizer(self, config, learning_rate): + + learning_rate_scheduler = max_utils.create_learning_rate_schedule( + learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps + ) + tx = max_utils.create_optimizer(config, learning_rate_scheduler) + return tx, learning_rate_scheduler + + def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training): + transformer = pipeline.flux + + tx, learning_rate_scheduler = None, None + if is_training: + learning_rate = self.config.learning_rate + + tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate) + + transformer_eval_params = transformer.init_weights( + rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True + ) + + transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu") + + weights_init_fn = functools.partial(pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length) + flux_state, state_mesh_shardings = max_utils.setup_initial_state( + model=pipeline.flux, + tx=tx, + config=self.config, + mesh=self.mesh, + weights_init_fn=weights_init_fn, + model_params=None, + checkpoint_manager=self.checkpoint_manager, + checkpoint_item=checkpoint_item_name, + training=is_training, + ) + if not self.config.train_new_flux: + flux_state = flux_state.replace(params=transformer_params) + flux_state = jax.device_put(flux_state, state_mesh_shardings) + return flux_state, state_mesh_shardings, learning_rate_scheduler + + def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): + + # Currently VAE training is not supported. + weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng) + return max_utils.setup_initial_state( + model=pipeline.vae, + tx=None, + config=self.config, + mesh=self.mesh, + weights_init_fn=weights_init_fn, + model_params=params, + checkpoint_manager=self.checkpoint_manager, + checkpoint_item=checkpoint_item_name, + training=is_training, + ) + + def restore_data_iterator_state(self, data_iterator): + if ( + self.config.dataset_type == "grain" + and data_iterator is not None + and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists() + ): + max_logging.log("Restoring data iterator from checkpoint") + restored = self.checkpoint_manager.restore( + self.checkpoint_manager.latest_step(), + args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)), + ) + data_iterator.local_iterator = restored["iter"] + else: + max_logging.log("data iterator checkpoint not found") + return data_iterator + + def _get_pipeline_class(self): + return FluxPipeline + + def _set_checkpoint_format(self, checkpoint_format): + self.checkpoint_format = checkpoint_format + + def save_checkpoint(self, train_step, pipeline, train_states): + items = { + "config": ocp.args.JsonSave({"model_name": self.config.model_name}), + } + + items["flux_state"] = ocp.args.PyTreeSave(train_states["flux_state"]) + + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + + def load_params(self, step=None): + + self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX + + def load_checkpoint(self, step=None, scheduler_class=None): + clip_encoder = FlaxCLIPTextModel.from_pretrained( + self.config.clip_model_name_or_path, dtype=self.config.weights_dtype + ) + clip_tokenizer = CLIPTokenizer.from_pretrained( + self.config.clip_model_name_or_path, max_length=77, use_fast=True + ) + + t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) + t5_tokenizer = AutoTokenizer.from_pretrained( + self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True + ) + encoders_sharding = PositionalSharding(self.devices_array).replicate() + partial_device_put_replicated = functools.partial(max_utils.device_put_replicated, sharding=encoders_sharding) + clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_encoder.params) + clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_encoder.params) + t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) + t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) + + + + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + self.config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" + ) + + flash_block_sizes = max_utils.get_flash_block_sizes(self.config) + # loading from pretrained here causes a crash when trying to compile the model + # Failed to load HSACO: HIP_ERROR_NoBinaryForGpu + transformer = FluxTransformer2DModel.from_config( + self.config.pretrained_model_name_or_path, + subfolder="transformer", + mesh=self.mesh, + split_head_dim=self.config.split_head_dim, + attention_kernel=self.config.attention, + flash_block_sizes=flash_block_sizes, + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + precision=max_utils.get_precision(self.config), + ) + + return FluxPipeline(t5_encoder, + clip_encoder, + vae, + t5_tokenizer, + clip_tokenizer, + transformer, + None, + dtype=self.config.activations_dtype, + mesh=self.mesh, + config=self.config, + rng=self.rng), vae_params + diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 944153d64..8c992bffe 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -77,7 +77,7 @@ norm_num_groups: 32 # If train_new_unet, unet weights will be randomly initialized to train the unet from scratch # else they will be loaded from pretrained_model_name_or_path -train_new_unet: False +train_new_flux: False # train text_encoder - Currently not supported for SDXL train_text_encoder: False @@ -115,7 +115,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu' # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] @@ -177,7 +177,7 @@ hf_train_files: '' hf_access_token: '' image_column: 'image' caption_column: 'text' -resolution: 1024 +resolution: 512 center_crop: False random_flip: False # If cache_latents_text_encoder_outputs is True @@ -193,17 +193,17 @@ checkpoint_every: -1 enable_single_replica_ckpt_restoring: False # Training loop -learning_rate: 4.e-7 +learning_rate: 1.e-5 scale_lr: False max_train_samples: -1 # max_train_steps takes priority over num_train_epochs. -max_train_steps: 200 +max_train_steps: 1500 num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' per_device_batch_size: 1 -warmup_steps_fraction: 0.0 +warmup_steps_fraction: 0.1 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. # However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before @@ -213,7 +213,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 1.e-2 # AdamW Weight decay +adam_weight_decay: 0 # AdamW Weight decay max_grad_norm: 1.0 enable_profiler: False @@ -223,14 +223,15 @@ skip_first_n_steps_for_profiler: 5 profiler_steps: 10 # Generation parameters -prompt: "A magical castle in the middle of a forest, artistic drawing" -prompt_2: "A magical castle in the middle of a forest, artistic drawing" +prompt: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet." +prompt_2: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet." negative_prompt: "purple, red" do_classifier_free_guidance: True guidance_scale: 3.5 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 50 +save_final_checkpoint: False # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 3106255a9..ebc901da3 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -123,7 +123,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu' # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] @@ -238,7 +238,7 @@ do_classifier_free_guidance: True guidance_scale: 0.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 4 +num_inference_steps: 50 # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 05f9802a4..bd1f13746 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -255,6 +255,61 @@ def calculate_unet_tflops(config, pipeline, batch_size, rngs, train): / jax.local_device_count() ) +def get_dummy_flux_inputs(config, pipeline, batch_size): + """Returns randomly initialized flux inputs.""" + latents, latents_ids = pipeline.prepare_latents( + batch_size=batch_size, + num_channels_latents=pipeline.flux.in_channels // 4, + height=config.resolution, + width=config.resolution, + vae_scale_factor=pipeline.vae_scale_factor, + dtype=config.activations_dtype, + rng=pipeline.rng + ) + guidance_vec = jnp.asarray([config.guidance_scale] * batch_size, dtype=config.activations_dtype) + + timesteps = jnp.ones((batch_size,), dtype=config.weights_dtype) + t5_hidden_states_shape = ( + batch_size, + config.max_sequence_length, + 4096, + ) + t5_hidden_states = jnp.zeros(t5_hidden_states_shape, dtype=config.weights_dtype) + t5_ids = jnp.zeros((batch_size, t5_hidden_states.shape[1], 3), dtype=config.weights_dtype) + + clip_hidden_states_shape = ( + batch_size, + 768, + ) + clip_hidden_states = jnp.zeros(clip_hidden_states_shape, dtype=config.weights_dtype) + + return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) + + +def calculate_flux_tflops(config, pipeline, batch_size, rngs, train): + """ + Calculates jflux tflops. + batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't + cache the compilation when flash is enabled. + """ + + (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs(config, pipeline, batch_size) + return ( + max_utils.calculate_model_tflops( + pipeline.flux, + rngs, + train, + hidden_states=latents, + img_ids=latents_ids, + encoder_hidden_states=t5_hidden_states, + txt_ids=t5_ids, + pooled_projections=clip_hidden_states, + timestep=timesteps, + guidance=guidance_vec, + ) + / jax.local_device_count() + ) + def tokenize_captions(examples, caption_column, tokenizer, input_ids_key="input_ids", p_encode=None): """Tokenize captions for sd1.x,sd2.x models.""" diff --git a/src/maxdiffusion/pipelines/flux/__init__.py b/src/maxdiffusion/pipelines/flux/__init__.py new file mode 100644 index 000000000..076f6350e --- /dev/null +++ b/src/maxdiffusion/pipelines/flux/__init__.py @@ -0,0 +1,5 @@ +_import_structure = { "pipeline_jflux" : "JfluxPipeline" } + +from .flux_pipeline import ( + FluxPipeline, +) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py new file mode 100644 index 000000000..386944b41 --- /dev/null +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -0,0 +1,383 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +import math +from flax.core.frozen_dict import FrozenDict +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer) +from einops import rearrange, repeat +from jax.typing import DTypeLike +from chex import Array + +from flax.linen import partitioning as nn_partitioning + +from maxdiffusion.utils import logging + +from ...models import FlaxAutoencoderKL +from ...schedulers import ( + FlaxEulerDiscreteScheduler +) +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + + +class FluxPipeline(FlaxDiffusionPipeline): + + def __init__( + self, + t5_encoder: FlaxCLIPTextModel, + clip_encoder: FlaxCLIPTextModel, + vae: FlaxAutoencoderKL, + t5_tokenizer: FlaxT5EncoderModel, + clip_tokenizer: CLIPTokenizer, + flux: FluxTransformer2DModel, + scheduler: FlaxEulerDiscreteScheduler, + dtype: jnp.dtype = jnp.float32, + mesh: Optional = None, + config: Optional = None, + rng: Optional = None, + ): + super().__init__() + self.dtype = dtype + self.register_modules( + vae=vae, + t5_encoder=t5_encoder, + clip_encoder=clip_encoder, + t5_tokenizer=t5_tokenizer, + clip_tokenizer=clip_tokenizer, + flux=flux, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + self.mesh = mesh + self._config = config + self.rng = rng + + def create_noise( + self, + num_samples: int, + height: int, + width: int, + dtype: DTypeLike, + seed: jax.random.PRNGKey, + ): + return jax.random.normal( + key=seed, + shape=(num_samples, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16)), + dtype=dtype, + ) + + def unpack(self, x: Array, height: int, width: int) -> Array: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + + def vae_decode(self, latents, vae, state, config): + img = self.unpack(x=latents, height=config.resolution, width=config.resolution) + img = img / vae.config.scaling_factor + vae.config.shift_factor + img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample + return img + + def vae_encode(self, latents, vae, state): + img = vae.apply( + {"params": state.params}, + latents, + deterministic=True, + method=vae.encode).latent_dist.mode() + img = vae.config.scaling_factor * (img - vae.config.shift_factor) + return img + + # this is the reverse of the unpack function + def pack_latents( + self, + latents: Array, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + ): + latents = jnp.reshape(latents, (batch_size, num_channels_latents, height // 2, 2, width // 2, 2)) + latents = jnp.permute_dims(latents, (0, 2, 4, 1, 3, 5)) + latents = jnp.reshape(latents, (batch_size, (height // 2) * (width // 2), num_channels_latents * 4)) + + return latents + + def prepare_latents( + self, batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array + ): + + # VAE applies 8x compression on images but we must also account for packing which + # requires latent height and width to be divisibly by 2. + height = 2 * (height // (vae_scale_factor * 2)) + width = 2 * (width // (vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + latents = jax.random.normal(rng, shape=shape, dtype=jnp.bfloat16) + # pack latents + latents = self.pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self.prepare_latent_image_ids(height // 2, width // 2) + latent_image_ids = jnp.tile(latent_image_ids, (batch_size, 1, 1)) + + return latents, latent_image_ids + + + def prepare_latent_image_ids(self, height, width): + latent_image_ids = jnp.zeros((height, width, 3)) + latent_image_ids = latent_image_ids.at[..., 1].set(latent_image_ids[..., 1] + jnp.arange(height)[:, None]) + latent_image_ids = latent_image_ids.at[..., 2].set(latent_image_ids[..., 2] + jnp.arange(width)[None, :]) + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) + + return latent_image_ids.astype(jnp.bfloat16) + + def get_clip_prompt_embeds( + self, prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="np", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False) + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1) + prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1)) + return prompt_embeds + + + def get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + tokenizer: AutoTokenizer, + text_encoder: FlaxT5EncoderModel, + max_sequence_length: int = 512, + ): + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + truncation=True, + max_length=max_sequence_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"] + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.astype(dtype) + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + + return prompt_embeds + + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + clip_tokenizer: CLIPTokenizer, + clip_text_encoder: FlaxCLIPTextModel, + t5_tokenizer: AutoTokenizer, + t5_text_encoder: FlaxT5EncoderModel, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + ): + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_2 = prompt or prompt_2 + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + pooled_prompt_embeds = self.get_clip_prompt_embeds( + prompt=prompt, num_images_per_prompt=num_images_per_prompt, tokenizer=clip_tokenizer, text_encoder=clip_text_encoder + ) + + prompt_embeds = self.get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + tokenizer=t5_tokenizer, + text_encoder=t5_text_encoder, + max_sequence_length=max_sequence_length, + ) + + text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) + return prompt_embeds, pooled_prompt_embeds, text_ids + + + def _generate( + self, + flux_params, + vae_params, + latents, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, + timesteps, + ): + + def loop_body( + step, + args, + transformer, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, + ): + latents, state, c_ts, p_ts = args + latents_dtype = latents.dtype + t_curr = c_ts[step] + t_prev = p_ts[step] + t_vec = jnp.full((latents.shape[0],), t_curr, dtype=latents.dtype) + pred = transformer.apply( + {"params": state.params}, + hidden_states=latents, + img_ids=latent_image_ids, + encoder_hidden_states=prompt_embeds, + txt_ids=txt_ids, + timestep=t_vec, + guidance=guidance_vec, + pooled_projections=vec, + ).sample + latents = latents + (t_prev - t_curr) * pred + latents = jnp.array(latents, dtype=latents_dtype) + return latents, state, c_ts, p_ts + + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + + loop_body_p = partial( + loop_body, + transformer=self.flux, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=txt_ids, + vec=vec, + guidance_vec=guidance_vec, + ) + + vae_decode_p = partial(self.vae_decode, vae=self.vae, state=vae_params, config=self._config) + + with self.mesh, nn_partitioning.axis_rules(self._config.logical_axis_rules): + latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, flux_params, c_ts, p_ts)) + image = vae_decode_p(latents) + return image + + def __call__( + self, + timesteps: int, + flux_params, + vae_params + ): + r""" + The call function to the pipeline for generation. + + Args: + txt: jnp.array, + txt_ids: jnp.array, + vec: jnp.array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + img: Optional[jnp.ndarray] = None, + shift: bool = False, + jit (`bool`, defaults to `False`): + + Examples: + + """ + + if isinstance(timesteps, int): + timesteps = jnp.linspace(1, 0, timesteps + 1) + + global_batch_size = 1 * jax.local_device_count() + + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=self._config.prompt, + prompt_2=self._config.prompt_2, + clip_tokenizer=self.clip_tokenizer, + clip_text_encoder=self.clip_encoder, + t5_tokenizer=self.t5_tokenizer, + t5_text_encoder=self.t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=self._config.max_sequence_length, + ) + + num_channels_latents = self.flux.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=self._config.resolution, + width=self._config.resolution, + dtype=jnp.bfloat16, + vae_scale_factor=self.vae_scale_factor, + rng=self.rng, + ) + + #timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16) + guidance = jnp.asarray([self._config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) + + images = self._generate( + flux_params, + vae_params, + latents, + latent_image_ids, + prompt_embeds, + text_ids, + pooled_prompt_embeds, + guidance, + timesteps, + ) + + images = images + return images diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index bc9013cbf..ea1694af2 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -17,6 +17,7 @@ import flax import jax.numpy as jnp +import max_logging from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( @@ -144,7 +145,7 @@ def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndar return sample def set_timesteps( - self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = (), timestep_spacing: str = "" ) -> EulerDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -155,17 +156,22 @@ def set_timesteps( num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ + if timestep_spacing == "": + timestep_spacing = self.config.timestep_spacing - if self.config.timestep_spacing == "linspace": + if timestep_spacing == "linspace": timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) - elif self.config.timestep_spacing == "leading": + elif timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // num_inference_steps timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += 1 - elif self.config.timestep_spacing == "trailing": + elif timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps timesteps = (jnp.arange(self.config.num_train_timesteps, 0, -step_ratio)).round() timesteps -= 1 + elif timestep_spacing == "flux": + max_logging.log("Using flux timestep spacing") + timesteps = jnp.linspace(1, 0, num_inference_steps + 1) else: raise ValueError( f"timestep_spacing must be one of ['linspace', 'leading', 'trailing'], got {self.config.timestep_spacing}" @@ -250,7 +256,15 @@ def add_noise( original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, + flux: bool=False, ) -> jnp.ndarray: + + if flux: + t = state.timesteps[timesteps] + t = t[:, None, None] + noisy_samples = t * noise + (1-t) * original_samples + return noisy_samples + sigma = state.sigmas[timesteps].flatten() sigma = broadcast_to_shape_from_left(sigma, noise.shape) diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py new file mode 100644 index 000000000..c432c0e69 --- /dev/null +++ b/src/maxdiffusion/train_flux.py @@ -0,0 +1,49 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +from typing import Sequence + +import jax +from absl import app +from maxdiffusion import ( + max_logging, + pyconfig, + mllog_utils, +) + +from maxdiffusion.trainers.flux_trainer import FluxTrainer + +from maxdiffusion.train_utils import ( + validate_train_config, +) + + +def train(config): + trainer = FluxTrainer(config) + trainer.start_training() + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + config = pyconfig.config + mllog_utils.train_init_start(config) + validate_train_config(config) + max_logging.log(f"Found {jax.device_count()} devices.") + train(config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 648c10d7e..67bbd2197 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -111,9 +111,12 @@ def write_metrics_to_tensorboard(writer, metrics, step, config): full_log = step % config.log_period == 0 if jax.process_index() == 0: max_logging.log( - f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " - f"loss: {metrics['scalar']['learning/loss']:.3f}" + "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( + step, + metrics['scalar']['perf/step_time_seconds'], + metrics['scalar']['perf/per_device_tflops_per_sec'], + float(metrics['scalar']['learning/loss']) + ) ) if full_log and jax.process_index() == 0: diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py new file mode 100644 index 000000000..c30db18f6 --- /dev/null +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -0,0 +1,449 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +from functools import partial +import datetime +import time +import numpy as np +import jax +import optax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.checkpointing.flux_checkpointer import (FluxCheckpointer, FLUX_CHECKPOINT) + +from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) + +from maxdiffusion import (max_utils, max_logging) +from PIL import Image + +from maxdiffusion.train_utils import ( + generate_timestep_weights, + get_first_step, + load_next_batch, + record_scalar_metrics, + write_metrics, +) + +from maxdiffusion.maxdiffusion_utils import calculate_flux_tflops + +from ..schedulers import ( + FlaxEulerDiscreteScheduler +) + + +class FluxTrainer(FluxCheckpointer): + + def __init__(self, config): + FluxCheckpointer.__init__(self, config, FLUX_CHECKPOINT) + + self.text_encoder_2_learning_rate_scheduler = None + + if config.train_text_encoder: + raise ValueError("this script currently doesn't support training text_encoders") + + def post_training_steps(self, pipeline, params, train_states, msg=""): + imgs = pipeline(flux_params=train_states["flux_state"], + timesteps=50, + vae_params=train_states["vae_state"]) + imgs = np.array(imgs) + imgs = (imgs * 0.5 + 0.5).clip(0, 1) + imgs = np.transpose(imgs, (0, 2, 3, 1)) + imgs = np.uint8(imgs * 255) + for i, image in enumerate(imgs): + Image.fromarray(image).save(f"flux_{msg}_{i}.png") + + def create_scheduler(self, pipeline, params): + noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32 + ) + noise_scheduler_state = noise_scheduler.set_timesteps( + state=noise_scheduler_state, + num_inference_steps=self.config.num_inference_steps, + timestep_spacing="flux") + return noise_scheduler, noise_scheduler_state + + def calculate_tflops(self, pipeline): + per_device_tflops = calculate_flux_tflops( + self.config, pipeline, self.total_train_batch_size, self.rng, train=True + ) + max_logging.log(f"JFLUX per device TFLOPS: {per_device_tflops}") + return per_device_tflops + + def start_training(self): + + # Hook + #self.pre_training_steps() + # Load checkpoint - will load or create states + pipeline, params = self.load_checkpoint() + + # create train states + train_states = {} + state_shardings = {} + vae_state, vae_state_mesh_shardings = self.create_vae_state( + pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False + ) + train_states["vae_state"] = vae_state + state_shardings["vae_state_shardings"] = vae_state_mesh_shardings + + # Load dataset + data_iterator = self.load_dataset(pipeline, params, train_states) + if self.config.dataset_type == "grain": + data_iterator = self.restore_data_iterator_state(data_iterator) + + + flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( + # ambiguous here, but if self.params.get("unet") doesn't exist + # Then its 1 of 2 scenarios: + # 1. unet state will be loaded directly from orbax + # 2. a new unet is being trained from scratch. + pipeline=pipeline, + params=None, # Params are loaded inside create_flux_state + checkpoint_item_name="flux_state", + is_training=True, + ) + train_states["flux_state"] = flux_state + state_shardings["flux_state_shardings"] = flux_state_mesh_shardings + self.post_training_steps(pipeline, params, train_states, msg="before_training") + + # Create scheduler + noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) + pipeline.scheduler = noise_scheduler + train_states["scheduler"] = noise_scheduler_state + + # Calculate tflops + per_device_tflops = self.calculate_tflops(pipeline) + self.per_device_tflops = per_device_tflops + + data_shardings = self.get_data_shardings() + # Compile train_step + p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings) + # Start training + train_states = self.training_loop( + p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler + ) + # 6. save final checkpoint + # Hook + self.post_training_steps(pipeline, params, train_states, "after_training") + + def get_shaped_batch(self, config, pipeline=None): + """Return the shape of the batch - this is what eval_shape would return for the + output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078. + """ + + scale_factor = 16 # hardcoded in jflux.get_noise + h = config.resolution // scale_factor + w = config.resolution // scale_factor + c = 16 + ph = pw = 2 + batch_image_shape = ( + self.total_train_batch_size, # b + h*w, + c*ph*pw + ) + img_ids_shape = ( + self.total_train_batch_size, + (2*h // 2) * (2*w // 2), + 3 + ) + text_shape = ( + self.total_train_batch_size, + config.max_sequence_length, + 4096, # Sequence length of text encoder, how to get this programmatically? + ) + text_ids_shape = ( + self.total_train_batch_size, + config.max_sequence_length, + 3, + ) + prompt_embeds_shape = ( + self.total_train_batch_size, + 768, # Sequence length of clip, how to get this programmatically? + ) + input_ids_dtype = self.config.activations_dtype + + shaped_batch = {} + shaped_batch["pixel_values"] = jax.ShapeDtypeStruct(batch_image_shape, input_ids_dtype) + shaped_batch["text_embeds"] = jax.ShapeDtypeStruct(text_shape, input_ids_dtype) + shaped_batch["input_ids"] = jax.ShapeDtypeStruct(text_ids_shape, input_ids_dtype) + shaped_batch["prompt_embeds"] = jax.ShapeDtypeStruct(prompt_embeds_shape, input_ids_dtype) + shaped_batch["img_ids"] = jax.ShapeDtypeStruct(img_ids_shape, input_ids_dtype) + return shaped_batch + + def get_data_shardings(self): + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = {"text_embeds": data_sharding, + "input_ids": data_sharding, + "prompt_embeds": data_sharding, + "pixel_values": data_sharding, + "img_ids": data_sharding} + + return data_sharding + + # adapted from max_utils.tokenize_captions_xl + @staticmethod + def tokenize_captions(examples, caption_column, encoder): + prompt = list(examples[caption_column]) + + prompt_embeds, pooled_prompt_embeds, text_ids = encoder(prompt, prompt) + + examples["text_embeds"] = jnp.float16(prompt_embeds) + examples["input_ids"] = jnp.float16(text_ids) + examples["prompt_embeds"] = jnp.float16(pooled_prompt_embeds) + + return examples + + @staticmethod + def transform_images( + examples, + image_column, + image_resolution, + vae_encode, + pack_latents, + prepare_latent_imgage_ids + ): + """Preprocess images to latents.""" + images = list(examples[image_column]) + + images = [ + jax.image.resize( + jnp.asarray(image) / 127.5 - 1.0, + [image_resolution, image_resolution, 3], + method="bilinear", + antialias=True + ) + for image in images + ] + + images = jnp.stack(images, axis=0, dtype=jnp.float16) + batch_size = 8 + num_batches = len(images) // batch_size + int(len(images) % batch_size != 0) + encoded_images = [] + for i in range(num_batches): + batch_images = images[i * batch_size:(i + 1) * batch_size] + batch_images = jnp.transpose(batch_images, (0, 3, 1, 2)) + batch_images = vae_encode(batch_images) + batch_images = jnp.transpose(batch_images, (0, 3, 1, 2)) + encoded_images.append(batch_images) + + images = jnp.concatenate(encoded_images, axis=0, dtype=jnp.float16) + b, c, h, w = images.shape + images = pack_latents( + latents=images, + batch_size=b, + num_channels_latents=c, + height=h, + width=w) + + img_ids = prepare_latent_imgage_ids(h // 2, w // 2) + img_ids = jnp.tile(img_ids, (b, 1, 1)) + + + examples["pixel_values"] = jnp.float16(images) + examples["img_ids"] = jnp.float16(img_ids) + + return examples + + def load_dataset(self, pipeline, params, train_states): + config = self.config + total_train_batch_size = self.total_train_batch_size + mesh = self.mesh + + encode_fn = partial(pipeline.encode_prompt, + clip_tokenizer=pipeline.clip_tokenizer, + t5_tokenizer=pipeline.t5_tokenizer, + clip_text_encoder=pipeline.clip_encoder, + t5_text_encoder=pipeline.t5_encoder + ) + pack_latents_p = partial(pipeline.pack_latents) + prepare_latent_image_ids_p = partial(pipeline.prepare_latent_image_ids) + vae_encode_p = partial(pipeline.vae_encode, vae=pipeline.vae, state=train_states["vae_state"]) + + + tokenize_fn = partial( + FluxTrainer.tokenize_captions, + caption_column=config.caption_column, + encoder=encode_fn + ) + image_transforms_fn = partial( + FluxTrainer.transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + vae_encode=vae_encode_p, + pack_latents=pack_latents_p, + prepare_latent_imgage_ids=prepare_latent_image_ids_p + ) + + data_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + total_train_batch_size, + tokenize_fn=tokenize_fn, + image_transforms_fn=image_transforms_fn, + ) + + return data_iterator + + def compile_train_step(self, pipeline, params, train_states, state_shardings, data_shardings): + self.rng, train_rngs = jax.random.split(self.rng) + guidance_vec = jnp.full((self.total_train_batch_size,), self.config.guidance_scale, dtype=self.config.activations_dtype) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + p_train_step = jax.jit( + partial(_train_step, guidance_vec=guidance_vec, pipeline=pipeline, scheduler=train_states["scheduler"], config=self.config), + in_shardings=( + state_shardings["flux_state_shardings"], + data_shardings, + None, + ), + out_shardings=(state_shardings["flux_state_shardings"], None, None), + donate_argnums=(0,), + ) + max_logging.log("Precompiling...") + s = time.time() + dummy_batch = self.get_shaped_batch(self.config, pipeline) + p_train_step = p_train_step.lower(train_states["flux_state"], dummy_batch, train_rngs) + p_train_step = p_train_step.compile() + max_logging.log(f"Compile time: {(time.time() - s )}") + return p_train_step + + def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): + + writer = max_utils.initialize_summary_writer(self.config) + flux_state = train_states["flux_state"] + num_model_parameters = max_utils.calculate_num_params_from_pytree(flux_state.params) + + max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) + max_utils.add_config_to_summary_writer(self.config, writer) + + if jax.process_index() == 0: + max_logging.log("***** Running training *****") + max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.total_train_batch_size}") + max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") + + last_step_completion = datetime.datetime.now() + local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None + running_gcs_metrics = [] if self.config.gcs_metrics else None + example_batch = None + + first_profiling_step = self.config.skip_first_n_steps_for_profiler + if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + last_profiling_step = np.clip( + first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 + ) + start_step = get_first_step(train_states["flux_state"]) + _, train_rngs = jax.random.split(self.rng) + times = [] + for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + + example_batch = load_next_batch(data_iterator, example_batch, self.config) + example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()} + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + with self.mesh: + flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs) + + samples_count = self.total_train_batch_size * (step + 1) + new_time = datetime.datetime.now() + + record_scalar_metrics( + train_metric, new_time - last_step_completion, self.per_device_tflops, unet_learning_rate_scheduler(step) + ) + if self.config.write_metrics: + write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + times.append(new_time - last_step_completion) + last_step_completion = new_time + + if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: + max_logging.log(f"Saving checkpoint for step {step}") + train_states["flux_state"] = flux_state + self.save_checkpoint(step, pipeline, train_states) + + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + + if self.config.write_metrics: + write_metrics( + writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config + ) + + train_states["flux_state"] = flux_state + max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}") + if self.config.save_final_checkpoint: + max_logging.log(f"Saving checkpoint for step {step}") + self.save_checkpoint(step, pipeline, train_states) + self.checkpoint_manager.wait_until_finished() + return train_states + + +def _train_step(flux_state, batch, train_rng, guidance_vec, pipeline, scheduler, config): + _, gen_dummy_rng = jax.random.split(train_rng) + sample_rng, timestep_bias_rng, new_train_rng = jax.random.split(gen_dummy_rng, 3) + state_params = {"flux_state": flux_state.params} + + def compute_loss(state_params): + latents = batch["pixel_values"] + text_embeds_ids = batch["input_ids"] + text_embeds = batch["text_embeds"] + prompt_embeds = batch["prompt_embeds"] + img_ids = batch["img_ids"] + + # Sample noise that we'll add to the latents + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal( + key=noise_rng, + shape=latents.shape, + dtype=latents.dtype, + ) + # Sample a random timestep for each image + bsz = latents.shape[0] + timesteps = jax.random.randint(timestep_rng, shape=(bsz,), minval=0, maxval=len(scheduler.timesteps)-1) + noisy_latents = pipeline.scheduler.add_noise(scheduler, latents, noise, timesteps, flux=True) + + model_pred = pipeline.flux.apply( + {"params": state_params["flux_state"]}, + hidden_states=noisy_latents, + img_ids=img_ids, + encoder_hidden_states=text_embeds, + txt_ids=text_embeds_ids, + timestep=scheduler.timesteps[timesteps], + guidance=guidance_vec, + pooled_projections=prompt_embeds, + ).sample + + target = noise - latents + loss = (target - model_pred) ** 2 + + loss = jnp.mean(loss) + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state_params) + + new_state = flux_state.apply_gradients(grads=grad["flux_state"]) + + metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} + + return new_state, metrics, new_train_rng + From 419519554a10438d73c07749c004c0d1c3c12f59 Mon Sep 17 00:00:00 2001 From: ksikiric Date: Thu, 13 Feb 2025 11:30:48 +0000 Subject: [PATCH 2/9] Rebased on flux_lora and aligned flux_pipeline with changes in generate_flux.py --- .../pipelines/flux/flux_pipeline.py | 45 +++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py index 386944b41..917bcacb2 100644 --- a/src/maxdiffusion/pipelines/flux/flux_pipeline.py +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import partial -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Callable import jax import jax.numpy as jnp @@ -152,8 +152,8 @@ def prepare_latents( def prepare_latent_image_ids(self, height, width): latent_image_ids = jnp.zeros((height, width, 3)) - latent_image_ids = latent_image_ids.at[..., 1].set(latent_image_ids[..., 1] + jnp.arange(height)[:, None]) - latent_image_ids = latent_image_ids.at[..., 2].set(latent_image_ids[..., 2] + jnp.arange(width)[None, :]) + latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None]) + latent_image_ids = latent_image_ids.at[..., 2].set(jnp.arange(width)[None, :]) latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -165,7 +165,6 @@ def get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel ): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", @@ -180,8 +179,7 @@ def get_clip_prompt_embeds( prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False) prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1) - prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1)) + prompt_embeds = jnp.tile(prompt_embeds, (num_images_per_prompt, 1)) return prompt_embeds @@ -260,7 +258,8 @@ def _generate( txt_ids, vec, guidance_vec, - timesteps, + c_ts, + p_ts ): def loop_body( @@ -292,9 +291,6 @@ def loop_body( latents = jnp.array(latents, dtype=latents_dtype) return latents, state, c_ts, p_ts - c_ts = timesteps[:-1] - p_ts = timesteps[1:] - loop_body_p = partial( loop_body, transformer=self.flux, @@ -308,10 +304,28 @@ def loop_body( vae_decode_p = partial(self.vae_decode, vae=self.vae, state=vae_params, config=self._config) with self.mesh, nn_partitioning.axis_rules(self._config.logical_axis_rules): - latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, flux_params, c_ts, p_ts)) + latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, flux_params, c_ts, p_ts)) image = vae_decode_p(latents) return image + def do_time_shift(self, mu: float, sigma: float, t: Array): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + + def get_lin_function(self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + def time_shift(self, latents, timesteps): + # estimate mu based on linear estimation between two points + lin_function = self.get_lin_function(x1=self._config.max_sequence_length, + y1=self._config.base_shift, + y2=self._config.max_shift) + mu = lin_function(latents.shape[1]) + timesteps = self.do_time_shift(mu, 1.0, timesteps) + return timesteps + def __call__( self, timesteps: int, @@ -364,7 +378,11 @@ def __call__( rng=self.rng, ) - #timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16) + if self._config.time_shift: + timesteps = self.time_shift(latents, timesteps) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + guidance = jnp.asarray([self._config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) images = self._generate( @@ -376,7 +394,8 @@ def __call__( text_ids, pooled_prompt_embeds, guidance, - timesteps, + c_ts, + p_ts ) images = images From 26da5fd8d9fe126c7bf8a2cb7360bad9d5631c2f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 27 Feb 2025 17:30:45 +0000 Subject: [PATCH 3/9] batch text encoding. --- src/maxdiffusion/configs/base_flux_dev.yml | 2 +- .../pipelines/flux/flux_pipeline.py | 29 +++++++++++++++---- src/maxdiffusion/train_flux.py | 3 +- src/maxdiffusion/trainers/flux_trainer.py | 8 +++-- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 8c992bffe..d6220cf25 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -115,7 +115,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py index 917bcacb2..6a7657a6e 100644 --- a/src/maxdiffusion/pipelines/flux/flux_pipeline.py +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -190,6 +190,8 @@ def get_t5_prompt_embeds( tokenizer: AutoTokenizer, text_encoder: FlaxT5EncoderModel, max_sequence_length: int = 512, + encode_in_batches=False, + encode_batch_size=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -205,13 +207,23 @@ def get_t5_prompt_embeds( return_tensors="np", ) text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"] + if encode_in_batches: + prompt_embeds = None + for i in range(0, text_input_ids.shape[0], encode_batch_size): + batch_prompt_embeds = text_encoder(text_input_ids[i:i+encode_batch_size], attention_mask=None, output_hidden_states=False)["last_hidden_state"] + if prompt_embeds == None: + prompt_embeds = batch_prompt_embeds + else: + prompt_embeds = jnp.concatenate([prompt_embeds, batch_prompt_embeds]) + else: + prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"] + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + dtype = text_encoder.dtype prompt_embeds = prompt_embeds.astype(dtype) - _, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) return prompt_embeds @@ -226,7 +238,12 @@ def encode_prompt( t5_text_encoder: FlaxT5EncoderModel, num_images_per_prompt: int = 1, max_sequence_length: int = 512, + encode_in_batches: bool = False, + encode_batch_size: int = None ): + + if encode_in_batches: + assert encode_in_batches is not None prompt = [prompt] if isinstance(prompt, str) else prompt prompt_2 = prompt or prompt_2 @@ -242,6 +259,8 @@ def encode_prompt( tokenizer=t5_tokenizer, text_encoder=t5_text_encoder, max_sequence_length=max_sequence_length, + encode_in_batches=encode_in_batches, + encode_batch_size=encode_batch_size ) text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py index c432c0e69..40abaff80 100644 --- a/src/maxdiffusion/train_flux.py +++ b/src/maxdiffusion/train_flux.py @@ -24,14 +24,13 @@ mllog_utils, ) -from maxdiffusion.trainers.flux_trainer import FluxTrainer - from maxdiffusion.train_utils import ( validate_train_config, ) def train(config): + from maxdiffusion.trainers.flux_trainer import FluxTrainer trainer = FluxTrainer(config) trainer.start_training() diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index c30db18f6..92f4ab11f 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -105,6 +105,8 @@ def start_training(self): if self.config.dataset_type == "grain": data_iterator = self.restore_data_iterator_state(data_iterator) + # don't need this anymore, clear some memory. + del pipeline.t5_encoder flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( # ambiguous here, but if self.params.get("unet") doesn't exist @@ -138,7 +140,7 @@ def start_training(self): ) # 6. save final checkpoint # Hook - self.post_training_steps(pipeline, params, train_states, "after_training") + #self.post_training_steps(pipeline, params, train_states, "after_training") def get_shaped_batch(self, config, pipeline=None): """Return the shape of the batch - this is what eval_shape would return for the @@ -267,7 +269,9 @@ def load_dataset(self, pipeline, params, train_states): clip_tokenizer=pipeline.clip_tokenizer, t5_tokenizer=pipeline.t5_tokenizer, clip_text_encoder=pipeline.clip_encoder, - t5_text_encoder=pipeline.t5_encoder + t5_text_encoder=pipeline.t5_encoder, + encode_in_batches=True, + encode_batch_size=16 ) pack_latents_p = partial(pipeline.pack_latents) prepare_latent_image_ids_p = partial(pipeline.prepare_latent_image_ids) From ee7d422e61128f08879523541616284dfd1a187e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 27 Feb 2025 23:35:22 +0000 Subject: [PATCH 4/9] comment out post_training_steps --- src/maxdiffusion/trainers/flux_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 92f4ab11f..5dfc05c30 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -107,7 +107,6 @@ def start_training(self): # don't need this anymore, clear some memory. del pipeline.t5_encoder - flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( # ambiguous here, but if self.params.get("unet") doesn't exist # Then its 1 of 2 scenarios: @@ -120,7 +119,7 @@ def start_training(self): ) train_states["flux_state"] = flux_state state_shardings["flux_state_shardings"] = flux_state_mesh_shardings - self.post_training_steps(pipeline, params, train_states, msg="before_training") + #self.post_training_steps(pipeline, params, train_states, msg="before_training") # Create scheduler noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) From d05161d149d0e7a799509495e6f8085502dd3b81 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 28 Feb 2025 18:16:28 +0000 Subject: [PATCH 5/9] refactor some code for similarity to sd trainers. --- .../checkpointing/flux_checkpointer.py | 145 +++++++++++------- src/maxdiffusion/trainers/flux_trainer.py | 59 ++++--- 2 files changed, 134 insertions(+), 70 deletions(-) diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index 4da642beb..0507b2c65 100644 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -16,12 +16,10 @@ from abc import ABC from contextlib import nullcontext -import os -import json import functools import jax import jax.numpy as jnp -from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P +from jax.sharding import Mesh import orbax.checkpoint as ocp import grain.python as grain from maxdiffusion import ( @@ -35,15 +33,19 @@ from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) from maxdiffusion.checkpointing.checkpointing_utils import ( - create_orbax_checkpoint_manager, - load_stable_diffusion_configs, + create_orbax_checkpoint_manager ) from maxdiffusion.models.flux.util import load_flow_model FLUX_CHECKPOINT = "FLUX_CHECKPOINT" -_CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS" _CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX" +FLUX_STATE_KEY = "flux_state" +FLUX_TRANSFORMER_PARAMS_KEY = "flux_transformer_params" +FLUX_STATE_SHARDINGS_KEY = "flux_state_shardings" +FLUX_VAE_PARAMS_KEY = "flux_vae" +VAE_STATE_KEY = "vae_state" +VAE_STATE_SHARDINGS_KEY = "vae_state_shardings" class FluxCheckpointer(ABC): @@ -144,67 +146,106 @@ def _set_checkpoint_format(self, checkpoint_format): self.checkpoint_format = checkpoint_format def save_checkpoint(self, train_step, pipeline, train_states): + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) items = { "config": ocp.args.JsonSave({"model_name": self.config.model_name}), } - items["flux_state"] = ocp.args.PyTreeSave(train_states["flux_state"]) + items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY]) self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) def load_params(self, step=None): self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX + + def load_flux_configs_from_orbax(self): + # TODO - load configs from orbax + return None - def load_checkpoint(self, step=None, scheduler_class=None): - clip_encoder = FlaxCLIPTextModel.from_pretrained( - self.config.clip_model_name_or_path, dtype=self.config.weights_dtype - ) - clip_tokenizer = CLIPTokenizer.from_pretrained( - self.config.clip_model_name_or_path, max_length=77, use_fast=True - ) + def load_diffusers_checkpoint(self): + flash_block_sizes = max_utils.get_flash_block_sizes(self.config) - t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) - t5_tokenizer = AutoTokenizer.from_pretrained( - self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True + if jax.device_count() == jax.local_device_count(): + context = jax.default_device(jax.devices("cpu")[0]) + else: + context = nullcontext() + + with context: + clip_encoder = FlaxCLIPTextModel.from_pretrained( + self.config.clip_model_name_or_path, dtype=self.config.weights_dtype + ) + clip_tokenizer = CLIPTokenizer.from_pretrained( + self.config.clip_model_name_or_path, + max_length=77, + use_fast=True + ) + t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) + t5_tokenizer = AutoTokenizer.from_pretrained( + self.config.t5xxl_model_name_or_path, + max_length=self.config.max_sequence_length, + use_fast=True + ) + + vae, vae_params = FlaxAutoencoderKL.from_pretrained( + self.config.pretrained_model_name_or_path, + subfolder="vae", + from_pt=True, + use_safetensors=True, + dtype=self.config.weights_dtype + ) + + # loading from pretrained here causes a crash when trying to compile the model + # Failed to load HSACO: HIP_ERROR_NoBinaryForGpu + transformer = FluxTransformer2DModel.from_config( + self.config.pretrained_model_name_or_path, + subfolder="transformer", + mesh=self.mesh, + split_head_dim=self.config.split_head_dim, + attention_kernel=self.config.attention, + flash_block_sizes=flash_block_sizes, + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + precision=max_utils.get_precision(self.config), + ) + transformer_eval_params = transformer.init_weights( + rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True + ) + + transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu") + + pipeline = FluxPipeline( + t5_encoder, + clip_encoder, + vae, + t5_tokenizer, + clip_tokenizer, + transformer, + None, + dtype=self.config.activations_dtype, + mesh=self.mesh, + config=self.config, + rng=self.rng ) - encoders_sharding = PositionalSharding(self.devices_array).replicate() - partial_device_put_replicated = functools.partial(max_utils.device_put_replicated, sharding=encoders_sharding) - clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_encoder.params) - clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_encoder.params) - t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) - t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) + params = { + FLUX_VAE_PARAMS_KEY : vae_params, + FLUX_TRANSFORMER_PARAMS_KEY : transformer_params + } + return pipeline, params - vae, vae_params = FlaxAutoencoderKL.from_pretrained( - self.config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" - ) + def load_checkpoint(self, step=None, scheduler_class=None): - flash_block_sizes = max_utils.get_flash_block_sizes(self.config) - # loading from pretrained here causes a crash when trying to compile the model - # Failed to load HSACO: HIP_ERROR_NoBinaryForGpu - transformer = FluxTransformer2DModel.from_config( - self.config.pretrained_model_name_or_path, - subfolder="transformer", - mesh=self.mesh, - split_head_dim=self.config.split_head_dim, - attention_kernel=self.config.attention, - flash_block_sizes=flash_block_sizes, - dtype=self.config.activations_dtype, - weights_dtype=self.config.weights_dtype, - precision=max_utils.get_precision(self.config), - ) - - return FluxPipeline(t5_encoder, - clip_encoder, - vae, - t5_tokenizer, - clip_tokenizer, - transformer, - None, - dtype=self.config.activations_dtype, - mesh=self.mesh, - config=self.config, - rng=self.rng), vae_params + model_configs = self.load_flux_configs_from_orbax() + + pipeline, params = None, {} + + if model_configs: + print("TODO - load configs from orbax") + else: + pipeline, params = self.load_diffusers_checkpoint() + + return pipeline, params diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 5dfc05c30..c9346091d 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -22,9 +22,17 @@ import jax import optax import jax.numpy as jnp -from jax.sharding import PartitionSpec as P +from jax.sharding import PositionalSharding, PartitionSpec as P from flax.linen import partitioning as nn_partitioning -from maxdiffusion.checkpointing.flux_checkpointer import (FluxCheckpointer, FLUX_CHECKPOINT) +from maxdiffusion.checkpointing.flux_checkpointer import ( + FluxCheckpointer, + FLUX_CHECKPOINT, + FLUX_TRANSFORMER_PARAMS_KEY, + FLUX_STATE_KEY, + FLUX_STATE_SHARDINGS_KEY, + FLUX_VAE_PARAMS_KEY, + VAE_STATE_KEY, + VAE_STATE_SHARDINGS_KEY) from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) @@ -57,7 +65,7 @@ def __init__(self, config): raise ValueError("this script currently doesn't support training text_encoders") def post_training_steps(self, pipeline, params, train_states, msg=""): - imgs = pipeline(flux_params=train_states["flux_state"], + imgs = pipeline(flux_params=train_states[FLUX_STATE_KEY], timesteps=50, vae_params=train_states["vae_state"]) imgs = np.array(imgs) @@ -94,11 +102,21 @@ def start_training(self): # create train states train_states = {} state_shardings = {} + + # move params to accelerator + encoders_sharding = PositionalSharding(self.devices_array).replicate() + partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding) + pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params) + pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params) + pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params) + pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params) + + vae_state, vae_state_mesh_shardings = self.create_vae_state( - pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False + pipeline=pipeline, params=params[FLUX_VAE_PARAMS_KEY], checkpoint_item_name=VAE_STATE_KEY, is_training=False ) - train_states["vae_state"] = vae_state - state_shardings["vae_state_shardings"] = vae_state_mesh_shardings + train_states[VAE_STATE_KEY] = vae_state + state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings # Load dataset data_iterator = self.load_dataset(pipeline, params, train_states) @@ -107,18 +125,23 @@ def start_training(self): # don't need this anymore, clear some memory. del pipeline.t5_encoder + + # evaluate shapes + flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( - # ambiguous here, but if self.params.get("unet") doesn't exist + # ambiguous here, but if params=None # Then its 1 of 2 scenarios: # 1. unet state will be loaded directly from orbax # 2. a new unet is being trained from scratch. pipeline=pipeline, params=None, # Params are loaded inside create_flux_state - checkpoint_item_name="flux_state", + checkpoint_item_name=FLUX_STATE_KEY, is_training=True, ) - train_states["flux_state"] = flux_state - state_shardings["flux_state_shardings"] = flux_state_mesh_shardings + flux_state = flux_state.replace(params=params[FLUX_TRANSFORMER_PARAMS_KEY]) + flux_state = jax.device_put(flux_state, flux_state_mesh_shardings) + train_states[FLUX_STATE_KEY] = flux_state + state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings #self.post_training_steps(pipeline, params, train_states, msg="before_training") # Create scheduler @@ -320,7 +343,7 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da max_logging.log("Precompiling...") s = time.time() dummy_batch = self.get_shaped_batch(self.config, pipeline) - p_train_step = p_train_step.lower(train_states["flux_state"], dummy_batch, train_rngs) + p_train_step = p_train_step.lower(train_states[FLUX_STATE_KEY], dummy_batch, train_rngs) p_train_step = p_train_step.compile() max_logging.log(f"Compile time: {(time.time() - s )}") return p_train_step @@ -328,7 +351,7 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): writer = max_utils.initialize_summary_writer(self.config) - flux_state = train_states["flux_state"] + flux_state = train_states[FLUX_STATE_KEY] num_model_parameters = max_utils.calculate_num_params_from_pytree(flux_state.params) max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) @@ -352,7 +375,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera last_profiling_step = np.clip( first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 ) - start_step = get_first_step(train_states["flux_state"]) + start_step = get_first_step(train_states[FLUX_STATE_KEY]) _, train_rngs = jax.random.split(self.rng) times = [] for step in np.arange(start_step, self.config.max_train_steps): @@ -379,7 +402,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: max_logging.log(f"Saving checkpoint for step {step}") - train_states["flux_state"] = flux_state + train_states[FLUX_STATE_KEY] = flux_state self.save_checkpoint(step, pipeline, train_states) if self.config.enable_profiler and step == last_profiling_step: @@ -390,7 +413,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config ) - train_states["flux_state"] = flux_state + train_states[FLUX_STATE_KEY] = flux_state max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}") if self.config.save_final_checkpoint: max_logging.log(f"Saving checkpoint for step {step}") @@ -402,7 +425,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera def _train_step(flux_state, batch, train_rng, guidance_vec, pipeline, scheduler, config): _, gen_dummy_rng = jax.random.split(train_rng) sample_rng, timestep_bias_rng, new_train_rng = jax.random.split(gen_dummy_rng, 3) - state_params = {"flux_state": flux_state.params} + state_params = {FLUX_STATE_KEY: flux_state.params} def compute_loss(state_params): latents = batch["pixel_values"] @@ -424,7 +447,7 @@ def compute_loss(state_params): noisy_latents = pipeline.scheduler.add_noise(scheduler, latents, noise, timesteps, flux=True) model_pred = pipeline.flux.apply( - {"params": state_params["flux_state"]}, + {"params": state_params[FLUX_STATE_KEY]}, hidden_states=noisy_latents, img_ids=img_ids, encoder_hidden_states=text_embeds, @@ -444,7 +467,7 @@ def compute_loss(state_params): grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state_params) - new_state = flux_state.apply_gradients(grads=grad["flux_state"]) + new_state = flux_state.apply_gradients(grads=grad[FLUX_STATE_KEY]) metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} From bfec2c8c3f70c76fafdb4e3bfc3f49195efdfad8 Mon Sep 17 00:00:00 2001 From: ksikiric Date: Mon, 31 Mar 2025 15:35:02 +0000 Subject: [PATCH 6/9] Added orbax saving and a new file for inference that utilizes the pipeline. --- .../checkpointing/checkpointing_utils.py | 26 ++-- .../checkpointing/flux_checkpointer.py | 88 ++++++++++-- src/maxdiffusion/generate_flux_pipeline.py | 127 ++++++++++++++++++ .../pipelines/flux/flux_pipeline.py | 6 +- src/maxdiffusion/trainers/flux_trainer.py | 27 +--- 5 files changed, 233 insertions(+), 41 deletions(-) create mode 100644 src/maxdiffusion/generate_flux_pipeline.py diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index d771a15de..e383d6124 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -57,16 +57,21 @@ def create_orbax_checkpoint_manager( max_logging.log(f"checkpoint dir: {checkpoint_dir}") p = epath.Path(checkpoint_dir) - item_names = ( - "unet_config", - "vae_config", - "text_encoder_config", - "scheduler_config", - "unet_state", - "vae_state", - "text_encoder_state", - "tokenizer_config", - ) + if checkpoint_type == FLUX_CHECKPOINT: + item_names = ("flux_state", "flux_config", + "vae_state", "vae_config", + "scheduler", "scheduler_config") + else: + item_names = ( + "unet_config", + "vae_config", + "text_encoder_config", + "scheduler_config", + "unet_state", + "vae_state", + "text_encoder_state", + "tokenizer_config", + ) if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT: item_names += ( "text_encoder_2_state", @@ -140,6 +145,7 @@ def load_params_from_path( ckpt_path = os.path.join(config.checkpoint_dir, str(step), checkpoint_item) ckpt_path = epath.Path(ckpt_path) + ckpt_path = os.path.abspath(ckpt_path) restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params) restored = ckptr.restore( diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index 0507b2c65..4c0f131bc 100644 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -17,6 +17,8 @@ from abc import ABC from contextlib import nullcontext import functools +import json +import os import jax import jax.numpy as jnp from jax.sharding import Mesh @@ -59,8 +61,10 @@ def __init__(self, config, checkpoint_type): self.mesh = Mesh(self.devices_array, self.config.mesh_axes) self.total_train_batch_size = self.config.total_train_batch_size + checkpoint_dir = os.path.abspath(self.config.checkpoint_dir) + self.checkpoint_manager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, + checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type, @@ -117,7 +121,7 @@ def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=F config=self.config, mesh=self.mesh, weights_init_fn=weights_init_fn, - model_params=params, + model_params=params.get("flux_vae", None), checkpoint_manager=self.checkpoint_manager, checkpoint_item=checkpoint_item_name, training=is_training, @@ -149,10 +153,14 @@ def save_checkpoint(self, train_step, pipeline, train_states): def config_to_json(model_or_config): return json.loads(model_or_config.to_json_string()) items = { - "config": ocp.args.JsonSave({"model_name": self.config.model_name}), + "flux_config": ocp.args.JsonSave(config_to_json(pipeline.flux)), + "vae_config": ocp.args.JsonSave(config_to_json(pipeline.vae)), + "scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)) } items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY]) + items["vae_state"] = ocp.args.PyTreeSave(train_states["vae_state"]) + items["scheduler"] = ocp.args.PyTreeSave(train_states["scheduler"]) self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) @@ -160,9 +168,20 @@ def load_params(self, step=None): self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX - def load_flux_configs_from_orbax(self): - # TODO - load configs from orbax - return None + def load_flux_configs_from_orbax(self, step): + max_logging.log("Restoring stable diffusion configs") + if step is None: + step = self.checkpoint_manager.latest_step() + if step is None: + return None + + restore_args = { + "flux_config": ocp.args.JsonRestore(), + "vae_config": ocp.args.JsonRestore(), + "scheduler_config": ocp.args.JsonRestore(), + } + + return (self.checkpoint_manager.restore(step, args=ocp.args.Composite(**restore_args)), None) def load_diffusers_checkpoint(self): flash_block_sizes = max_utils.get_flash_block_sizes(self.config) @@ -238,12 +257,65 @@ def load_diffusers_checkpoint(self): def load_checkpoint(self, step=None, scheduler_class=None): - model_configs = self.load_flux_configs_from_orbax() + model_configs = self.load_flux_configs_from_orbax(step) pipeline, params = None, {} if model_configs: - print("TODO - load configs from orbax") + if jax.device_count() == jax.local_device_count(): + context = jax.default_device(jax.devices("cpu")[0]) + else: + context = nullcontext() + + with context: + clip_encoder = FlaxCLIPTextModel.from_pretrained( + self.config.clip_model_name_or_path, dtype=self.config.weights_dtype + ) + clip_tokenizer = CLIPTokenizer.from_pretrained( + self.config.clip_model_name_or_path, + max_length=77, + use_fast=True + ) + t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) + t5_tokenizer = AutoTokenizer.from_pretrained( + self.config.t5xxl_model_name_or_path, + max_length=self.config.max_sequence_length, + use_fast=True + ) + + vae = FlaxAutoencoderKL.from_config( + model_configs[0]["vae_config"], + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + from_pt=self.config.from_pt, + ) + + transformer = FluxTransformer2DModel.from_config( + model_configs[0]["flux_config"], + mesh=self.mesh, + split_head_dim=self.config.split_head_dim, + attention_kernel=self.config.attention, + flash_block_sizes=max_utils.get_flash_block_sizes(self.config), + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + precision=max_utils.get_precision(self.config), + from_pt=self.config.from_pt, + ) + + pipeline = FluxPipeline( + t5_encoder, + clip_encoder, + vae, + t5_tokenizer, + clip_tokenizer, + transformer, + None, + dtype=self.config.activations_dtype, + mesh=self.mesh, + config=self.config, + rng=self.rng + ) + else: pipeline, params = self.load_diffusers_checkpoint() diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py new file mode 100644 index 000000000..f3e6c1195 --- /dev/null +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -0,0 +1,127 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Sequence +from absl import app +from contextlib import ExitStack +import functools +import time +import numpy as np +from PIL import Image +import jax + +from maxdiffusion import pyconfig, max_logging, max_utils + +from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer +from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path +from maxdiffusion.max_utils import setup_initial_state + +def run(config): + checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT") + pipeline, params = checkpoint_loader.load_checkpoint() + + if not params: + ## VAE + weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False + ) + # load unet params from orbax checkpoint + vae_params = load_params_from_path( + config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state" + ) + + vae_state = {"params": vae_params} + + ## Flux + weights_init_fn = functools.partial(pipeline.flux.init_weights, + rngs=checkpoint_loader.rng, + max_sequence_length=config.max_sequence_length) + + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False + ) + # load unet params from orbax checkpoint + flux_params = load_params_from_path( + config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state" + ) + flux_state = {"params": flux_params} + else: + weights_init_fn = functools.partial( + pipeline.flux.init_weights, + rngs=checkpoint_loader.rng, + max_sequence_length=config.max_sequence_length, + eval_only=False + ) + transformer_state, flux_state_shardings = setup_initial_state( + model=pipeline.flux, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + transformer_state = transformer_state.replace(params=params["flux_transformer_params"]) + transformer_state = jax.device_put(transformer_state, flux_state_shardings) + + weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng) + vae_state, _ = setup_initial_state( + model=pipeline.vae, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=params['flux_vae'], + training=False, + ) + + vae_state = {"params": vae_state.params} + flux_state = {"params": transformer_state.params} + + t0 = time.perf_counter() + with ExitStack() as stack: + imgs = pipeline(flux_params=flux_state, + timesteps=50, + vae_params=vae_state).block_until_ready() + t1 = time.perf_counter() + max_logging.log(f"Compile time: {t1 - t0:.1f}s.") + + t0 = time.perf_counter() + with ExitStack() as stack: + imgs = pipeline(flux_params=flux_state, + timesteps=50, + vae_params=vae_state).block_until_ready() + imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) + t1 = time.perf_counter() + max_logging.log(f"Inference time: {t1 - t0:.1f}s.") + imgs = np.array(imgs) + imgs = (imgs * 0.5 + 0.5).clip(0, 1) + imgs = np.transpose(imgs, (0, 2, 3, 1)) + imgs = np.uint8(imgs * 255) + for i, image in enumerate(imgs): + Image.fromarray(image).save(f"flux_{i}.png") + + return imgs + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py index 6a7657a6e..cd40a45be 100644 --- a/src/maxdiffusion/pipelines/flux/flux_pipeline.py +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -102,12 +102,12 @@ def unpack(self, x: Array, height: int, width: int) -> Array: def vae_decode(self, latents, vae, state, config): img = self.unpack(x=latents, height=config.resolution, width=config.resolution) img = img / vae.config.scaling_factor + vae.config.shift_factor - img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample + img = vae.apply({"params": state["params"]}, img, deterministic=True, method=vae.decode).sample return img def vae_encode(self, latents, vae, state): img = vae.apply( - {"params": state.params}, + {"params": state["params"]}, latents, deterministic=True, method=vae.encode).latent_dist.mode() @@ -297,7 +297,7 @@ def loop_body( t_prev = p_ts[step] t_vec = jnp.full((latents.shape[0],), t_curr, dtype=latents.dtype) pred = transformer.apply( - {"params": state.params}, + {"params": state['params']}, hidden_states=latents, img_ids=latent_image_ids, encoder_hidden_states=prompt_embeds, diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index c9346091d..7970b4207 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -65,15 +65,7 @@ def __init__(self, config): raise ValueError("this script currently doesn't support training text_encoders") def post_training_steps(self, pipeline, params, train_states, msg=""): - imgs = pipeline(flux_params=train_states[FLUX_STATE_KEY], - timesteps=50, - vae_params=train_states["vae_state"]) - imgs = np.array(imgs) - imgs = (imgs * 0.5 + 0.5).clip(0, 1) - imgs = np.transpose(imgs, (0, 2, 3, 1)) - imgs = np.uint8(imgs * 255) - for i, image in enumerate(imgs): - Image.fromarray(image).save(f"flux_{msg}_{i}.png") + pass def create_scheduler(self, pipeline, params): noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( @@ -113,7 +105,7 @@ def start_training(self): vae_state, vae_state_mesh_shardings = self.create_vae_state( - pipeline=pipeline, params=params[FLUX_VAE_PARAMS_KEY], checkpoint_item_name=VAE_STATE_KEY, is_training=False + pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False ) train_states[VAE_STATE_KEY] = vae_state state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings @@ -131,14 +123,13 @@ def start_training(self): flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( # ambiguous here, but if params=None # Then its 1 of 2 scenarios: - # 1. unet state will be loaded directly from orbax - # 2. a new unet is being trained from scratch. + # 1. flux state will be loaded directly from orbax + # 2. a new flux is being trained from scratch. pipeline=pipeline, params=None, # Params are loaded inside create_flux_state checkpoint_item_name=FLUX_STATE_KEY, is_training=True, ) - flux_state = flux_state.replace(params=params[FLUX_TRANSFORMER_PARAMS_KEY]) flux_state = jax.device_put(flux_state, flux_state_mesh_shardings) train_states[FLUX_STATE_KEY] = flux_state state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings @@ -162,7 +153,7 @@ def start_training(self): ) # 6. save final checkpoint # Hook - #self.post_training_steps(pipeline, params, train_states, "after_training") + self.post_training_steps(pipeline, params, train_states, "after_training") def get_shaped_batch(self, config, pipeline=None): """Return the shape of the batch - this is what eval_shape would return for the @@ -408,13 +399,9 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera if self.config.enable_profiler and step == last_profiling_step: max_utils.deactivate_profiler(self.config) - if self.config.write_metrics: - write_metrics( - writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config - ) - train_states[FLUX_STATE_KEY] = flux_state - max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}") + if len(times) > 0: + max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}") if self.config.save_final_checkpoint: max_logging.log(f"Saving checkpoint for step {step}") self.save_checkpoint(step, pipeline, train_states) From 4600a72d6e768d3378445a2891173a0b937ef938 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 15 Apr 2025 14:34:16 -0700 Subject: [PATCH 7/9] Update generate_flux_pipeline.py --- src/maxdiffusion/generate_flux_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py index f3e6c1195..31ae90f9a 100644 --- a/src/maxdiffusion/generate_flux_pipeline.py +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -25,11 +25,11 @@ from maxdiffusion import pyconfig, max_logging, max_utils -from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path from maxdiffusion.max_utils import setup_initial_state def run(config): + from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT") pipeline, params = checkpoint_loader.load_checkpoint() From 5453b3ccd5f728ca9d633ff4489d10e2fe02d4ac Mon Sep 17 00:00:00 2001 From: ksikiric Date: Wed, 16 Apr 2025 07:41:07 +0000 Subject: [PATCH 8/9] Fixed comments and rebased on main --- src/maxdiffusion/checkpointing/flux_checkpointer.py | 4 +--- src/maxdiffusion/configs/base_flux_dev.yml | 6 +++--- src/maxdiffusion/configs/base_flux_schnell.yml | 8 ++++---- .../schedulers/scheduling_euler_discrete_flax.py | 2 -- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index 4c0f131bc..73a6f0568 100644 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -61,10 +61,8 @@ def __init__(self, config, checkpoint_type): self.mesh = Mesh(self.devices_array, self.config.mesh_axes) self.total_train_batch_size = self.config.total_train_batch_size - checkpoint_dir = os.path.abspath(self.config.checkpoint_dir) - self.checkpoint_manager = create_orbax_checkpoint_manager( - checkpoint_dir, + self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type, diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index d6220cf25..167dc8bc8 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -75,7 +75,7 @@ flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 -# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# If train_new_flux, flux weights will be randomly initialized to train flux from scratch # else they will be loaded from pretrained_model_name_or_path train_new_flux: False @@ -223,8 +223,8 @@ skip_first_n_steps_for_profiler: 5 profiler_steps: 10 # Generation parameters -prompt: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet." -prompt_2: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet." +prompt: "A magical castle in the middle of a forest, artistic drawing" +prompt_2: "A magical castle in the middle of a forest, artistic drawing" negative_prompt: "purple, red" do_classifier_free_guidance: True guidance_scale: 3.5 diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index ebc901da3..188074a5b 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -83,9 +83,9 @@ flash_block_sizes: { # GroupNorm groups norm_num_groups: 32 -# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# If train_new_flux, flux weights will be randomly initialized to train flux from scratch # else they will be loaded from pretrained_model_name_or_path -train_new_unet: False +train_new_flux: False # train text_encoder - Currently not supported for SDXL train_text_encoder: False @@ -123,7 +123,7 @@ diffusion_scheduler_config: { base_output_directory: "" # Hardware -hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu' +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] @@ -238,7 +238,7 @@ do_classifier_free_guidance: True guidance_scale: 0.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 50 +num_inference_steps: 4 # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index ea1694af2..c46cd5014 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -17,7 +17,6 @@ import flax import jax.numpy as jnp -import max_logging from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( @@ -170,7 +169,6 @@ def set_timesteps( timesteps = (jnp.arange(self.config.num_train_timesteps, 0, -step_ratio)).round() timesteps -= 1 elif timestep_spacing == "flux": - max_logging.log("Using flux timestep spacing") timesteps = jnp.linspace(1, 0, num_inference_steps + 1) else: raise ValueError( From 7e130fc613eebeaf3c36761e59a1213478e968ba Mon Sep 17 00:00:00 2001 From: ksikiric Date: Wed, 16 Apr 2025 09:20:51 +0000 Subject: [PATCH 9/9] ruff + code_style --- .../checkpointing/checkpointing_utils.py | 4 +- .../checkpointing/flux_checkpointer.py | 160 ++++++++--------- src/maxdiffusion/generate_flux_pipeline.py | 30 ++-- src/maxdiffusion/maxdiffusion_utils.py | 23 +-- src/maxdiffusion/pipelines/flux/__init__.py | 4 +- .../pipelines/flux/flux_pipeline.py | 161 ++++++++---------- .../scheduling_euler_discrete_flax.py | 4 +- src/maxdiffusion/train_flux.py | 1 + src/maxdiffusion/train_utils.py | 8 +- src/maxdiffusion/trainers/flux_trainer.py | 153 +++++++---------- 10 files changed, 244 insertions(+), 304 deletions(-) diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index e383d6124..aa68267a5 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -58,9 +58,7 @@ def create_orbax_checkpoint_manager( p = epath.Path(checkpoint_dir) if checkpoint_type == FLUX_CHECKPOINT: - item_names = ("flux_state", "flux_config", - "vae_state", "vae_config", - "scheduler", "scheduler_config") + item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config") else: item_names = ( "unet_config", diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index 73a6f0568..a5e1bfc2f 100644 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -18,9 +18,7 @@ from contextlib import nullcontext import functools import json -import os import jax -import jax.numpy as jnp from jax.sharding import Mesh import orbax.checkpoint as ocp import grain.python as grain @@ -32,11 +30,9 @@ from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel from ..pipelines.flux.flux_pipeline import FluxPipeline -from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer) -from maxdiffusion.checkpointing.checkpointing_utils import ( - create_orbax_checkpoint_manager -) +from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from maxdiffusion.models.flux.util import load_flow_model FLUX_CHECKPOINT = "FLUX_CHECKPOINT" @@ -49,6 +45,7 @@ VAE_STATE_KEY = "vae_state" VAE_STATE_SHARDINGS_KEY = "vae_state_shardings" + class FluxCheckpointer(ABC): def __init__(self, config, checkpoint_type): @@ -87,12 +84,14 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training) tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate) transformer_eval_params = transformer.init_weights( - rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True + rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True ) transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu") - weights_init_fn = functools.partial(pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length) + weights_init_fn = functools.partial( + pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length + ) flux_state, state_mesh_shardings = max_utils.setup_initial_state( model=pipeline.flux, tx=tx, @@ -150,10 +149,11 @@ def _set_checkpoint_format(self, checkpoint_format): def save_checkpoint(self, train_step, pipeline, train_states): def config_to_json(model_or_config): return json.loads(model_or_config.to_json_string()) + items = { "flux_config": ocp.args.JsonSave(config_to_json(pipeline.flux)), "vae_config": ocp.args.JsonSave(config_to_json(pipeline.vae)), - "scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)) + "scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)), } items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY]) @@ -165,7 +165,7 @@ def config_to_json(model_or_config): def load_params(self, step=None): self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX - + def load_flux_configs_from_orbax(self, step): max_logging.log("Restoring stable diffusion configs") if step is None: @@ -188,68 +188,57 @@ def load_diffusers_checkpoint(self): context = jax.default_device(jax.devices("cpu")[0]) else: context = nullcontext() - + with context: - clip_encoder = FlaxCLIPTextModel.from_pretrained( - self.config.clip_model_name_or_path, dtype=self.config.weights_dtype - ) - clip_tokenizer = CLIPTokenizer.from_pretrained( - self.config.clip_model_name_or_path, - max_length=77, - use_fast=True - ) + clip_encoder = FlaxCLIPTextModel.from_pretrained(self.config.clip_model_name_or_path, dtype=self.config.weights_dtype) + clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True) t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) t5_tokenizer = AutoTokenizer.from_pretrained( - self.config.t5xxl_model_name_or_path, - max_length=self.config.max_sequence_length, - use_fast=True + self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( - self.config.pretrained_model_name_or_path, - subfolder="vae", - from_pt=True, - use_safetensors=True, - dtype=self.config.weights_dtype + self.config.pretrained_model_name_or_path, + subfolder="vae", + from_pt=True, + use_safetensors=True, + dtype=self.config.weights_dtype, ) # loading from pretrained here causes a crash when trying to compile the model # Failed to load HSACO: HIP_ERROR_NoBinaryForGpu transformer = FluxTransformer2DModel.from_config( - self.config.pretrained_model_name_or_path, - subfolder="transformer", - mesh=self.mesh, - split_head_dim=self.config.split_head_dim, - attention_kernel=self.config.attention, - flash_block_sizes=flash_block_sizes, - dtype=self.config.activations_dtype, - weights_dtype=self.config.weights_dtype, - precision=max_utils.get_precision(self.config), + self.config.pretrained_model_name_or_path, + subfolder="transformer", + mesh=self.mesh, + split_head_dim=self.config.split_head_dim, + attention_kernel=self.config.attention, + flash_block_sizes=flash_block_sizes, + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + precision=max_utils.get_precision(self.config), ) transformer_eval_params = transformer.init_weights( - rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True + rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True ) - + transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu") pipeline = FluxPipeline( - t5_encoder, - clip_encoder, - vae, - t5_tokenizer, - clip_tokenizer, - transformer, - None, - dtype=self.config.activations_dtype, - mesh=self.mesh, - config=self.config, - rng=self.rng + t5_encoder, + clip_encoder, + vae, + t5_tokenizer, + clip_tokenizer, + transformer, + None, + dtype=self.config.activations_dtype, + mesh=self.mesh, + config=self.config, + rng=self.rng, ) - params = { - FLUX_VAE_PARAMS_KEY : vae_params, - FLUX_TRANSFORMER_PARAMS_KEY : transformer_params - } + params = {FLUX_VAE_PARAMS_KEY: vae_params, FLUX_TRANSFORMER_PARAMS_KEY: transformer_params} return pipeline, params @@ -267,55 +256,50 @@ def load_checkpoint(self, step=None, scheduler_class=None): with context: clip_encoder = FlaxCLIPTextModel.from_pretrained( - self.config.clip_model_name_or_path, dtype=self.config.weights_dtype + self.config.clip_model_name_or_path, dtype=self.config.weights_dtype ) - clip_tokenizer = CLIPTokenizer.from_pretrained( - self.config.clip_model_name_or_path, - max_length=77, - use_fast=True + clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True) + t5_encoder = FlaxT5EncoderModel.from_pretrained( + self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype ) - t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) t5_tokenizer = AutoTokenizer.from_pretrained( - self.config.t5xxl_model_name_or_path, - max_length=self.config.max_sequence_length, - use_fast=True + self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True ) vae = FlaxAutoencoderKL.from_config( - model_configs[0]["vae_config"], - dtype=self.config.activations_dtype, - weights_dtype=self.config.weights_dtype, - from_pt=self.config.from_pt, + model_configs[0]["vae_config"], + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + from_pt=self.config.from_pt, ) transformer = FluxTransformer2DModel.from_config( - model_configs[0]["flux_config"], - mesh=self.mesh, - split_head_dim=self.config.split_head_dim, - attention_kernel=self.config.attention, - flash_block_sizes=max_utils.get_flash_block_sizes(self.config), - dtype=self.config.activations_dtype, - weights_dtype=self.config.weights_dtype, - precision=max_utils.get_precision(self.config), - from_pt=self.config.from_pt, + model_configs[0]["flux_config"], + mesh=self.mesh, + split_head_dim=self.config.split_head_dim, + attention_kernel=self.config.attention, + flash_block_sizes=max_utils.get_flash_block_sizes(self.config), + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + precision=max_utils.get_precision(self.config), + from_pt=self.config.from_pt, ) pipeline = FluxPipeline( - t5_encoder, - clip_encoder, - vae, - t5_tokenizer, - clip_tokenizer, - transformer, - None, - dtype=self.config.activations_dtype, - mesh=self.mesh, - config=self.config, - rng=self.rng + t5_encoder, + clip_encoder, + vae, + t5_tokenizer, + clip_tokenizer, + transformer, + None, + dtype=self.config.activations_dtype, + mesh=self.mesh, + config=self.config, + rng=self.rng, ) else: pipeline, params = self.load_diffusers_checkpoint() - - return pipeline, params + return pipeline, params diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py index 31ae90f9a..8887375d0 100644 --- a/src/maxdiffusion/generate_flux_pipeline.py +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -28,8 +28,10 @@ from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path from maxdiffusion.max_utils import setup_initial_state + def run(config): from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer + checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT") pipeline, params = checkpoint_loader.load_checkpoint() @@ -47,9 +49,9 @@ def run(config): vae_state = {"params": vae_params} ## Flux - weights_init_fn = functools.partial(pipeline.flux.init_weights, - rngs=checkpoint_loader.rng, - max_sequence_length=config.max_sequence_length) + weights_init_fn = functools.partial( + pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length + ) unboxed_abstract_state, _, _ = max_utils.get_abstract_state( pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False @@ -61,10 +63,10 @@ def run(config): flux_state = {"params": flux_params} else: weights_init_fn = functools.partial( - pipeline.flux.init_weights, - rngs=checkpoint_loader.rng, - max_sequence_length=config.max_sequence_length, - eval_only=False + pipeline.flux.init_weights, + rngs=checkpoint_loader.rng, + max_sequence_length=config.max_sequence_length, + eval_only=False, ) transformer_state, flux_state_shardings = setup_initial_state( model=pipeline.flux, @@ -85,7 +87,7 @@ def run(config): config=config, mesh=checkpoint_loader.mesh, weights_init_fn=weights_init_fn, - model_params=params['flux_vae'], + model_params=params["flux_vae"], training=False, ) @@ -93,18 +95,14 @@ def run(config): flux_state = {"params": transformer_state.params} t0 = time.perf_counter() - with ExitStack() as stack: - imgs = pipeline(flux_params=flux_state, - timesteps=50, - vae_params=vae_state).block_until_ready() + with ExitStack(): + imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready() t1 = time.perf_counter() max_logging.log(f"Compile time: {t1 - t0:.1f}s.") t0 = time.perf_counter() - with ExitStack() as stack: - imgs = pipeline(flux_params=flux_state, - timesteps=50, - vae_params=vae_state).block_until_ready() + with ExitStack(): + imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready() imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) t1 = time.perf_counter() max_logging.log(f"Inference time: {t1 - t0:.1f}s.") diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index bd1f13746..43400a62e 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -255,19 +255,20 @@ def calculate_unet_tflops(config, pipeline, batch_size, rngs, train): / jax.local_device_count() ) + def get_dummy_flux_inputs(config, pipeline, batch_size): """Returns randomly initialized flux inputs.""" latents, latents_ids = pipeline.prepare_latents( - batch_size=batch_size, - num_channels_latents=pipeline.flux.in_channels // 4, - height=config.resolution, - width=config.resolution, - vae_scale_factor=pipeline.vae_scale_factor, - dtype=config.activations_dtype, - rng=pipeline.rng + batch_size=batch_size, + num_channels_latents=pipeline.flux.in_channels // 4, + height=config.resolution, + width=config.resolution, + vae_scale_factor=pipeline.vae_scale_factor, + dtype=config.activations_dtype, + rng=pipeline.rng, ) guidance_vec = jnp.asarray([config.guidance_scale] * batch_size, dtype=config.activations_dtype) - + timesteps = jnp.ones((batch_size,), dtype=config.weights_dtype) t5_hidden_states_shape = ( batch_size, @@ -282,7 +283,7 @@ def get_dummy_flux_inputs(config, pipeline, batch_size): 768, ) clip_hidden_states = jnp.zeros(clip_hidden_states_shape, dtype=config.weights_dtype) - + return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) @@ -293,7 +294,9 @@ def calculate_flux_tflops(config, pipeline, batch_size, rngs, train): cache the compilation when flash is enabled. """ - (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs(config, pipeline, batch_size) + (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs( + config, pipeline, batch_size + ) return ( max_utils.calculate_model_tflops( pipeline.flux, diff --git a/src/maxdiffusion/pipelines/flux/__init__.py b/src/maxdiffusion/pipelines/flux/__init__.py index 076f6350e..fcc4854fb 100644 --- a/src/maxdiffusion/pipelines/flux/__init__.py +++ b/src/maxdiffusion/pipelines/flux/__init__.py @@ -1,5 +1,5 @@ -_import_structure = { "pipeline_jflux" : "JfluxPipeline" } +_import_structure = {"pipeline_jflux": "JfluxPipeline"} from .flux_pipeline import ( FluxPipeline, -) \ No newline at end of file +) diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py index cd40a45be..e655b491d 100644 --- a/src/maxdiffusion/pipelines/flux/flux_pipeline.py +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -13,15 +13,13 @@ # limitations under the License. from functools import partial -from typing import Dict, List, Optional, Union, Callable +from typing import List, Optional, Union, Callable import jax import jax.numpy as jnp -import numpy as np import math -from flax.core.frozen_dict import FrozenDict from transformers import (CLIPTokenizer, FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer) -from einops import rearrange, repeat +from einops import rearrange from jax.typing import DTypeLike from chex import Array @@ -30,9 +28,7 @@ from maxdiffusion.utils import logging from ...models import FlaxAutoencoderKL -from ...schedulers import ( - FlaxEulerDiscreteScheduler -) +from ...schedulers import (FlaxEulerDiscreteScheduler) from ..pipeline_flax_utils import FlaxDiffusionPipeline from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel @@ -76,13 +72,13 @@ def __init__( self.rng = rng def create_noise( - self, - num_samples: int, - height: int, - width: int, - dtype: DTypeLike, - seed: jax.random.PRNGKey, - ): + self, + num_samples: int, + height: int, + width: int, + dtype: DTypeLike, + seed: jax.random.PRNGKey, + ): return jax.random.normal( key=seed, shape=(num_samples, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16)), @@ -106,11 +102,7 @@ def vae_decode(self, latents, vae, state, config): return img def vae_encode(self, latents, vae, state): - img = vae.apply( - {"params": state["params"]}, - latents, - deterministic=True, - method=vae.encode).latent_dist.mode() + img = vae.apply({"params": state["params"]}, latents, deterministic=True, method=vae.encode).latent_dist.mode() img = vae.config.scaling_factor * (img - vae.config.shift_factor) return img @@ -122,7 +114,7 @@ def pack_latents( num_channels_latents: int, height: int, width: int, - ): + ): latents = jnp.reshape(latents, (batch_size, num_channels_latents, height // 2, 2, width // 2, 2)) latents = jnp.permute_dims(latents, (0, 2, 4, 1, 3, 5)) latents = jnp.reshape(latents, (batch_size, (height // 2) * (width // 2), num_channels_latents * 4)) @@ -130,7 +122,14 @@ def pack_latents( return latents def prepare_latents( - self, batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + vae_scale_factor: int, + dtype: jnp.dtype, + rng: Array, ): # VAE applies 8x compression on images but we must also account for packing which @@ -149,7 +148,6 @@ def prepare_latents( return latents, latent_image_ids - def prepare_latent_image_ids(self, height, width): latent_image_ids = jnp.zeros((height, width, 3)) latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None]) @@ -162,7 +160,11 @@ def prepare_latent_image_ids(self, height, width): return latent_image_ids.astype(jnp.bfloat16) def get_clip_prompt_embeds( - self, prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + tokenizer: CLIPTokenizer, + text_encoder: FlaxCLIPTextModel, ): prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = tokenizer( @@ -182,7 +184,6 @@ def get_clip_prompt_embeds( prompt_embeds = jnp.tile(prompt_embeds, (num_images_per_prompt, 1)) return prompt_embeds - def get_t5_prompt_embeds( self, prompt: Union[str, List[str]], @@ -210,8 +211,10 @@ def get_t5_prompt_embeds( if encode_in_batches: prompt_embeds = None for i in range(0, text_input_ids.shape[0], encode_batch_size): - batch_prompt_embeds = text_encoder(text_input_ids[i:i+encode_batch_size], attention_mask=None, output_hidden_states=False)["last_hidden_state"] - if prompt_embeds == None: + batch_prompt_embeds = text_encoder( + text_input_ids[i : i + encode_batch_size], attention_mask=None, output_hidden_states=False + )["last_hidden_state"] + if prompt_embeds is None: prompt_embeds = batch_prompt_embeds else: prompt_embeds = jnp.concatenate([prompt_embeds, batch_prompt_embeds]) @@ -227,7 +230,6 @@ def get_t5_prompt_embeds( return prompt_embeds - def encode_prompt( self, prompt: Union[str, List[str]], @@ -239,9 +241,9 @@ def encode_prompt( num_images_per_prompt: int = 1, max_sequence_length: int = 512, encode_in_batches: bool = False, - encode_batch_size: int = None + encode_batch_size: int = None, ): - + if encode_in_batches: assert encode_in_batches is not None @@ -260,36 +262,25 @@ def encode_prompt( text_encoder=t5_text_encoder, max_sequence_length=max_sequence_length, encode_in_batches=encode_in_batches, - encode_batch_size=encode_batch_size + encode_batch_size=encode_batch_size, ) text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) return prompt_embeds, pooled_prompt_embeds, text_ids - def _generate( - self, - flux_params, - vae_params, - latents, - latent_image_ids, - prompt_embeds, - txt_ids, - vec, - guidance_vec, - c_ts, - p_ts + self, flux_params, vae_params, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts ): def loop_body( - step, - args, - transformer, - latent_image_ids, - prompt_embeds, - txt_ids, - vec, - guidance_vec, + step, + args, + transformer, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, ): latents, state, c_ts, p_ts = args latents_dtype = latents.dtype @@ -297,7 +288,7 @@ def loop_body( t_prev = p_ts[step] t_vec = jnp.full((latents.shape[0],), t_curr, dtype=latents.dtype) pred = transformer.apply( - {"params": state['params']}, + {"params": state["params"]}, hidden_states=latents, img_ids=latent_image_ids, encoder_hidden_states=prompt_embeds, @@ -311,13 +302,13 @@ def loop_body( return latents, state, c_ts, p_ts loop_body_p = partial( - loop_body, - transformer=self.flux, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=txt_ids, - vec=vec, - guidance_vec=guidance_vec, + loop_body, + transformer=self.flux, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=txt_ids, + vec=vec, + guidance_vec=guidance_vec, ) vae_decode_p = partial(self.vae_decode, vae=self.vae, state=vae_params, config=self._config) @@ -330,27 +321,23 @@ def loop_body( def do_time_shift(self, mu: float, sigma: float, t: Array): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - def get_lin_function(self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + def get_lin_function( + self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 + ) -> Callable[[float], float]: m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b def time_shift(self, latents, timesteps): # estimate mu based on linear estimation between two points - lin_function = self.get_lin_function(x1=self._config.max_sequence_length, - y1=self._config.base_shift, - y2=self._config.max_shift) + lin_function = self.get_lin_function( + x1=self._config.max_sequence_length, y1=self._config.base_shift, y2=self._config.max_shift + ) mu = lin_function(latents.shape[1]) timesteps = self.do_time_shift(mu, 1.0, timesteps) return timesteps - def __call__( - self, - timesteps: int, - flux_params, - vae_params - ): + def __call__(self, timesteps: int, flux_params, vae_params): r""" The call function to the pipeline for generation. @@ -376,26 +363,26 @@ def __call__( global_batch_size = 1 * jax.local_device_count() prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( - prompt=self._config.prompt, - prompt_2=self._config.prompt_2, - clip_tokenizer=self.clip_tokenizer, - clip_text_encoder=self.clip_encoder, - t5_tokenizer=self.t5_tokenizer, - t5_text_encoder=self.t5_encoder, - num_images_per_prompt=global_batch_size, - max_sequence_length=self._config.max_sequence_length, - ) + prompt=self._config.prompt, + prompt_2=self._config.prompt_2, + clip_tokenizer=self.clip_tokenizer, + clip_text_encoder=self.clip_encoder, + t5_tokenizer=self.t5_tokenizer, + t5_text_encoder=self.t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=self._config.max_sequence_length, + ) num_channels_latents = self.flux.in_channels // 4 latents, latent_image_ids = self.prepare_latents( - batch_size=global_batch_size, - num_channels_latents=num_channels_latents, - height=self._config.resolution, - width=self._config.resolution, - dtype=jnp.bfloat16, - vae_scale_factor=self.vae_scale_factor, - rng=self.rng, - ) + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=self._config.resolution, + width=self._config.resolution, + dtype=jnp.bfloat16, + vae_scale_factor=self.vae_scale_factor, + rng=self.rng, + ) if self._config.time_shift: timesteps = self.time_shift(latents, timesteps) @@ -414,7 +401,7 @@ def __call__( pooled_prompt_embeds, guidance, c_ts, - p_ts + p_ts, ) images = images diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index c46cd5014..863fa26cd 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -254,13 +254,13 @@ def add_noise( original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, - flux: bool=False, + flux: bool = False, ) -> jnp.ndarray: if flux: t = state.timesteps[timesteps] t = t[:, None, None] - noisy_samples = t * noise + (1-t) * original_samples + noisy_samples = t * noise + (1 - t) * original_samples return noisy_samples sigma = state.sigmas[timesteps].flatten() diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py index 40abaff80..e3b161039 100644 --- a/src/maxdiffusion/train_flux.py +++ b/src/maxdiffusion/train_flux.py @@ -31,6 +31,7 @@ def train(config): from maxdiffusion.trainers.flux_trainer import FluxTrainer + trainer = FluxTrainer(config) trainer.start_training() diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 67bbd2197..1337f2329 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -112,10 +112,10 @@ def write_metrics_to_tensorboard(writer, metrics, step, config): if jax.process_index() == 0: max_logging.log( "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( - step, - metrics['scalar']['perf/step_time_seconds'], - metrics['scalar']['perf/per_device_tflops_per_sec'], - float(metrics['scalar']['learning/loss']) + step, + metrics["scalar"]["perf/step_time_seconds"], + metrics["scalar"]["perf/per_device_tflops_per_sec"], + float(metrics["scalar"]["learning/loss"]), ) ) diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 7970b4207..ed29fe915 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -20,27 +20,23 @@ import time import numpy as np import jax -import optax import jax.numpy as jnp from jax.sharding import PositionalSharding, PartitionSpec as P from flax.linen import partitioning as nn_partitioning from maxdiffusion.checkpointing.flux_checkpointer import ( - FluxCheckpointer, - FLUX_CHECKPOINT, - FLUX_TRANSFORMER_PARAMS_KEY, - FLUX_STATE_KEY, - FLUX_STATE_SHARDINGS_KEY, - FLUX_VAE_PARAMS_KEY, - VAE_STATE_KEY, - VAE_STATE_SHARDINGS_KEY) + FluxCheckpointer, + FLUX_CHECKPOINT, + FLUX_STATE_KEY, + FLUX_STATE_SHARDINGS_KEY, + VAE_STATE_KEY, + VAE_STATE_SHARDINGS_KEY, +) from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion import (max_utils, max_logging) -from PIL import Image from maxdiffusion.train_utils import ( - generate_timestep_weights, get_first_step, load_next_batch, record_scalar_metrics, @@ -49,9 +45,7 @@ from maxdiffusion.maxdiffusion_utils import calculate_flux_tflops -from ..schedulers import ( - FlaxEulerDiscreteScheduler -) +from ..schedulers import (FlaxEulerDiscreteScheduler) class FluxTrainer(FluxCheckpointer): @@ -72,22 +66,19 @@ def create_scheduler(self, pipeline, params): pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32 ) noise_scheduler_state = noise_scheduler.set_timesteps( - state=noise_scheduler_state, - num_inference_steps=self.config.num_inference_steps, - timestep_spacing="flux") + state=noise_scheduler_state, num_inference_steps=self.config.num_inference_steps, timestep_spacing="flux" + ) return noise_scheduler, noise_scheduler_state - + def calculate_tflops(self, pipeline): - per_device_tflops = calculate_flux_tflops( - self.config, pipeline, self.total_train_batch_size, self.rng, train=True - ) + per_device_tflops = calculate_flux_tflops(self.config, pipeline, self.total_train_batch_size, self.rng, train=True) max_logging.log(f"JFLUX per device TFLOPS: {per_device_tflops}") return per_device_tflops def start_training(self): # Hook - #self.pre_training_steps() + # self.pre_training_steps() # Load checkpoint - will load or create states pipeline, params = self.load_checkpoint() @@ -103,7 +94,6 @@ def start_training(self): pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params) pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params) - vae_state, vae_state_mesh_shardings = self.create_vae_state( pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False ) @@ -119,21 +109,21 @@ def start_training(self): del pipeline.t5_encoder # evaluate shapes - + flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( # ambiguous here, but if params=None # Then its 1 of 2 scenarios: # 1. flux state will be loaded directly from orbax # 2. a new flux is being trained from scratch. pipeline=pipeline, - params=None, # Params are loaded inside create_flux_state + params=None, # Params are loaded inside create_flux_state checkpoint_item_name=FLUX_STATE_KEY, is_training=True, ) flux_state = jax.device_put(flux_state, flux_state_mesh_shardings) train_states[FLUX_STATE_KEY] = flux_state state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings - #self.post_training_steps(pipeline, params, train_states, msg="before_training") + # self.post_training_steps(pipeline, params, train_states, msg="before_training") # Create scheduler noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) @@ -165,20 +155,12 @@ def get_shaped_batch(self, config, pipeline=None): w = config.resolution // scale_factor c = 16 ph = pw = 2 - batch_image_shape = ( - self.total_train_batch_size, # b - h*w, - c*ph*pw - ) - img_ids_shape = ( - self.total_train_batch_size, - (2*h // 2) * (2*w // 2), - 3 - ) + batch_image_shape = (self.total_train_batch_size, h * w, c * ph * pw) # b + img_ids_shape = (self.total_train_batch_size, (2 * h // 2) * (2 * w // 2), 3) text_shape = ( self.total_train_batch_size, config.max_sequence_length, - 4096, # Sequence length of text encoder, how to get this programmatically? + 4096, # Sequence length of text encoder, how to get this programmatically? ) text_ids_shape = ( self.total_train_batch_size, @@ -187,7 +169,7 @@ def get_shaped_batch(self, config, pipeline=None): ) prompt_embeds_shape = ( self.total_train_batch_size, - 768, # Sequence length of clip, how to get this programmatically? + 768, # Sequence length of clip, how to get this programmatically? ) input_ids_dtype = self.config.activations_dtype @@ -201,11 +183,13 @@ def get_shaped_batch(self, config, pipeline=None): def get_data_shardings(self): data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - data_sharding = {"text_embeds": data_sharding, - "input_ids": data_sharding, - "prompt_embeds": data_sharding, - "pixel_values": data_sharding, - "img_ids": data_sharding} + data_sharding = { + "text_embeds": data_sharding, + "input_ids": data_sharding, + "prompt_embeds": data_sharding, + "pixel_values": data_sharding, + "img_ids": data_sharding, + } return data_sharding @@ -223,33 +207,23 @@ def tokenize_captions(examples, caption_column, encoder): return examples @staticmethod - def transform_images( - examples, - image_column, - image_resolution, - vae_encode, - pack_latents, - prepare_latent_imgage_ids - ): + def transform_images(examples, image_column, image_resolution, vae_encode, pack_latents, prepare_latent_imgage_ids): """Preprocess images to latents.""" images = list(examples[image_column]) images = [ - jax.image.resize( - jnp.asarray(image) / 127.5 - 1.0, - [image_resolution, image_resolution, 3], - method="bilinear", - antialias=True - ) + jax.image.resize( + jnp.asarray(image) / 127.5 - 1.0, [image_resolution, image_resolution, 3], method="bilinear", antialias=True + ) for image in images - ] + ] images = jnp.stack(images, axis=0, dtype=jnp.float16) batch_size = 8 num_batches = len(images) // batch_size + int(len(images) % batch_size != 0) encoded_images = [] for i in range(num_batches): - batch_images = images[i * batch_size:(i + 1) * batch_size] + batch_images = images[i * batch_size : (i + 1) * batch_size] batch_images = jnp.transpose(batch_images, (0, 3, 1, 2)) batch_images = vae_encode(batch_images) batch_images = jnp.transpose(batch_images, (0, 3, 1, 2)) @@ -257,17 +231,11 @@ def transform_images( images = jnp.concatenate(encoded_images, axis=0, dtype=jnp.float16) b, c, h, w = images.shape - images = pack_latents( - latents=images, - batch_size=b, - num_channels_latents=c, - height=h, - width=w) + images = pack_latents(latents=images, batch_size=b, num_channels_latents=c, height=h, width=w) img_ids = prepare_latent_imgage_ids(h // 2, w // 2) img_ids = jnp.tile(img_ids, (b, 1, 1)) - examples["pixel_values"] = jnp.float16(images) examples["img_ids"] = jnp.float16(img_ids) @@ -278,31 +246,27 @@ def load_dataset(self, pipeline, params, train_states): total_train_batch_size = self.total_train_batch_size mesh = self.mesh - encode_fn = partial(pipeline.encode_prompt, - clip_tokenizer=pipeline.clip_tokenizer, - t5_tokenizer=pipeline.t5_tokenizer, - clip_text_encoder=pipeline.clip_encoder, - t5_text_encoder=pipeline.t5_encoder, - encode_in_batches=True, - encode_batch_size=16 - ) + encode_fn = partial( + pipeline.encode_prompt, + clip_tokenizer=pipeline.clip_tokenizer, + t5_tokenizer=pipeline.t5_tokenizer, + clip_text_encoder=pipeline.clip_encoder, + t5_text_encoder=pipeline.t5_encoder, + encode_in_batches=True, + encode_batch_size=16, + ) pack_latents_p = partial(pipeline.pack_latents) prepare_latent_image_ids_p = partial(pipeline.prepare_latent_image_ids) vae_encode_p = partial(pipeline.vae_encode, vae=pipeline.vae, state=train_states["vae_state"]) - - tokenize_fn = partial( - FluxTrainer.tokenize_captions, - caption_column=config.caption_column, - encoder=encode_fn - ) + tokenize_fn = partial(FluxTrainer.tokenize_captions, caption_column=config.caption_column, encoder=encode_fn) image_transforms_fn = partial( FluxTrainer.transform_images, image_column=config.image_column, image_resolution=config.resolution, vae_encode=vae_encode_p, pack_latents=pack_latents_p, - prepare_latent_imgage_ids=prepare_latent_image_ids_p + prepare_latent_imgage_ids=prepare_latent_image_ids_p, ) data_iterator = make_data_iterator( @@ -322,7 +286,13 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da guidance_vec = jnp.full((self.total_train_batch_size,), self.config.guidance_scale, dtype=self.config.activations_dtype) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): p_train_step = jax.jit( - partial(_train_step, guidance_vec=guidance_vec, pipeline=pipeline, scheduler=train_states["scheduler"], config=self.config), + partial( + _train_step, + guidance_vec=guidance_vec, + pipeline=pipeline, + scheduler=train_states["scheduler"], + config=self.config, + ), in_shardings=( state_shardings["flux_state_shardings"], data_shardings, @@ -430,18 +400,18 @@ def compute_loss(state_params): ) # Sample a random timestep for each image bsz = latents.shape[0] - timesteps = jax.random.randint(timestep_rng, shape=(bsz,), minval=0, maxval=len(scheduler.timesteps)-1) + timesteps = jax.random.randint(timestep_rng, shape=(bsz,), minval=0, maxval=len(scheduler.timesteps) - 1) noisy_latents = pipeline.scheduler.add_noise(scheduler, latents, noise, timesteps, flux=True) model_pred = pipeline.flux.apply( - {"params": state_params[FLUX_STATE_KEY]}, - hidden_states=noisy_latents, - img_ids=img_ids, - encoder_hidden_states=text_embeds, - txt_ids=text_embeds_ids, - timestep=scheduler.timesteps[timesteps], - guidance=guidance_vec, - pooled_projections=prompt_embeds, + {"params": state_params[FLUX_STATE_KEY]}, + hidden_states=noisy_latents, + img_ids=img_ids, + encoder_hidden_states=text_embeds, + txt_ids=text_embeds_ids, + timestep=scheduler.timesteps[timesteps], + guidance=guidance_vec, + pooled_projections=prompt_embeds, ).sample target = noise - latents @@ -459,4 +429,3 @@ def compute_loss(state_params): metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} return new_state, metrics, new_train_rng -