From bd46013f3d430e6b903f1de6f6622b972f77976c Mon Sep 17 00:00:00 2001 From: jcaraban Date: Mon, 27 Jan 2025 10:53:42 -0600 Subject: [PATCH 1/3] working flax.linen-flux.1 ported from nnx-jflux --- src/maxdiffusion/__init__.py | 2 + .../base_stable_diffusion_checkpointer.py | 17 +- .../checkpointing/checkpointing_utils.py | 57 +- .../checkpointing/jflux_checkpointer.py | 177 ++++ src/maxdiffusion/common_types.py | 13 + src/maxdiffusion/configs/base_jflux.yml | 260 +++++ src/maxdiffusion/create_jflux_checkpoints.py | 27 + src/maxdiffusion/generate_jflux.py | 109 +++ src/maxdiffusion/max_utils.py | 919 +++++++++++++++--- src/maxdiffusion/maxdiffusion_utils.py | 60 ++ src/maxdiffusion/models/__init__.py | 2 +- src/maxdiffusion/models/ae_flux_nnx.py | 583 +++++++++++ src/maxdiffusion/models/attention_flax.py | 188 ++++ src/maxdiffusion/models/embeddings_flax.py | 197 +++- src/maxdiffusion/models/flux_utils.py | 408 ++++++++ .../models/modeling_flax_pytorch_utils.py | 31 + .../models/modeling_flax_utils.py | 9 +- src/maxdiffusion/models/normalization_flax.py | 149 +++ .../transformers/transformer_flux_flax.py | 621 ++++++++++++ src/maxdiffusion/pipelines/__init__.py | 15 +- src/maxdiffusion/pipelines/jflux/__init__.py | 5 + .../pipelines/jflux/pipeline_jflux.py | 239 +++++ .../pipelines/pipeline_flax_utils.py | 2 +- .../scheduling_euler_discrete_flax.py | 28 +- src/maxdiffusion/tests/flux_tests.py | 35 + src/maxdiffusion/train_jflux.py | 48 + src/maxdiffusion/train_utils.py | 9 +- src/maxdiffusion/trainers/jflux_trainer.py | 494 ++++++++++ .../utils/dynamic_modules_utils.py | 2 +- 29 files changed, 4507 insertions(+), 199 deletions(-) create mode 100644 src/maxdiffusion/checkpointing/jflux_checkpointer.py create mode 100644 src/maxdiffusion/configs/base_jflux.yml create mode 100644 src/maxdiffusion/create_jflux_checkpoints.py create mode 100644 src/maxdiffusion/generate_jflux.py create mode 100644 src/maxdiffusion/models/ae_flux_nnx.py create mode 100644 src/maxdiffusion/models/flux_utils.py create mode 100644 src/maxdiffusion/models/normalization_flax.py create mode 100644 src/maxdiffusion/models/transformers/transformer_flux_flax.py create mode 100644 src/maxdiffusion/pipelines/jflux/__init__.py create mode 100644 src/maxdiffusion/pipelines/jflux/pipeline_jflux.py create mode 100644 src/maxdiffusion/tests/flux_tests.py create mode 100644 src/maxdiffusion/train_jflux.py create mode 100644 src/maxdiffusion/trainers/jflux_trainer.py diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index bb2d5d25d..dc7768431 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -409,6 +409,7 @@ "FlaxStableDiffusionInpaintPipeline", "FlaxStableDiffusionPipeline", "FlaxStableDiffusionXLPipeline", + "JfluxPipeline", ] ) @@ -478,6 +479,7 @@ FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, FlaxStableDiffusionXLPipeline, + JfluxPipeline, ) try: diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index 72215549c..9c989add5 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -35,7 +35,7 @@ from maxdiffusion.transformers import (CLIPTokenizer, FlaxCLIPTextModel, CLIPTextConfig, FlaxCLIPTextModelWithProjection) from maxdiffusion.checkpointing.checkpointing_utils import ( - create_orbax_checkpoint_manager, + create_stable_diffusion_orbax_checkpoint_manager, load_stable_diffusion_configs, ) @@ -43,6 +43,7 @@ STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT" _CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS" _CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX" +JFLUX_CHECKPOINT = "JFLUX_CHECKPOINT" class BaseStableDiffusionCheckpointer(ABC): @@ -57,16 +58,11 @@ def __init__(self, config, checkpoint_type): self.mesh = Mesh(devices_array, self.config.mesh_axes) self.total_train_batch_size = self.config.total_train_batch_size - self.checkpoint_manager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, + self.checkpoint_manager = create_stable_diffusion_orbax_checkpoint_manager( + self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type ) def _create_optimizer(self, config, learning_rate): - learning_rate_scheduler = max_utils.create_learning_rate_schedule( learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) @@ -74,7 +70,6 @@ def _create_optimizer(self, config, learning_rate): return tx, learning_rate_scheduler def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training): - tx, learning_rate_scheduler = None, None if is_training: learning_rate = self.config.learning_rate @@ -96,7 +91,6 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training) return unet_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): - # Currently VAE training is not supported. weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng) return max_utils.setup_initial_state( @@ -112,7 +106,6 @@ def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=F ) def create_text_encoder_state(self, pipeline, params, checkpoint_item_name, is_training): - tx = None if is_training: learning_rate = self.config.text_encoder_learning_rate @@ -259,11 +252,9 @@ def config_to_json(model_or_config): self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) def load_params(self, step=None): - self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX def load_checkpoint(self, step=None, scheduler_class=None): - pipeline_class = self._get_pipeline_class() self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index b8710e1a6..b9ebe4ba8 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -38,11 +38,12 @@ def create_orbax_checkpoint_manager( checkpoint_dir: str, enable_checkpointing: bool, - save_interval_steps, + save_interval_steps: int, checkpoint_type: str, dataset_type: str = "tf", use_async: bool = True, orbax_logger: Optional[abstract_logger.AbstractLogger] = None, + item_names=None, ): """ Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled. @@ -56,6 +57,29 @@ def create_orbax_checkpoint_manager( max_logging.log(f"checkpoint dir: {checkpoint_dir}") p = epath.Path(checkpoint_dir) + print("item_names: ", item_names) + + mngr = CheckpointManager( + p, + item_names=item_names, + options=CheckpointManagerOptions( + create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async + ), + logger=orbax_logger, + ) + + max_logging.log("Checkpoint manager created!") + return mngr + + +def create_stable_diffusion_orbax_checkpoint_manager( + checkpoint_dir: str, + enable_checkpointing: bool, + save_interval_steps: int, + checkpoint_type: str, + use_async: bool = True, + orbax_logger: Optional[abstract_logger.AbstractLogger] = None, +): item_names = ( "unet_config", "vae_config", @@ -74,6 +98,8 @@ def create_orbax_checkpoint_manager( if dataset_type == "grain": item_names += ("iter",) + if override_item_names is not None: + item_names = override_item_names print("item_names: ", item_names) mngr = CheckpointManager( @@ -84,9 +110,9 @@ def create_orbax_checkpoint_manager( ), logger=orbax_logger, ) - - max_logging.log("Checkpoint manager created!") - return mngr + return create_orbax_checkpoint_manager( + checkpoint_dir, enable_checkpointing, save_interval_steps, use_async, orbax_logger, item_names + ) def load_stable_diffusion_configs( @@ -204,11 +230,10 @@ def load_state_if_possible( if latest_step is None: return None else: - max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: - if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + max_logging.log( + f"restoring from this run's directory latest step {latest_step}" + ) def map_to_pspec(data): pspec = data.sharding.spec @@ -227,18 +252,20 @@ def map_to_pspec(data): dtype=data.dtype, ) - array_handler = ocp.type_handlers.SingleReplicaArrayHandler( - replica_axis_index=0, - broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit - ) - ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) + if enable_single_replica_ckpt_restoring: + array_handler = ocp.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=0, + broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit + ) + ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) restore_args = jax.tree_util.tree_map( map_to_pspec, abstract_unboxed_pre_state, ) + item = {checkpoint_item: ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)} return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) - except: - max_logging.log(f"could not load {checkpoint_item} from orbax") + except Exception as e: + max_logging.log(f"could not load {checkpoint_item} from orbax: {e}") return None diff --git a/src/maxdiffusion/checkpointing/jflux_checkpointer.py b/src/maxdiffusion/checkpointing/jflux_checkpointer.py new file mode 100644 index 000000000..b86aee375 --- /dev/null +++ b/src/maxdiffusion/checkpointing/jflux_checkpointer.py @@ -0,0 +1,177 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC +import functools +import jax +from jax.sharding import PartitionSpec +from jax.sharding import Mesh +import orbax.checkpoint as ocp +from maxdiffusion import (max_utils) +from maxdiffusion.pipelines.jflux.pipeline_jflux import JfluxPipeline +from maxdiffusion.models.flux_utils import configs +from maxdiffusion.models.transformers.transformer_flux_flax import FluxTransformer2DModel +from maxdiffusion.models.embeddings_flax import HFEmbedder +from maxdiffusion.models.flux_utils import load_ae +from flax.linen import partitioning as nn_partitioning + +from maxdiffusion.checkpointing.checkpointing_utils import ( + create_orbax_checkpoint_manager, +) + + +def get_device_type(): + """Returns the type of JAX device being used. + + Returns: + str: "gpu", "tpu", or "cpu" + """ + try: + device_kind = jax.devices()[0].device_kind + if "tpu" in device_kind.lower(): + return "tpu" + elif "amd" in device_kind.lower(): + return "rocm" + elif "nvidia" in device_kind.lower(): + return "cuda" + else: + return "cpu" + except IndexError: + return "cpu" # No devices found, likely using CPU + + +class JfluxCheckpointer(ABC): + flux_state_item_name = "flux_state" + config_item_name = "config" + + def __init__(self, config): + self.config = config + + self.rng = jax.random.PRNGKey(self.config.seed) + devices_array = max_utils.create_device_mesh(config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + self.total_train_batch_size = self.config.total_train_batch_size + + self.checkpoint_manager = create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=self.config.save_interval_steps, + checkpoint_type="none", + item_names={JfluxCheckpointer.flux_state_item_name, JfluxCheckpointer.config_item_name}, + ) + + def _create_optimizer(self, config): + learning_rate_scheduler = max_utils.create_learning_rate_schedule(config) + tx = max_utils.create_optimizer(config, learning_rate_scheduler) + return tx, learning_rate_scheduler + + def create_flux_state(self, flux, init_flux_weights, params, is_training, use_jit=True): + tx, learning_rate_scheduler = None, None + if is_training: + + tx, learning_rate_scheduler = self._create_optimizer(self.config) + + if init_flux_weights is not None: + weights_init_fn = functools.partial(init_flux_weights, rng=self.rng) + else: + weights_init_fn = None + flux_state, state_mesh_shardings = max_utils.setup_initial_state( + model=flux, + tx=tx, + config=self.config, + mesh=self.mesh, + weights_init_fn=weights_init_fn, + model_params=params.get(JfluxCheckpointer.flux_state_item_name, None) if params is not None else None, + checkpoint_manager=self.checkpoint_manager, + checkpoint_item=JfluxCheckpointer.flux_state_item_name, + training=is_training, + use_jit=use_jit, + ) + + return flux_state, state_mesh_shardings, learning_rate_scheduler + + def _get_pipeline_class(self): + return JfluxPipeline + + def save_checkpoint(self, train_step, pipeline, train_states): + items = { + JfluxCheckpointer.config_item_name: ocp.args.JsonSave({"model_name": self.config.model_name}), + } + + items[JfluxCheckpointer.flux_state_item_name] = ocp.args.PyTreeSave(train_states[JfluxCheckpointer.flux_state_item_name]) + + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + + def load_pretrained_model(self, model_name): + # This code to generate the safetensors filename may not generalize + # but loading does not work without it + print(f"loading pretrained model {self.config.pretrained_model_name_or_path}") + stname = self.config.pretrained_model_name_or_path.split("/")[1].lower().replace(".", "") + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flux, weights = FluxTransformer2DModel.from_pretrained( + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, + subfolder="transformer", + from_pt=True, + filename=f"{stname}.safetensors", + mesh=self.mesh, + ) + weights = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), weights) + return flux, weights + + def load_checkpoint(self, step=None, scheduler_class=None): + with jax.default_device(jax.devices("cpu")[0]): + t5 = HFEmbedder( + "ariG23498/t5-v1-1-xxl-flax", + max_length=256 if self.config.model_name == "flux-schnell" else 512, + dtype=jax.numpy.bfloat16, + ) + + clip = HFEmbedder( + "ariG23498/clip-vit-large-patch14-text-flax", + max_length=77, + dtype=jax.numpy.bfloat16, + ) + + ae = load_ae(self.config.model_name, "cpu") + + precision = max_utils.get_precision(self.config) + flash_block_sizes = max_utils.get_flash_block_sizes(self.config) + data_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(*self.config.data_sharding)) + # loading from pretrained here causes a crash when trying to compile the model + # Failed to load HSACO: HIP_ERROR_NoBinaryForGpu + model_params = configs[self.config.model_name].params + flux = FluxTransformer2DModel( + num_layers=model_params.depth, + num_single_layers=model_params.depth_single_blocks, + in_channels=model_params.in_channels, + attention_head_dim=int(model_params.hidden_size / model_params.num_heads), + num_attention_heads=model_params.num_heads, + joint_attention_dim=model_params.context_in_dim, + pooled_projection_dim=model_params.vec_in_dim, + mlp_ratio=model_params.mlp_ratio, + qkv_bias=model_params.qkv_bias, + theta=model_params.theta, + guidance_embeds=model_params.guidance_embed, + axes_dims_rope=model_params.axes_dim, + dtype=self.config.activations_dtype, + weights_dtype=self.config.weights_dtype, + attention_kernel=self.config.attention, + flash_block_sizes=flash_block_sizes, + mesh=self.mesh, + precision=precision, + ) + + return JfluxPipeline(t5, clip, flux, ae, dtype=self.config.activations_dtype, sharding=data_sharding, scheduler=None) diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index b27792f4e..936a39d6d 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -36,8 +36,21 @@ BATCH = "activation_batch" LENGTH = "activation_length" +EMBED = "activation_embed" HEAD = "activation_heads" +KV_BATCH = "activation_kv_batch" +KV_HEAD = "activation_kv_heads" +KV_HEAD_DIM = "activation_kv_head_dim" D_KV = "activation_kv" KEEP_1 = "activation_keep_1" KEEP_2 = "activation_keep_2" CONV_OUT = "activation_conv_out_channels" + +# needed for flash attention +MODEL_MODE_AUTOREGRESSIVE = "autoregressive" +MODEL_MODE_PREFILL = "prefill" +MODEL_MODE_TRAIN = "train" + +# A large negative mask value is used for masking to ensure that the +# softmax function assigns an extremely low probability to the masked positions. +DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) diff --git a/src/maxdiffusion/configs/base_jflux.yml b/src/maxdiffusion/configs/base_jflux.yml new file mode 100644 index 000000000..ce65590a1 --- /dev/null +++ b/src/maxdiffusion/configs/base_jflux.yml @@ -0,0 +1,260 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: 'jflux' + +model_name: "flux-schnell" +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 +save_interval_steps: -1 + +pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-schnell' +checkpoint_path: "" +checkpoint_step: -1 +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +guidance: 4.0 +save_final_checkpoint: False +run_inference_after_training: True + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# Set true to load weights from pytorch +from_pt: False +split_head_dim: True +attention: 'dot_product' # Supported attention: dot_product, flash +flash_block_sizes: {} +# GroupNorm groups +norm_num_groups: 32 + +# If train_new_flux, flux weights will be randomly initialized to train from scratch +# else they will be loaded from pretrained_model_name_or_path +train_new_flux: False +revision: '' + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: '', + # values are v_prediction or leave empty to use scheduler's default. + prediction_type: '', + rescale_zero_terminal_snr: False, + timestep_spacing: '' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + + +# Parallelism +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive'] +logical_axis_rules: [ + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages. + # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape. + # The "stage" needs to be listed first since the microbatch dimension is first before the reshape. + ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_heads', ['tensor','sequence']], + ['activation_kv_heads', ['tensor','sequence']], + ['activation_length', 'sequence'], + ['activation_embed', ['tensor', 'fsdp_transpose']], + ['activation_mlp', 'tensor'], + ['activation_kv', 'tensor'], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_head_dim', 'tensor'], + ['activation_vocab', ['tensor', 'sequence']], + ['activation_stage', 'stage'], + ['activation_exp', 'expert'], + ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], + ['vocab', ['tensor', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']], + ['norm', 'tensor'], + ['heads', ['tensor', 'autoregressive', 'fsdp_transpose']], + ['layers', 'stage'], + ['kv', []], + ['kv_heads', ['tensor', 'autoregressive']], + ['kv_head_dim', []], + ['cache_batch', []], + ['cache_heads', ['autoregressive', 'tensor']], + ['cache_kv', []], + ['cache_sequence', []], + ['exp', 'expert'], + ] +# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 +dcn_fsdp_transpose_parallelism: 1 +dcn_sequence_parallelism: 1 # never recommended +dcn_tensor_parallelism: 1 # never recommended +dcn_pipeline_parallelism: 1 +dcn_expert_parallelism: 1 +dcn_autoregressive_parallelism: 1 # never recommended +ici_data_parallelism: 1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_fsdp_transpose_parallelism: -1 +ici_sequence_parallelism: 1 +ici_tensor_parallelism: 1 +ici_autoregressive_parallelism: 1 +ici_pipeline_parallelism: 1 +ici_expert_parallelism: 1 + +# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation, +# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. +num_slices: 1 + + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '/dev/shm/jax' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 50 +num_train_epochs: 1 +seed: 102333 +output_dir: '/workspace/runs' +per_device_batch_size: 2 + +warmup_steps_fraction: 0.0 +cosine_learning_rate_final_fraction: 1.0 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A confident Grovyle, a grass-type Pokémon, strikes a dynamic pose with its leafy appendages." +negative_prompt: "purple, red" +do_classifier_free_guidance: True +guidance_scale: 4.0 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 50 + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" +unet_checkpoint: "" # needed in pyconfig + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' + +# added from maxtext version +hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' +compile_topology: '' +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + +custom_mesh: "" # Available options: ['hybrid_ring_64x4'] +# Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html +allow_split_physical_axes: False diff --git a/src/maxdiffusion/create_jflux_checkpoints.py b/src/maxdiffusion/create_jflux_checkpoints.py new file mode 100644 index 000000000..6bb60146c --- /dev/null +++ b/src/maxdiffusion/create_jflux_checkpoints.py @@ -0,0 +1,27 @@ +from typing import Sequence + +from absl import app +from maxdiffusion import pyconfig + +import maxdiffusion.checkpointing.jflux_checkpointer as jflux_checkpointer + + +def run(config): + checkpointer = jflux_checkpointer.JfluxCheckpointer(config) + flux, flux_state = checkpointer.load_pretrained_model(config.pretrained_model_name_or_path) + # INTERNAL: Failed to load HSACO: HIP_ERROR_NoBinaryForGpu when jitting + state, _, _ = checkpointer.create_flux_state(flux, None, {checkpointer.flux_state_item_name: flux_state}, True, False) + step = config.checkpoint_step + if step is None or step < 0: + step = 0 + checkpointer.save_checkpoint(step, None, {checkpointer.flux_state_item_name: state}) + checkpointer.checkpoint_manager.wait_until_finished() + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/generate_jflux.py b/src/maxdiffusion/generate_jflux.py new file mode 100644 index 000000000..dc8e00ba0 --- /dev/null +++ b/src/maxdiffusion/generate_jflux.py @@ -0,0 +1,109 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import time +from typing import Sequence + +import numpy as np + +import jax +import jax.numpy as jnp +from absl import app +from maxdiffusion import (pyconfig, max_logging) +from PIL import Image +from flax.linen import partitioning as nn_partitioning + +import maxdiffusion.checkpointing.jflux_checkpointer as jflux_checkpointer + +from einops import rearrange +import os +import re +from glob import iglob + + +def run(config): + device_type = jflux_checkpointer.get_device_type() + max_logging.log(f"Using {device_type} device") + + output_dir = "output" + seed = jax.random.PRNGKey(seed=102333 if config.seed is None else config.seed) + max_logging.log(f"Generating with seed {config.seed}:\n{config.prompt}") + + checkpointer = jflux_checkpointer.JfluxCheckpointer(config) + pipeline = checkpointer.load_checkpoint() + state, _, _ = checkpointer.create_flux_state(pipeline.flux, pipeline.init_flux_weights, None, True) + state = state.params + + with checkpointer.mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + state = jax.device_put(state, pipeline.data_sharding) + + img = pipeline.create_noise(len(jax.devices()), config.resolution, config.resolution, config.activations_dtype, seed) + (txt, txt_ids, vec, img) = pipeline.prepare_inputs([config.prompt for _ in range(len(jax.devices()))], img) + + def do_inference(): + return pipeline( + state, + txt, + txt_ids, + vec, + config.num_inference_steps, + config.resolution, + config.resolution, + config.guidance_scale, + img, + shift=config.model_name != "flux-schnell", + ) + + t0 = time.perf_counter() + x = do_inference() + t1 = time.perf_counter() + print(f"Compile time: {t1 - t0:.1f}s.") + # real run + max_logging.log("real inference") + t0 = time.perf_counter() + x = do_inference() + t1 = time.perf_counter() + + output_name = os.path.join(output_dir, "maxdiff_img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"maxdiff_img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + fn = output_name.format(idx=idx) + max_logging.log(f"Done in {t1 - t0:.1f}s. Saving {fn}") + # bring into PIL format and save + x = x.clip(-1, 1) + x = rearrange(x[0], "c h w -> h w c") + + x = 127.5 * (x + 1.0) + x_numpy = np.array(x.astype(jnp.uint8)) + img = Image.fromarray(x_numpy) + + img.save(fn, quality=95, subsampling=0) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 4ac1739cc..f8dc415b9 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -1,93 +1,112 @@ -# ruff: noqa """ - Copyright 2024 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -# pylint: disable=bare-except, consider-using-generator """ Common Max Utils needed by multiple modules""" -import functools -from functools import reduce -from contextlib import nullcontext -import json -import yaml -import os -from pathlib import Path -import subprocess - import numpy as np - -import flax import jax import jax.numpy as jnp +from jax.experimental import mesh_utils +from maxdiffusion import checkpointing +from maxdiffusion import common_types +import functools +import time import optax +import os +import socket +import subprocess +from etils import epath +from collections.abc import Sequence +import collections +from typing import Any, Tuple, Union, Callable, Set +from functools import reduce from maxdiffusion import max_logging from maxdiffusion.checkpointing import checkpointing_utils from maxdiffusion.models.attention_flax import AttentionOp -from flax import linen as nn -import flax.linen as nn import flax.linen.module as module_lib from flax.linen.summary import _process_inputs from flax.typing import ( PRNGKey, RNGSequences, ) -from flax.linen import partitioning as nn_partitioning + + +import orbax.checkpoint as ocp +import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager + + +import json +import yaml +import flax from flax.training import train_state -from jax.experimental import mesh_utils -from jax.sharding import PositionalSharding -from flax import struct -from typing import ( - Callable, - Any, - Tuple, - Union, - Set, -) -from flax import core -from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning from tensorboardX import writer - from google.cloud import storage -FrozenDict = core.frozen_dict.FrozenDict +# pylint: disable=too-many-positional-arguments -class InferenceState(struct.PyTreeNode): - # pylint: disable=g-bare-generic - apply_fn: Callable = struct.field(pytree_node=False) - params: FrozenDict[str, Any] | None = struct.field(pytree_node=True) +def find_nans_and_infs(pytree): + def finder(x): + return jnp.any(jnp.isinf(x) | jnp.isnan(x)) + + bad_pytree = jax.tree_util.tree_map(finder, pytree) + return jax.tree_util.tree_flatten(bad_pytree) def l2norm_pytree(x): """L2 norm of a pytree of arrays.""" - return jax.tree_util.tree_reduce(lambda x, y: x + jax.numpy.sum(y**2), x, initializer=0.0) ** 0.5 + return jnp.sqrt(jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(jnp.square(y)), x, initializer=0.0)) -def activate_profiler(config): - if jax.process_index() == 0 and config.enable_profiler: - jax.profiler.start_trace(config.tensorboard_dir) +def calculate_num_params_from_pytree(params): + params_sizes = jax.tree_util.tree_map(jax.numpy.size, params) + total_parameters = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes) + assert total_parameters >= 0 + return total_parameters -def deactivate_profiler(config): - if jax.process_index() == 0 and config.enable_profiler: - jax.profiler.stop_trace() +def calculate_total_params_per_chip(params): + """Calculate total paramsper chip.""" + + def calculate_leaf_params_per_chip(arr): + shard = arr.addressable_shards[0] + return np.prod(shard.data.shape) + + params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_per_chip, params) + total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip) + return total_parameters_per_chip + + +def calculate_bytes_from_pytree(params): + params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) + total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) + return total_bytes + + +def summarize_size_from_pytree(params): + num_params = calculate_num_params_from_pytree(params) + num_bytes = calculate_bytes_from_pytree(params) + return num_params, num_bytes, num_bytes / num_params def initialize_summary_writer(config): - return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None + summary_writer_path = os.path.join(config.tensorboard_dir, config.run_name) + return writer.SummaryWriter(summary_writer_path) if jax.process_index() == 0 else None def close_summary_writer(summary_writer): @@ -105,7 +124,7 @@ def _prepare_metrics_for_json(metrics, step, run_name): return metrics_dict -def write_metrics_locally(metrics, step, config, file): +def write_metrics_locally(metrics, step, config, file, is_training=True): """Writes metrics locally for testing""" if step == 0: file.truncate(0) @@ -113,7 +132,7 @@ def write_metrics_locally(metrics, step, config, file): metrics_dict = _prepare_metrics_for_json(metrics, step, config.run_name) file.write(str(json.dumps(metrics_dict)) + "\n") - if step == config.max_train_steps - 1: + if is_training and step == config.steps - 1: file.close() @@ -150,6 +169,26 @@ def add_text_to_summary_writer(key, value, summary_writer): summary_writer.add_text(key, value) +def write_metrics_for_gcs(metrics, step, config, running_metrics, is_training=True): + """Writes metrics to gcs""" + metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name) + running_metrics.append(metrics_dict_step) + if is_training and (step + 1) % config.log_period == 0 or step == config.steps - 1: + start_step = (step // config.log_period) * config.log_period + metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt" + with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs: + for metrics_step in running_metrics: + metrics_for_gcs.write(str(json.dumps(metrics_step)) + "\n") + + metrics_for_gcs.close() + gcs_filename = os.path.join(config.metrics_dir, metrics_filename) + max_logging.log(f"Moving file {metrics_filename} to GCS...") + upload_blob(gcs_filename, metrics_filename) + max_logging.log(f"File {metrics_filename} moved successfully!") + running_metrics = [] # reset running_metrics to empty list + return running_metrics + + def write_config_raw_keys_for_gcs(raw_keys): """Writes config raw keys to GCS""" if not raw_keys["save_config_to_gcs"] or jax.process_index() != 0: @@ -175,24 +214,6 @@ def parse_gcs_bucket_and_prefix(destination_gcs_name): return bucket, key -def download_blobs(source_gcs_folder, local_destination): - """Downloads a folder to a local location""" - bucket_name, prefix_name = parse_gcs_bucket_and_prefix(source_gcs_folder) - storage_client = storage.Client() - bucket = storage_client.get_bucket(bucket_name) - blobs = bucket.list_blobs(prefix=prefix_name) - for blob in blobs: - file_split = blob.name.split("/") - directory = os.path.join(local_destination, "/".join(file_split[0:-1])) - Path(directory).mkdir(parents=True, exist_ok=True) - if len(file_split[-1]) <= 0: - continue - download_to_filename = os.path.join(directory, file_split[-1]) - if not os.path.isfile(download_to_filename): - blob.download_to_filename(download_to_filename) - return os.path.join(local_destination, prefix_name) - - def upload_blob(destination_gcs_name, source_file_name): """Uploads a file to a GCS location""" bucket_name, prefix_name = parse_gcs_bucket_and_prefix(destination_gcs_name) @@ -202,26 +223,150 @@ def upload_blob(destination_gcs_name, source_file_name): blob.upload_from_filename(source_file_name) -def walk_and_upload_blobs(config, output_dir): - user_dir = os.path.expanduser("~") - uploaded_files = set() - for root, _, files in os.walk(os.path.abspath(output_dir)): - for file in files: - file_to_upload = os.path.join(root, file) - if file_to_upload in uploaded_files: - continue - gcs_file_name = os.path.join( - config.base_output_directory, - file_to_upload.replace(user_dir, "").strip("/").replace("maxdiffusion", "").strip("/"), - ) - max_logging.log(f"Moving file {file_to_upload} to {gcs_file_name}") - upload_blob(gcs_file_name, file_to_upload) - uploaded_files.add(file_to_upload) - max_logging.log(f"File {file_to_upload} moved successfully!") +def maybe_initialize_jax_distributed_system(raw_keys): + """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of + indirection in MaxText to avoid breaking the call sites unnecessarily. + Currently jax.distributed.initialize() fully works as expected! -def device_put_replicated(x, sharding): - return jax.make_array_from_callback(x.shape, sharding, lambda index: x[index]) + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. + """ + if raw_keys["compile_topology"]: + # Don't initialize jax distributed with AOT compilation + return + if is_gpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") + initialize_jax_for_gpu() + max_logging.log("Jax distributed system initialized on GPU!") + elif is_cpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for CPU backend...") + initialize_jax_for_cpu() + max_logging.log("Jax distributed system initialized on CPUs!") + elif ( + raw_keys["enable_checkpointing"] + and raw_keys["async_checkpointing"] + and raw_keys["compile_topology_num_slices"] == -1 + and not raw_keys["enable_single_controller"] + ) or raw_keys["hardware"] == "gpu_multiprocess": + max_logging.log("Attempting to initialize the jax distributed system...") + if not raw_keys["enable_emergency_checkpoint"]: + jax.distributed.initialize() + else: + initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys) + max_logging.log("Jax distributed system initialized!") + + +def initialize_jax_for_gpu(): + """Jax distributed initialize for GPUs.""" + if os.environ.get("JAX_COORDINATOR_IP") is not None: + coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) + coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) + device_list = {os.getenv("CUDA_VISIBLE_DEVICES")} + if len(device_list) == 0: + device_list = None + jax.distributed.initialize( + coordinator_address=f"{coordinator_ip}:{coordinator_port}", + num_processes=int(os.getenv("NNODES")), + process_id=int(os.getenv("NODE_RANK")), + local_device_ids=device_list, + ) + max_logging.log(f"JAX global devices: {jax.devices()}") + + +def initialize_jax_for_cpu(): + if os.environ.get("JAX_COORDINATOR_IP") is not None: + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" + coordinator_ip_address = str(os.getenv("JAX_COORDINATOR_IP")) + coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK + # Env variables to be set in XPK or otherwise + job_index = int(os.environ.get("NODE_RANK")) + # job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) + # processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) + pid = job_index # * processes_in_job + job_completion_index + max_logging.log(f" Jax process id is {pid} ") + # Explicit initialize is needed only for CPUs + jax.distributed.initialize( + coordinator_address=coordinator_address, + process_id=pid, + num_processes=int(os.environ.get("NNODES")), + ) + + +def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): + """Initialize JAX distributed runtime for TPUs when emergency checkpointing is used. + The information required to initialize JAX distributed runtime will be written by GKE to + the local checkpoint directory. This function retrieves that information and initializes + JAX distributed runtime. + """ + process_id, coordinator_address = _retrieve_jax_init_info(raw_keys) + + if process_id != "" and coordinator_address != "": + max_logging.log( + f"Using {process_id} as the process_id and {coordinator_address} as the" + " coordinator_address to initialize JAX distributed runtime..." + ) + jax.distributed.initialize(coordinator_address=coordinator_address, process_id=int(process_id)) + else: + max_logging.log( + "Initializing JAX distributed runtime without args when emergency checkpointing is" + " enabled. This should not happen and your workload may have unexpected behavior." + ) + jax.distributed.initialize() + + ocp.multihost.initialize_runtime_to_distributed_ids() + + +def _retrieve_jax_init_info(raw_keys): + """Retrieve JAX init info from a local file.""" + JAX_INIT_INFO_FILE = "jax-init-info.txt" + local_jax_init_info_file = epath.Path(raw_keys["local_checkpoint_directory"]) / JAX_INIT_INFO_FILE + # Allow time for the JAX init info file to be populated by GKE. This is needed because the file is + # only populated when the worker with process id of 0 is determined. After a disruption, although some + # workers might be up and running, the init info file won't be populated until the node with process id + # of 0 is known and this could take time. Using 900 seconds for now and it needs to be increased if the + # "repair" time is longer. + for i in range(900): + if local_jax_init_info_file.exists(): + return local_jax_init_info_file.read_text().split("\n")[:2] + max_logging.log(f"Unable to locate {JAX_INIT_INFO_FILE} after {i} seconds, sleeping for 1 second before retrying...") + time.sleep(1) + max_logging.log( + f"Unable to locate {JAX_INIT_INFO_FILE} after 900 seconds," "returning empty process id and coordinator address." + ) + return "", "" + + +def is_cpu_backend(raw_keys): + """Determine whether Maxtext is intended to run on a CPU backend.""" + return raw_keys["hardware"] == "cpu" + + +def is_gpu_backend(raw_keys): + """Determine whether Maxtext is intended to run on a GPU backend.""" + return raw_keys["hardware"] == "gpu" + + +def get_coordinator_ip_address(): + """Get coordinator IP Address with retries""" + coordinator_address = "" + coordinator_ip_address = "" + if os.environ.get("JAX_COORDINATOR_ADDRESS") is not None: + coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS") + coordinator_found = False + lookup_attempt = 1 + max_coordinator_lookups = 50 + while not coordinator_found and lookup_attempt <= max_coordinator_lookups: + try: + coordinator_ip_address = socket.gethostbyname(coordinator_address) + coordinator_found = True + except socket.gaierror: + max_logging.log( + f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying..." + ) + lookup_attempt += 1 + time.sleep(5) + max_logging.log(f"Coordinator IP address: {coordinator_ip_address}") + return coordinator_ip_address def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type): @@ -242,7 +387,6 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) target_type = "slices" if parallelism_type == "DCN" else "devices per slice" - assert ( np.prod(parallelism_vals) == target_product ), f"Number of {target_type} {target_product} does not match\ @@ -251,41 +395,124 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ return parallelism_vals -def create_device_mesh(config, devices=None, logging=True): +def create_custom_64x4_device_mesh( + mesh_shape: Sequence[int], + dcn_mesh_shape: Sequence[int], + devices: Sequence[Any], + process_is_granule: bool = False, + should_sort_granules_by_key: bool = True, +) -> np.ndarray: + """Custom device mesh for 64x4 ici parallelism""" + assert len(devices) % 256 == 0, f"This custom mesh is not valid for {len(devices)} devices" + attr = "process_index" if process_is_granule else "slice_index" + if not hasattr(devices[0], attr): + raise ValueError(f"Device {devices[0]} does not have attribute {attr}. See" " `process_is_granule` option.") + granule_dict = collections.defaultdict(list) + for dev in devices: + granule_dict[getattr(dev, attr)].append(dev) + granules = ( + [granule_dict[key] for key in sorted(granule_dict.keys())] if should_sort_granules_by_key else granule_dict.values() + ) + if np.prod(dcn_mesh_shape) != len(granules): + raise ValueError(f"Number of slices {len(granules)} must equal the product of " f"dcn_mesh_shape {dcn_mesh_shape}") + per_granule_meshes = [ + mesh_utils.create_device_mesh( + [16, 16], + granule, + allow_split_physical_axes=False, + ) + for granule in granules + ] + + def reshape_mesh_to_rings(a): + b = [] + for i in range(8): + b.append([]) + for j in range(8): + a_i = i * 2 + a_j = j * 2 + # forms a ring of size 4 + b[i].append([a[a_i, a_j], a[a_i, a_j + 1], a[a_i + 1, a_j + 1], a[a_i + 1, a_j]]) + b = np.array(b) + b = np.reshape(b, (64, 4)) + return b + + per_granule_meshes = [np.reshape(reshape_mesh_to_rings(x), mesh_shape) for x in per_granule_meshes] + # TODO(jekbradbury): handle non-uniform DCN topologies + granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) + blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh) + device_mesh = np.block(blocks.tolist()) + return device_mesh + + +def create_device_mesh(config, devices=None): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: devices = jax.devices() num_devices = len(devices) - try: - num_slices = 1 + max([d.slice_index for d in devices]) - except: - num_slices = 1 + num_slices = config.num_slices num_devices_per_slice = num_devices // num_slices - max_logging.log(f"Devices: {devices} (num_devices: {num_devices})") multi_slice_env = num_slices > 1 dcn_parallelism = [ config.dcn_data_parallelism, + config.dcn_pipeline_parallelism, config.dcn_fsdp_parallelism, + config.dcn_fsdp_transpose_parallelism, + config.dcn_sequence_parallelism, config.dcn_tensor_parallelism, + config.dcn_expert_parallelism, + config.dcn_autoregressive_parallelism, ] ici_parallelism = [ config.ici_data_parallelism, + config.ici_pipeline_parallelism, config.ici_fsdp_parallelism, + config.ici_fsdp_transpose_parallelism, + config.ici_sequence_parallelism, config.ici_tensor_parallelism, + config.ici_expert_parallelism, + config.ici_autoregressive_parallelism, ] # Find possible unspecified parallelisms ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") + + allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False + if multi_slice_env: dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") - mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) + if config.custom_mesh == "hybrid_ring_64x4": + # asserting on ici parallelism + assert sorted(set(ici_parallelism)) == [ + 1, + 4, + 64, + ], f"Invalid custom_mesh:{config.custom_mesh} chosen for ICI mesh shape {ici_parallelism}" + mesh = create_custom_64x4_device_mesh(ici_parallelism, dcn_parallelism, devices) + else: + mesh = mesh_utils.create_hybrid_device_mesh( + ici_parallelism, + dcn_parallelism, + devices, + allow_split_physical_axes=allow_split_physical_axes, + ) else: - mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) + if allow_split_physical_axes: + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + contiguous_submeshes=False, + allow_split_physical_axes=allow_split_physical_axes, + ) + else: + mesh = mesh_utils.create_device_mesh( + ici_parallelism, + devices, + ) - if logging: - max_logging.log(f"Decided on mesh: {mesh}") + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") return mesh @@ -330,13 +557,14 @@ def init_train_state(model, tx, weights_init_fn, params=None, training=True, eva return state -def get_abstract_state(model, tx, config, mesh, weights_init_fn, training=True): +def get_abstract_state(model, tx, config, mesh, weights_init_fn, params, training=True): """Get a shaped abstraction of the state (including optimizer)""" init_state_partial = functools.partial( init_train_state, model=model, tx=tx, weights_init_fn=weights_init_fn, + params=params, training=training, eval_only=True, ) @@ -366,6 +594,7 @@ def setup_initial_state( checkpoint_manager=None, checkpoint_item=None, training=True, + use_jit=True, ): """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. @@ -387,9 +616,11 @@ def setup_initial_state( """ # Initialization state = None - unboxed_abstract_state, _, state_mesh_shardings = get_abstract_state(model, tx, config, mesh, weights_init_fn, training) - with nn_partitioning.axis_rules(config.logical_axis_rules): - if checkpoint_manager and checkpoint_item: + unboxed_abstract_state, _, state_mesh_shardings = get_abstract_state( + model, tx, config, mesh, weights_init_fn, model_params, training + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + if checkpoint_manager is not None and checkpoint_item is not None: max_logging.log(f"setup_initial_state for {checkpoint_item}") state = checkpointing_utils.load_state_if_possible( checkpoint_manager, @@ -401,6 +632,7 @@ def setup_initial_state( state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") + init_train_state_partial = functools.partial( init_train_state, model=model, @@ -411,11 +643,16 @@ def setup_initial_state( eval_only=False, ) - state = jax.jit( - init_train_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() + if use_jit: + init_train_state_partial = jax.jit( + init_train_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + ) + state = init_train_state_partial() + else: + state = init_train_state_partial() + state = jax.device_put(state, state_mesh_shardings) state = unbox_logicallypartioned_trainstate(state) @@ -426,41 +663,287 @@ def setup_initial_state( # ----------------------------------------------------------------------------- -def create_learning_rate_schedule(learning_rate, learning_rate_schedule_steps, warmup_steps_fraction, max_train_steps): - """Creates a warmup to constant learning rate schedule: - We take inspiration from WarmupHoldPolicy used in stable diffusion - see https://github.com/NVIDIA/NeMo/blob/dbc8a6ee490355bfa0cb1e10b8d199dcc47482e0/nemo/core/optim/lr_scheduler.py#L142 - Learning rate schedule has either two parts: +def create_learning_rate_schedule(config): + """Creates a warmup and cosine decay learning rate schedule: + We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 + Learning rate schedule has either two or three parts: 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] - 2) Constant learning rate of 0 afterwards. + 2) Cosine from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps + 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. + The zero learning rate section can be used to more accurately measure the fully trained model's performance. """ - lr = learning_rate - warmup_steps = int(learning_rate_schedule_steps * warmup_steps_fraction) - constant_zero_steps = max_train_steps - warmup_steps + def make_cos_schedule(init_lr, final_lr, len_steps): + def schedule(step): + pct = (step) / len_steps + a = 0.5 * (jnp.cos(jnp.pi * pct) + 1) + lr = init_lr * a + final_lr * (1 - a) + return lr + + return schedule + + lr = config.learning_rate + cos_final_lr = lr * config.cosine_learning_rate_final_fraction + + warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction) + cos_steps = config.learning_rate_schedule_steps - warmup_steps + constant_zero_steps = config.max_train_steps - config.learning_rate_schedule_steps warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps) - constant_schedule = optax.constant_schedule(lr) + cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps) + constant_schedule = optax.constant_schedule(0.0) - pieces = [warmup_schedule, constant_schedule] + pieces = [warmup_schedule, cos_schedule] boundaries = [ warmup_steps, - warmup_steps + constant_zero_steps, + warmup_steps + cos_steps, ] + if constant_zero_steps > 0: + pieces.append(constant_schedule) + boundaries.append(warmup_steps + cos_steps + constant_zero_steps) + return optax.join_schedules(pieces, boundaries) -def create_optimizer(config, learning_rate_scheduler): - return optax.adamw( - learning_rate=learning_rate_scheduler, - b1=config.adam_b1, - b2=config.adam_b2, - eps=config.adam_eps, - weight_decay=config.adam_weight_decay, +# Cross entropy implementation is taken from original T5X codebase: +# https://github.com/google-research/t5x/blob/ace831eea1e2742b4299cd1a9af7e4f302038351/t5x/losses.py#L25-L101 +@jax.custom_vjp +def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Computes cross entropy loss with stable custom gradient. + Computes a stabilized-gradient version of: + -jnp.sum(targets * nn.log_softmax(logits), axis=-1) + If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 + will be added to the cross entropy loss (z = softmax normalization constant). + The two uses of z_loss are: + 1. To keep the logits from drifting too far from zero, which can cause + unacceptable roundoff errors in bfloat16. + 2. To encourage the logits to be normalized log-probabilities. + Args: + logits: [batch, length, num_classes] float array. + targets: categorical one-hot targets [batch, length, num_classes] float + array. + z_loss: coefficient for auxiliary z-loss loss term. + Returns: + tuple with the total loss and the z_loss, both + float arrays with shape [batch, length]. + """ + logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) + log_softmax = logits - logits_sum + loss = -jnp.sum(targets * log_softmax, axis=-1) + # Add auxiliary z-loss term. + log_z = jnp.squeeze(logits_sum, axis=-1) + total_z_loss = z_loss * jax.lax.square(log_z) + loss += total_z_loss + return loss, total_z_loss + + +def _cross_entropy_with_logits_fwd(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0) -> Tuple[ + Tuple[jnp.ndarray, jnp.ndarray], + Tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + ], +]: + """Forward-mode of `cross_entropy_with_logits`.""" + max_logit = logits.max(axis=-1, keepdims=True) + shifted = logits - max_logit + exp_shifted = jnp.exp(shifted) + sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True) + log_softmax = shifted - jnp.log(sum_exp) + loss = -jnp.sum(targets * log_softmax, axis=-1) + # Add auxiliary z-loss term. + log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) + total_z_loss = z_loss * jax.lax.square(log_z) + loss += total_z_loss + return (loss, total_z_loss), ( + logits, + targets, + z_loss, + exp_shifted, + sum_exp, # pytype: disable=bad-return-type #jax-ndarray + log_softmax, + log_z, ) +def _cross_entropy_with_logits_bwd( + res: Tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + ], + g: Tuple[jnp.ndarray, jnp.ndarray], +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Backward-mode of `cross_entropy_with_logits`.""" + g = g[0] # Ignore z_loss component as that is only used for logging. + logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res + # z-loss term adds the (2 * z_loss * log_z) factor. + deriv = jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets + g_logits = jnp.expand_dims(g, axis=-1) * deriv + g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax + return ( + jnp.asarray(g_logits, logits.dtype), + jnp.asarray(g_targets, targets.dtype), + jnp.array(0.0), + ) # sets z-loss coeff gradient to 0 + + +cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd) + + +# def get_abstract_state(model, tx, config, mesh, weights_init_fn, training=True): +# """Get a shaped abstraction of the state (including optimizer)""" +# init_state_partial = functools.partial( +# init_train_state, +# model=model, +# tx=tx, +# weights_init_fn=weights_init_fn, +# training=training, +# eval_only=True, +# ) +# with nn_partitioning.axis_rules(config.logical_axis_rules): +# abstract_state = jax.eval_shape(init_state_partial) +# +# state_logical_annotations = nn.get_partition_spec(abstract_state) +# +# state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) +# +# abstract_sharded_state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings).eval_shape() +# unboxed_sharded_abstract_state = unbox_logicallypartioned_trainstate(abstract_sharded_state) +# +# # Initialization +# with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): +# state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) +# return unboxed_sharded_abstract_state, state_mesh_annotations, state_mesh_shardings + + +def get_kv_cache_annotations(model, config, rng, mesh): + """Get a shaped abstraction of the state (including optimizer)""" + + def init_kv_cache(model, config): + input_shape = ( + config.micro_batch_size_to_train_on, + config.max_prefill_predict_length, + ) + + model_vars = model.init( + {"params": rng, "dropout": rng, "aqt": rng}, + jnp.ones(input_shape), + jnp.ones(input_shape), + model_mode=common_types.MODEL_MODE_PREFILL, + ) + return model_vars["cache"] + + with nn_partitioning.axis_rules(config.logical_axis_rules): + init_kv_cache_partial = functools.partial(init_kv_cache, model, config) + abstract_state = jax.eval_shape(init_kv_cache_partial) + state_logical_annotations = nn.get_partition_spec(abstract_state) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) + return state_mesh_annotations + + +def print_pytree_shape(print_str, ptree): + print("\n") + print(print_str) + print(jax.tree_util.tree_map(lambda x: x.shape, ptree)) + + +def print_model_vars(print_str, model_vars): + for k in model_vars: + print(f"{print_str} key{k}:") + print(f"\t {model_vars[k]}") + + +def get_project(): + """Get project""" + completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) + project_outputs = completed_command.stdout.decode().strip().split("\n") + if len(project_outputs) < 1 or project_outputs[-1] == "": + max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project '") + return None + return project_outputs[-1] + + +def delete_pytree(p): + def delete_leaf(leaf): + if isinstance(leaf, jax.Array): + leaf.delete() + del leaf + + jax.tree_util.tree_map(delete_leaf, p) + + +def summarize_pytree_data(params, name="Params", raw=False): + """Generate basic metrics of a given Pytree.""" + num_params, total_param_size, avg_param_size = summarize_size_from_pytree(params) + if not raw: + num_params_in_billions = num_params / 1e9 + total_param_size_in_gb = total_param_size / 1e9 + print( + f"{name} stats: \n" + f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" + f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n" + ) + else: + print( + f"{name} stats: \n" + f"\tTotal number of params: {num_params:.3f} \n" + f"\tTotal memory usage: {total_param_size:.3f} bytes \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n" + ) + return num_params, total_param_size, avg_param_size + + +def save_quantized_checkpoint_if_configured(config, params): + assert config.quantization, "quantization must be configured" + if config.save_quantized_params_path: + checkpointing.save_params_to_path(config.save_quantized_params_path, params) + else: + "Skipping saving quantized checkpoint as save_quantized_params_path is null." + + +def print_mem_stats(label: str): + print(f"\nMemstats: {label}:") + try: + for d in jax.local_devices(): + stats = d.memory_stats() + used = round(stats["bytes_in_use"] / 2**30, 2) + limit = round(stats["bytes_limit"] / 2**30, 2) + print(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") + except (RuntimeError, KeyError, TypeError) as ex: + print(f"\tMemstats unavailable, error: {ex}") + + +def print_system_information(): + """Print system information of the current environment. + Note that this will initialize the JAX backend.""" + max_logging.log(f"System Information: Jax Version: {jax.__version__}") + max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") + max_logging.log(f"System Information: Jax Backend: {jax.lib.xla_bridge.get_backend().platform_version}") + + +def activate_profiler(config): + if jax.process_index() == 0 and config.enable_profiler: + jax.profiler.start_trace(config.tensorboard_dir) + + +def deactivate_profiler(config): + if jax.process_index() == 0 and config.enable_profiler: + jax.profiler.stop_trace() + + def get_precision(config): """Get precision from config.""" precision_str = config.precision @@ -476,6 +959,8 @@ def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: + from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel + flash_block_sizes = splash_attention_kernel.BlockSizes( block_q=config.flash_block_sizes["block_q"], block_kv_compute=config.flash_block_sizes["block_kv_compute"], @@ -489,21 +974,6 @@ def get_flash_block_sizes(config): return flash_block_sizes -def delete_pytree(to_delete): - jax.tree_util.tree_map(lambda x: x.delete(), to_delete) - - -def get_memory_allocations(): - devices = jax.local_devices() - gb = 10**9 - for device in devices: - m_stats = device.memory_stats() - max_logging.log( - f"device : {device.process_index}," - f'bytes in use: {m_stats["bytes_in_use"] / gb} / {m_stats["bytes_limit"] / gb} GB' - ) - - # Taking inspiration from flax's https://flax.readthedocs.io/en/v0.5.3/_modules/flax/linen/summary.html#tabulate # to retrieve layer parameters and calculate def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSequences], train, **kwargs): @@ -556,13 +1026,6 @@ def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSe return total_flops -def calculate_num_params_from_pytree(params): - """Calculates number of parameters from a pytree""" - params_sizes = jax.tree_util.tree_map(jax.numpy.size, params) - total_parameters = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes) - return total_parameters - - def get_global_batch_size(per_device_batch_size): return per_device_batch_size * jax.device_count() @@ -572,6 +1035,11 @@ def is_gpu_backend(raw_keys): return raw_keys["hardware"] == "gpu" +def is_gpu_backend(raw_keys): + """Determine whether Maxdiffusion is intended to run on a GPU backend.""" + return raw_keys["hardware"] == "gpu" + + def initialize_jax_for_gpu(): """Jax distribute initialize for GPUs.""" if os.environ.get("JAX_COORDINATOR_IP") is not None: @@ -586,9 +1054,156 @@ def initialize_jax_for_gpu(): def maybe_initialize_jax_distributed_system(raw_keys): + """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of + indirection in MaxText to avoid breaking the call sites unnecessarily. + + Currently jax.distributed.initialize() fully works as expected! + + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. + """ + if raw_keys["compile_topology"]: + # Don't initialize jax distributed with AOT compilation + return if is_gpu_backend(raw_keys): max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") + elif is_cpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for CPU backend...") + initialize_jax_for_cpu() + max_logging.log("Jax distributed system initialized on CPUs!") + elif ( + raw_keys["enable_checkpointing"] + and raw_keys["async_checkpointing"] + and raw_keys["compile_topology_num_slices"] == -1 + and not raw_keys["enable_single_controller"] + ) or raw_keys["hardware"] == "gpu_multiprocess": + max_logging.log("Attempting to initialize the jax distributed system...") + if not raw_keys["enable_emergency_checkpoint"]: + jax.distributed.initialize() + else: + initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys) + max_logging.log("Jax distributed system initialized!") + + +def initialize_jax_for_gpu(): + """Jax distributed initialize for GPUs.""" + if os.environ.get("JAX_COORDINATOR_IP") is not None: + coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) + coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) + device_list = {os.getenv("CUDA_VISIBLE_DEVICES")} + if len(device_list) == 0: + device_list = None + jax.distributed.initialize( + coordinator_address=f"{coordinator_ip}:{coordinator_port}", + num_processes=int(os.getenv("NNODES")), + process_id=int(os.getenv("NODE_RANK")), + local_device_ids=device_list, + ) + max_logging.log(f"JAX global devices: {jax.devices()}") + + +def initialize_jax_for_cpu(): + if os.environ.get("JAX_COORDINATOR_IP") is not None: + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" + coordinator_ip_address = str(os.getenv("JAX_COORDINATOR_IP")) + coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK + # Env variables to be set in XPK or otherwise + job_index = int(os.environ.get("NODE_RANK")) + # job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) + # processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) + pid = job_index # * processes_in_job + job_completion_index + max_logging.log(f" Jax process id is {pid} ") + # Explicit initialize is needed only for CPUs + jax.distributed.initialize( + coordinator_address=coordinator_address, + process_id=pid, + num_processes=int(os.environ.get("NNODES")), + ) + + +def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): + """Initialize JAX distributed runtime for TPUs when emergency checkpointing is used. + The information required to initialize JAX distributed runtime will be written by GKE to + the local checkpoint directory. This function retrieves that information and initializes + JAX distributed runtime. + """ + process_id, coordinator_address = _retrieve_jax_init_info(raw_keys) + + if process_id != "" and coordinator_address != "": + max_logging.log( + f"Using {process_id} as the process_id and {coordinator_address} as the" + " coordinator_address to initialize JAX distributed runtime..." + ) + jax.distributed.initialize(coordinator_address=coordinator_address, process_id=int(process_id)) else: + max_logging.log( + "Initializing JAX distributed runtime without args when emergency checkpointing is" + " enabled. This should not happen and your workload may have unexpected behavior." + ) jax.distributed.initialize() + + ocp.multihost.utils.initialize_runtime_to_distributed_ids() + + +def _retrieve_jax_init_info(raw_keys): + """Retrieve JAX init info from a local file.""" + JAX_INIT_INFO_FILE = "jax-init-info.txt" + local_jax_init_info_file = epath.Path(raw_keys["local_checkpoint_directory"]) / JAX_INIT_INFO_FILE + # Allow time for the JAX init info file to be populated by GKE. This is needed because the file is + # only populated when the worker with process id of 0 is determined. After a disruption, although some + # workers might be up and running, the init info file won't be populated until the node with process id + # of 0 is known and this could take time. Using 900 seconds for now and it needs to be increased if the + # "repair" time is longer. + for i in range(900): + if local_jax_init_info_file.exists(): + return local_jax_init_info_file.read_text().split("\n")[:2] + max_logging.log(f"Unable to locate {JAX_INIT_INFO_FILE} after {i} seconds, sleeping for 1 second before retrying...") + time.sleep(1) + max_logging.log( + f"Unable to locate {JAX_INIT_INFO_FILE} after 900 seconds," "returning empty process id and coordinator address." + ) + return "", "" + + +def is_cpu_backend(raw_keys): + """Determine whether Maxtext is intended to run on a CPU backend.""" + return raw_keys["hardware"] == "cpu" + + +def is_gpu_backend(raw_keys): + """Determine whether Maxtext is intended to run on a GPU backend.""" + return raw_keys["hardware"] == "gpu" + + +def get_coordinator_ip_address(): + """Get coordinator IP Address with retries""" + coordinator_address = "" + coordinator_ip_address = "" + if os.environ.get("JAX_COORDINATOR_ADDRESS") is not None: + coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS") + coordinator_found = False + lookup_attempt = 1 + max_coordinator_lookups = 50 + while not coordinator_found and lookup_attempt <= max_coordinator_lookups: + try: + coordinator_ip_address = socket.gethostbyname(coordinator_address) + coordinator_found = True + except socket.gaierror: + max_logging.log( + f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying..." + ) + lookup_attempt += 1 + time.sleep(5) + max_logging.log(f"Coordinator IP address: {coordinator_ip_address}") + return coordinator_ip_address + + +def create_optimizer(config, learning_rate_scheduler): + return optax.adamw( + learning_rate=learning_rate_scheduler, + b1=config.adam_b1, + b2=config.adam_b2, + eps=config.adam_eps, + weight_decay=config.adam_weight_decay, + ) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 60b1730b2..79e9fba3e 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -24,6 +24,7 @@ from maxdiffusion import ( max_utils, ) +from einops import repeat from .models.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax @@ -239,6 +240,65 @@ def calculate_unet_tflops(config, pipeline, batch_size, rngs, train): ) +def get_dummy_flux_inputs(config, pipeline, batch_size): + """Returns randomly initialized flux inputs.""" + scale_factor = 16 + img_input_shape = ( + batch_size, + pipeline.flux.in_channels, + 2 * config.resolution // scale_factor, + 2 * config.resolution // scale_factor, + ) + + latents = jax.random.normal(jax.random.PRNGKey(0), shape=img_input_shape, dtype=config.weights_dtype) + + latents, latents_ids, guidance_vec = pipeline.prepare_img_ids(latents, 4.0) + + timesteps = jnp.ones((batch_size,), dtype=config.weights_dtype) + t5_hidden_states_shape = ( + batch_size, + pipeline.t5.max_length, + 4096, + ) + t5_hidden_states = jnp.zeros(t5_hidden_states_shape, dtype=config.weights_dtype) + t5_ids = jnp.zeros((batch_size, t5_hidden_states.shape[1], 3), dtype=config.weights_dtype) + + clip_hidden_states_shape = ( + batch_size, + 768, + ) + clip_hidden_states = jnp.zeros(clip_hidden_states_shape, dtype=config.weights_dtype) + + return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) + + +def calculate_flux_tflops(config, pipeline, batch_size, rngs, train): + """ + Calculates jflux tflops. + batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't + cache the compilation when flash is enabled. + """ + + (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs( + config, pipeline, batch_size + ) + return ( + max_utils.calculate_model_tflops( + pipeline.flux, + rngs, + train, + hidden_states=latents, + img_ids=latents_ids, + encoder_hidden_states=t5_hidden_states, + txt_ids=t5_ids, + pooled_projections=clip_hidden_states, + timestep=timesteps, + guidance=guidance_vec, + ) + / jax.local_device_count() + ) + + def tokenize_captions(examples, caption_column, tokenizer, input_ids_key="input_ids", p_encode=None): """Tokenize captions for sd1.x,sd2.x models.""" captions = list(examples[caption_column]) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index ec09a5eb1..5e875076b 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -23,10 +23,10 @@ _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] +_import_structure["normalization_flax"] = ["FlaxAdaLayerNormZeroSingle"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL diff --git a/src/maxdiffusion/models/ae_flux_nnx.py b/src/maxdiffusion/models/ae_flux_nnx.py new file mode 100644 index 000000000..2a32a2bc1 --- /dev/null +++ b/src/maxdiffusion/models/ae_flux_nnx.py @@ -0,0 +1,583 @@ +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from chex import Array +from einops import rearrange +from flax import nnx +from jax.typing import DTypeLike + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + rngs: nnx.Rngs + param_dtype: DTypeLike = jnp.bfloat16 + + +@nnx.jit +def swish(x: Array) -> Array: + return nnx.swish(x) + + +class AttnBlock(nnx.Module): + + def __init__( + self, + in_channels: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jnp.bfloat16, + ) -> None: + self.in_channels = in_channels + + self.norm = nnx.GroupNorm( + num_groups=32, + num_features=in_channels, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + self.q = nnx.Conv( + in_features=in_channels, + out_features=in_channels, + kernel_size=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + self.k = nnx.Conv( + in_features=in_channels, + out_features=in_channels, + kernel_size=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + self.v = nnx.Conv( + in_features=in_channels, + out_features=in_channels, + kernel_size=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + self.proj_out = nnx.Conv( + in_features=in_channels, + out_features=in_channels, + kernel_size=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + @nnx.jit + def attention(self, h_: Array) -> Array: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, h, w, c = q.shape + q = rearrange(q, "b h w c-> b (h w) 1 c") + k = rearrange(k, "b h w c-> b (h w) 1 c") + v = rearrange(v, "b h w c-> b (h w) 1 c") + + # Calculate Attention + h_ = nnx.dot_product_attention(q, k, v, dtype=jnp.bfloat16) + + return rearrange(h_, "b (h w) 1 c -> b h w c", h=h, w=w, c=c, b=b) + + @nnx.jit + def __call__(self, x: Array) -> Array: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nnx.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jnp.bfloat16, + ) -> None: + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nnx.GroupNorm( + num_groups=32, + num_features=in_channels, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + self.conv1 = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + self.norm2 = nnx.GroupNorm( + num_groups=32, + num_features=out_channels, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + self.conv2 = nnx.Conv( + in_features=out_channels, + out_features=out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding=(0, 0), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + @nnx.jit + def __call__(self, x: Array) -> Array: + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nnx.Module): + + def __init__( + self, + in_channels: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jnp.bfloat16, + ): + self.conv = nnx.Conv( + in_features=in_channels, + out_features=in_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding=(0, 0), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + @nnx.jit + def __call__(self, x: Array) -> Array: + # no padding for height and channel, padding for height and width + pad_width = ((0, 0), (0, 1), (0, 1), (0, 0)) + x = jnp.pad(array=x, pad_width=pad_width, mode="constant", constant_values=0) + x = self.conv(x) + return x + + +class Upsample(nnx.Module): + + def __init__( + self, + in_channels: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jnp.bfloat16, + ): + self.conv = nnx.Conv( + in_features=in_channels, + out_features=in_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + @nnx.jit + def __call__(self, x: Array) -> Array: + # Assuming `x` is a 4D tensor with shape (batch, height, width, channels) + scale_factor = 2.0 + b, h, w, c = x.shape + new_height = int(h * scale_factor) + new_width = int(w * scale_factor) + new_shape = (b, new_height, new_width, c) + + # Resize using nearest-neighbor interpolation + x = jax.image.resize(x, new_shape, method="nearest") + x = self.conv(x) + return x + + +class Encoder(nnx.Module): + + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jnp.bfloat16, + ) -> None: + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = nnx.Conv( + in_features=in_channels, + out_features=self.ch, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nnx.Sequential() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nnx.Sequential() + attn = nnx.Sequential() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.layers.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + rngs=rngs, + param_dtype=param_dtype, + ) + ) + block_in = block_out + down = nnx.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample( + in_channels=block_in, + rngs=rngs, + param_dtype=param_dtype, + ) + curr_res = curr_res // 2 + self.down.layers.append(down) + + # middle + self.mid = nnx.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + rngs=rngs, + param_dtype=param_dtype, + ) + self.mid.attn_1 = AttnBlock( + in_channels=block_in, + rngs=rngs, + param_dtype=param_dtype, + ) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + rngs=rngs, + param_dtype=param_dtype, + ) + + # end + self.norm_out = nnx.GroupNorm( + num_groups=32, + num_features=block_in, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + self.conv_out = nnx.Conv( + in_features=block_in, + out_features=2 * z_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + # @nnx.jit + def __call__(self, x: Array) -> Array: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down.layers[i_level].block.layers[i_block](hs[-1]) + if len(self.down.layers[i_level].attn.layers) > 0: + h = self.down.layers[i_level].attn.layers[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down.layers[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nnx.Module): + + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + rngs: nnx.Rngs, + param_dtype: DTypeLike = jnp.bfloat16, + ): + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nnx.Conv( + in_features=z_channels, + out_features=block_in, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + # middle + self.mid = nnx.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + rngs=rngs, + param_dtype=param_dtype, + ) + self.mid.attn_1 = AttnBlock( + in_channels=block_in, + rngs=rngs, + param_dtype=param_dtype, + ) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + rngs=rngs, + param_dtype=param_dtype, + ) + + # upsampling + self.up = nnx.Sequential() + for i_level in reversed(range(self.num_resolutions)): + block = nnx.Sequential() + attn = nnx.Sequential() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.layers.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + rngs=rngs, + param_dtype=param_dtype, + ) + ) + block_in = block_out + up = nnx.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample( + in_channels=block_in, + rngs=rngs, + param_dtype=param_dtype, + ) + curr_res = curr_res * 2 + self.up.layers.insert(0, up) + + # end + self.norm_out = nnx.GroupNorm( + num_groups=32, + num_features=block_in, + epsilon=1e-6, + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + self.conv_out = nnx.Conv( + in_features=block_in, + out_features=out_ch, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + rngs=rngs, + param_dtype=param_dtype, + dtype=param_dtype, + ) + + @nnx.jit + def __call__(self, z: Array) -> Array: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up.layers[i_level].block.layers[i_block](h) + if len(self.up.layers[i_level].attn.layers) > 0: + h = self.up.layers[i_level].attn.layers[i_block](h) + if i_level != 0: + h = self.up.layers[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nnx.Module): + + def __init__( + self, + sample: bool = True, + chunk_dim: int = -1, + key: Array = jax.random.PRNGKey(42), + ): + self.sample = sample + self.chunk_dim = chunk_dim + self.key = key + + # @nnx.jit + def __call__(self, z: Array) -> Array: + mean, logvar = jnp.split(z, 2, axis=self.chunk_dim) + if self.sample: + std = jnp.exp(0.5 * logvar) + return mean + std * jax.random.normal(key=self.key, shape=mean.shape, dtype=z.dtype) + else: + return mean + + +class AutoEncoder(nnx.Module): + + def __init__( + self, + params: AutoEncoderParams, + ): + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + rngs=params.rngs, + param_dtype=params.param_dtype, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + rngs=params.rngs, + param_dtype=params.param_dtype, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Array) -> Array: + # rearrange for jax + x = rearrange(x, "b c h w -> b h w c") + + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + + # rearrange for jax + z = rearrange(z, "b h w c -> b c h w") + return z + + def decode(self, z: Array) -> Array: + # rearrange for jax + z = rearrange(z, "b c h w -> b h w c") + + z = z / self.scale_factor + self.shift_factor + z = self.decoder(z) + + # rearrange for jax + z = rearrange(z, "b h w c -> b c h w") + return z + + @nnx.jit + def __call__(self, x: Array) -> Array: + # x -> (b, c, h, w) + return self.decode(self.encode(x)) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 81f774d88..ad76e5408 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -21,6 +21,8 @@ from jax.experimental import shard_map from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel +from einops import rearrange +import enum from .. import common_types, max_logging @@ -30,11 +32,21 @@ BlockSizes = common_types.BlockSizes +class AttentionType(enum.Enum): + GLOBAL = "global" + LOCAL_SLIDING = "local_sliding" + + AxisNames = common_types.AxisNames BATCH = common_types.BATCH +KV_BATCH = common_types.KV_BATCH LENGTH = common_types.LENGTH HEAD = common_types.HEAD +KV_HEAD = common_types.KV_HEAD D_KV = common_types.D_KV +KV_HEAD_DIM = common_types.KV_HEAD_DIM +EMBED = common_types.EMBED +DEFAULT_MASK_VALUE = common_types.DEFAULT_MASK_VALUE class AttentionOp(nn.Module): @@ -47,9 +59,11 @@ class AttentionOp(nn.Module): split_head_dim: bool = False float32_qk_product: bool = True flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + cudnn_flash_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None dtype: DType = jnp.float32 + attention_lib = None def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: """Check attention inputs.""" @@ -324,6 +338,180 @@ def chunk_scanner(chunk_idx, _): return jnp.concatenate(res, axis=-3) # fuse the chunked result back +class FlaxFluxAttention(nn.Module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + attention_kernel: str = "dot_product" + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + query_axis_names: AxisNames = (KV_BATCH, KV_HEAD, LENGTH, KV_HEAD_DIM) + key_axis_names: AxisNames = (KV_BATCH, KV_HEAD, LENGTH, KV_HEAD_DIM) + value_axis_names: AxisNames = (KV_BATCH, KV_HEAD, LENGTH, KV_HEAD_DIM) + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) + precision: jax.lax.Precision = None + qkv_bias: bool = False + + def setup(self): + if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + inner_dim = self.dim_head * self.heads + scale = self.dim_head**-0.5 + + self.attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + scale=scale, + heads=self.heads, + dim_head=self.dim_head, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + flash_block_sizes=self.flash_block_sizes, + dtype=self.dtype, + float32_qk_product=False, + ) + + kernel_axes = ("embed", "heads") + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) + + self.qkv = nn.Dense( + inner_dim * 3, + kernel_init=qkv_init_kernel, + use_bias=self.qkv_bias, + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="i_qkv", + precision=self.precision, + ) + + self.encoder_qkv = nn.Dense( + inner_dim * 3, + kernel_init=qkv_init_kernel, + use_bias=self.qkv_bias, + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="e_qkv", + precision=self.precision, + ) + + # kernel_axes_out=("heads", "embed") + self.proj_attn = nn.Dense( + self.query_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + use_bias=True, + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="i_proj", + precision=self.precision, + ) + + self.encoder_proj_attn = nn.Dense( + self.query_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + use_bias=True, + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="e_proj", + precision=self.precision, + ) + + self.query_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + self.key_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + + self.encoder_query_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + self.encoder_key_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + + def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) + + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + + return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + + def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): + + qkv_proj = self.qkv(hidden_states) + B, L = hidden_states.shape[:2] + H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 + qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) + query_proj, key_proj, value_proj = qkv_proj + + query_proj = self.query_norm(query_proj) + + key_proj = self.key_norm(key_proj) + + if encoder_hidden_states is not None: + + encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states) + B, L = encoder_hidden_states.shape[:2] + H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3 + encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) + encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj + + encoder_query_proj = self.encoder_query_norm(encoder_query_proj) + + encoder_key_proj = self.encoder_key_norm(encoder_key_proj) + + query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2) + key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2) + value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2) + + query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) + key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) + value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) + + image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) + query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb) + + query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1) + key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1) + value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1) + + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + context_attn_output = None + + if encoder_hidden_states is not None: + context_attn_output, attn_output = ( + attn_output[:, : encoder_hidden_states.shape[1]], + attn_output[:, encoder_hidden_states.shape[1] :], + ) + + attn_output = self.proj_attn(attn_output) + + context_attn_output = self.encoder_proj_attn(context_attn_output) + + return attn_output, context_attn_output + + class FlaxAttention(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index af2e8dfb2..5f0795071 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import math - +from typing import List, Union +import jax import flax.linen as nn import jax.numpy as jnp +from flax import nnx +from transformers import FlaxCLIPTextModel, FlaxT5EncoderModel, AutoTokenizer + def get_sinusoidal_embeddings( timesteps: jnp.ndarray, @@ -72,9 +76,23 @@ class FlaxTimestepEmbedding(nn.Module): @nn.compact def __call__(self, temb): - temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_1")(temb) + temb = nn.Dense( + self.time_embed_dim, + dtype=self.dtype, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + param_dtype=self.weights_dtype, + name="in_layer", + )(temb) temb = nn.silu(temb) - temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_2")(temb) + temb = nn.Dense( + self.time_embed_dim, + dtype=self.dtype, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + param_dtype=self.weights_dtype, + name="out_layer", + )(temb) return temb @@ -96,3 +114,176 @@ def __call__(self, timesteps): return get_sinusoidal_embeddings( timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift ) + + +def get_1d_rotary_pos_embed( + dim: int, pos: Union[jnp.array, int], theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0, freqs_dtype=jnp.float32 +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = jnp.arange(pos) + + theta = theta * ntk_factor + freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor + freqs = jnp.outer(pos, freqs) + freqs_cos = jnp.cos(freqs) + freqs_sin = jnp.sin(freqs) + out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1) + + return out + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + hidden_size: int + out_features: int = None + act_fn: str = "gelu_tanh" + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, caption): + hidden_states = nn.Dense( + self.hidden_size, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + name="in_layer", + )(caption) + if self.act_fn == "gelu_tanh": + act_1 = nn.gelu + elif self.act_fn == "silu": + act_1 = nn.swish + else: + raise ValueError(f"Unknown activation function: {self.act_fn}") + hidden_states = act_1(hidden_states) + + hidden_states = nn.Dense( + self.hidden_size, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + name="out_layer", + )(hidden_states) + return hidden_states + + +class FluxPosEmbed(nn.Module): + theta: int + axes_dim: List[int] + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, ids): + n_axes = ids.shape[-1] + out_freqs = [] + pos = ids.astype(self.dtype) + freqs_dtype = self.dtype + for i in range(n_axes): + out = get_1d_rotary_pos_embed(self.axes_dim[i], pos[..., i], freqs_dtype=freqs_dtype) + out_freqs.append(out) + + out_freqs = jnp.concatenate(out_freqs, axis=1) + return out_freqs + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + embedding_dim: int + pooled_projection_dim: int + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, timestep, pooled_projection): + timesteps_proj = timestep + timestep_emb = FlaxTimestepEmbedding( + time_embed_dim=self.embedding_dim, dtype=self.dtype, weights_dtype=self.weights_dtype + )(timesteps_proj) + pooled_projections = PixArtAlphaTextProjection( + self.embedding_dim, + act_fn="silu", + dtype=self.dtype, + weights_dtype=self.weights_dtype, + )(pooled_projection) + + conditioning = timestep_emb + pooled_projections + return conditioning + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + embedding_dim: int + pooled_projection_dim: int + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, timestep, guidance, pooled_projection): + timesteps_proj = timestep + timestep_emb = FlaxTimestepEmbedding( + time_embed_dim=self.embedding_dim, dtype=self.dtype, weights_dtype=self.weights_dtype + )(timesteps_proj.astype(pooled_projection.dtype)) + + guidance_proj = guidance + guidance_emb = FlaxTimestepEmbedding( + time_embed_dim=self.embedding_dim, dtype=self.dtype, weights_dtype=self.weights_dtype + )(guidance_proj.astype(pooled_projection.dtype)) + + time_guidance_emb = timestep_emb + guidance_emb + + pooled_projections = PixArtAlphaTextProjection( + self.embedding_dim, act_fn="silu", dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision + )(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +class HFEmbedder(nnx.Module): + + def __init__(self, version: str, max_length: int, **hf_kwargs): + super().__init__() + self.is_clip = version.split("/")[1].startswith("clip") + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version, max_length=max_length, use_fast=True) + self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(version, **hf_kwargs) + else: + self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version, max_length=max_length, use_fast=True) + self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(version, **hf_kwargs) + + def __call__(self, text: list[str]): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="np", + ) + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"], + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/src/maxdiffusion/models/flux_utils.py b/src/maxdiffusion/models/flux_utils.py new file mode 100644 index 000000000..aafaa6288 --- /dev/null +++ b/src/maxdiffusion/models/flux_utils.py @@ -0,0 +1,408 @@ +import os +from dataclasses import dataclass + +import jax +import torch # need for torch 2 jax +from chex import Array +from flax import nnx +from huggingface_hub import hf_hub_download +from jax import numpy as jnp +from safetensors import safe_open +from einops import rearrange +from maxdiffusion.models.ae_flux_nnx import AutoEncoder, AutoEncoderParams +from maxdiffusion.models.transformers.transformer_flux_flax import FluxParams + +############################################################################################## +# AUTOENCODER MODEL PORTING +############################################################################################## + + +def port_group_norm(group_norm, tensors, prefix): + group_norm.scale.value = tensors[f"{prefix}.weight"] + group_norm.bias.value = tensors[f"{prefix}.bias"] + + return group_norm + + +def port_conv(conv, tensors, prefix): + conv.kernel.value = rearrange(tensors[f"{prefix}.weight"], "i o k1 k2 -> k1 k2 o i") + conv.bias.value = tensors[f"{prefix}.bias"] + + return conv + + +def port_attn_block(attn_block, tensors, prefix): + # port the norm + attn_block.norm = port_group_norm( + group_norm=attn_block.norm, + tensors=tensors, + prefix=f"{prefix}.norm", + ) + + # port the k, q, v layers + attn_block.k = port_conv( + conv=attn_block.k, + tensors=tensors, + prefix=f"{prefix}.k", + ) + + attn_block.q = port_conv( + conv=attn_block.q, + tensors=tensors, + prefix=f"{prefix}.q", + ) + + attn_block.v = port_conv( + conv=attn_block.v, + tensors=tensors, + prefix=f"{prefix}.v", + ) + + # port the proj_out layer + attn_block.proj_out = port_conv( + conv=attn_block.proj_out, + tensors=tensors, + prefix=f"{prefix}.proj_out", + ) + + return attn_block + + +def port_resent_block(resnet_block, tensors, prefix): + # port the norm + resnet_block.norm1 = port_group_norm( + group_norm=resnet_block.norm1, + tensors=tensors, + prefix=f"{prefix}.norm1", + ) + resnet_block.norm2 = port_group_norm( + group_norm=resnet_block.norm2, + tensors=tensors, + prefix=f"{prefix}.norm2", + ) + + # port the convs + resnet_block.conv1 = port_conv( + conv=resnet_block.conv1, + tensors=tensors, + prefix=f"{prefix}.conv1", + ) + resnet_block.conv2 = port_conv( + conv=resnet_block.conv2, + tensors=tensors, + prefix=f"{prefix}.conv2", + ) + + if resnet_block.in_channels != resnet_block.out_channels: + resnet_block.nin_shortcut = port_conv( + conv=resnet_block.nin_shortcut, + tensors=tensors, + prefix=f"{prefix}.nin_shortcut", + ) + + return resnet_block + + +def port_downsample(downsample, tensors, prefix): + # port the conv + downsample.conv = port_conv( + conv=downsample.conv, + tensors=tensors, + prefix=f"{prefix}.conv", + ) + + return downsample + + +def port_upsample(upsample, tensors, prefix): + # port the conv + upsample.conv = port_conv( + conv=upsample.conv, + tensors=tensors, + prefix=f"{prefix}.conv", + ) + + return upsample + + +def port_encoder(encoder, tensors, prefix): + # conv in + encoder.conv_in = port_conv( + conv=encoder.conv_in, + tensors=tensors, + prefix=f"{prefix}.conv_in", + ) + + # down + for i, down_layer in enumerate(encoder.down.layers): + # block + for j, block_layer in enumerate(down_layer.block.layers): + block_layer = port_resent_block( + resnet_block=block_layer, + tensors=tensors, + prefix=f"{prefix}.down.{i}.block.{j}", + ) + # attn + for j, attn_layer in enumerate(down_layer.attn.layers): + attn_layer = port_attn_block( + attn_block=attn_layer, + tensors=tensors, + prefix=f"{prefix}.attn.{i}.block.{j}", + ) + + # downsample + if i != encoder.num_resolutions - 1: + downsample = down_layer.downsample + downsample = port_downsample( + downsample=downsample, + tensors=tensors, + prefix=f"{prefix}.down.{i}.downsample", + ) + + # mid + encoder.mid.block_1 = port_resent_block( + resnet_block=encoder.mid.block_1, + tensors=tensors, + prefix=f"{prefix}.mid.block_1", + ) + encoder.mid.attn_1 = port_attn_block( + attn_block=encoder.mid.attn_1, + tensors=tensors, + prefix=f"{prefix}.mid.attn_1", + ) + encoder.mid.block_2 = port_resent_block( + resnet_block=encoder.mid.block_2, + tensors=tensors, + prefix=f"{prefix}.mid.block_2", + ) + + # norm out + encoder.norm_out = port_group_norm( + group_norm=encoder.norm_out, + tensors=tensors, + prefix=f"{prefix}.norm_out", + ) + + # conv out + encoder.conv_out = port_conv( + conv=encoder.conv_out, + tensors=tensors, + prefix=f"{prefix}.conv_out", + ) + + return encoder + + +def port_decoder(decoder, tensors, prefix): + # conv in + decoder.conv_in = port_conv( + conv=decoder.conv_in, + tensors=tensors, + prefix=f"{prefix}.conv_in", + ) + + # mid + decoder.mid.block_1 = port_resent_block( + resnet_block=decoder.mid.block_1, + tensors=tensors, + prefix=f"{prefix}.mid.block_1", + ) + decoder.mid.attn_1 = port_attn_block( + attn_block=decoder.mid.attn_1, + tensors=tensors, + prefix=f"{prefix}.mid.attn_1", + ) + decoder.mid.block_2 = port_resent_block( + resnet_block=decoder.mid.block_2, + tensors=tensors, + prefix=f"{prefix}.mid.block_2", + ) + + for i, up_layer in enumerate(decoder.up.layers): + # block + for j, block_layer in enumerate(up_layer.block.layers): + block_layer = port_resent_block( + resnet_block=block_layer, + tensors=tensors, + prefix=f"{prefix}.up.{i}.block.{j}", + ) + + # attn + for j, attn_layer in enumerate(up_layer.attn.layers): + attn_layer = port_attn_block( + attn_block=attn_layer, + tensors=tensors, + prefix=f"{prefix}.up.{i}.attn.{j}", + ) + + # upsample + if i != 0: + up_layer.upsample = port_upsample( + upsample=up_layer.upsample, + tensors=tensors, + prefix=f"{prefix}.up.{i}.upsample", + ) + + # norm out + decoder.norm_out = port_group_norm( + group_norm=decoder.norm_out, + tensors=tensors, + prefix=f"{prefix}.norm_out", + ) + + # conv out + decoder.conv_out = port_conv( + conv=decoder.conv_out, + tensors=tensors, + prefix=f"{prefix}.conv_out", + ) + + return decoder + + +def port_autoencoder(autoencoder, tensors): + autoencoder.encoder = port_encoder( + encoder=autoencoder.encoder, + tensors=tensors, + prefix="encoder", + ) + autoencoder.decoder = port_decoder( + decoder=autoencoder.decoder, + tensors=tensors, + prefix="decoder", + ) + return autoencoder + + +def torch2jax(torch_tensor: torch.Tensor) -> Array: + is_bfloat16 = torch_tensor.dtype == torch.bfloat16 + if is_bfloat16: + # upcast the tensor to fp32 + torch_tensor = torch_tensor.to(dtype=torch.float32) + + if torch.device.type != "cpu": + torch_tensor = torch_tensor.to("cpu") + + numpy_value = torch_tensor.numpy() + jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16) + return jax_array + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + rngs=nnx.Rngs(default=42), + param_dtype=jnp.bfloat16, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + rngs=nnx.Rngs(default=42), + param_dtype=jnp.bfloat16, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + rngs=nnx.Rngs(default=42), + param_dtype=jnp.bfloat16, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + rngs=nnx.Rngs(default=42), + param_dtype=jnp.bfloat16, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + + +def load_ae(name: str, device: str, hf_download: bool = True) -> AutoEncoder: + device = jax.devices(device)[0] + with jax.default_device(device): + ckpt_path = configs[name].ae_path + if ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_ae is not None and hf_download: + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) + + print(f"Load and port autoencoder on {device}") + ae = AutoEncoder(params=configs[name].ae_params) + + if ckpt_path is not None: + tensors = {} + with safe_open(ckpt_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + ae = port_autoencoder(autoencoder=ae, tensors=tensors) + + del tensors + jax.clear_caches() + return ae diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 5cb7b9eb2..fbb8ce992 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -55,6 +55,10 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic ("to_k", "key"), ("to_v", "value"), ("to_q", "query"), + ("e_proj", "e_proj"), + ("i_proj", "i_proj"), + ("e_qkv", "e_qkv"), + ("i_qkv", "i_qkv"), ): if pt_tuple_key[-2] == rename_from: weight_name = pt_tuple_key[-1] @@ -255,6 +259,33 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): renamed_pt_key = rename_key(pt_key) + + if "FluxTransformer2DModel" == flax_model.__class__.__name__: + if "double_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("double_blocks_", "double_blocks.layers_") + renamed_pt_key = renamed_pt_key.replace("img_mlp_", "img_mlp.layers_") + renamed_pt_key = renamed_pt_key.replace("txt_mlp_", "txt_mlp.layers_") + renamed_pt_key = renamed_pt_key.replace("img_mod", "img_norm1") + renamed_pt_key = renamed_pt_key.replace("txt_mod", "txt_norm1") + renamed_pt_key = renamed_pt_key.replace("img_attn.qkv", "attn.i_qkv") + renamed_pt_key = renamed_pt_key.replace("img_attn.proj", "attn.i_proj") + renamed_pt_key = renamed_pt_key.replace("txt_attn.qkv", "attn.e_qkv") + renamed_pt_key = renamed_pt_key.replace("txt_attn.proj", "attn.e_proj") + + renamed_pt_key = renamed_pt_key.replace("img_attn.norm", "attn") + renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.key_norm", "attn.encoder_key_norm") + renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.query_norm", "attn.encoder_query_norm") + elif "single_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("single_blocks_", "single_blocks.layers_") + renamed_pt_key = renamed_pt_key.replace("modulation", "norm") + renamed_pt_key = renamed_pt_key.replace("norm.key_norm", "attn.key_norm") + renamed_pt_key = renamed_pt_key.replace("norm.query_norm", "attn.query_norm") + elif "vector_in" in renamed_pt_key or "time_in" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("vector_in", "time_text_embed.PixArtAlphaTextProjection_0") + renamed_pt_key = renamed_pt_key.replace("time_in", "time_text_embed.FlaxTimestepEmbedding_0") + elif "final_layer" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("final_layer.linear", "proj_out") + renamed_pt_key = renamed_pt_key.replace("final_layer.adaLN_modulation_1", "norm_out.Dense_0") pt_tuple_key = tuple(renamed_pt_key.split(".")) # Correctly rename weight parameters diff --git a/src/maxdiffusion/models/modeling_flax_utils.py b/src/maxdiffusion/models/modeling_flax_utils.py index 3983c224d..4d9f7b311 100644 --- a/src/maxdiffusion/models/modeling_flax_utils.py +++ b/src/maxdiffusion/models/modeling_flax_utils.py @@ -298,6 +298,7 @@ def from_pretrained( use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) + filename = kwargs.pop("filename", None) user_agent = { "maxdiffusion": __version__, @@ -309,6 +310,7 @@ def from_pretrained( cls.mesh = kwargs["mesh"] # Load config if we don't provide one + unused_kwargs = kwargs if config is None: config, unused_kwargs = cls.load_config( pretrained_model_name_or_path, @@ -323,7 +325,6 @@ def from_pretrained( subfolder=subfolder, **kwargs, ) - model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs) # Load model @@ -353,10 +354,12 @@ def from_pretrained( f"{pretrained_path_with_subfolder}." ) else: + if filename is None: + filename = FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME try: model_file = hf_hub_download( pretrained_model_name_or_path, - filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, + filename=filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -364,7 +367,7 @@ def from_pretrained( local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, - subfolder=subfolder, + subfolder=subfolder if not ".safetensors" in filename else None, revision=revision, ) diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py new file mode 100644 index 000000000..ea3b970d8 --- /dev/null +++ b/src/maxdiffusion/models/normalization_flax.py @@ -0,0 +1,149 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import jax +import jax.numpy as jnp +import flax.linen as nn + + +class AdaLayerNormContinuous(nn.Module): + embedding_dim: int + elementwise_affine: bool = True + eps: float = 1e-5 + bias: bool = True + norm_type: str = "layer_norm" + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, x, conditioning_embedding): + assert self.norm_type == "layer_norm" + emb = nn.Dense( + self.embedding_dim * 2, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + use_bias=self.bias, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + )(nn.silu(conditioning_embedding)) + shift, scale = jnp.split(emb, 2, axis=1) + shift = nn.with_logical_constraint(shift, ("activation_batch", "activation_embed")) + scale = nn.with_logical_constraint(scale, ("activation_batch", "activation_embed")) + x = nn.LayerNorm(epsilon=self.eps, use_bias=self.elementwise_affine, use_scale=self.elementwise_affine)(x) + x = (1 + scale[:, None, :]) * x + shift[:, None, :] + return x + + +class AdaLayerNormZero(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + embedding_dim: int + norm_type: str = "layer_norm" + bias: bool = True + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, x, emb): + emb = nn.Dense( + 6 * self.embedding_dim, + use_bias=self.bias, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + name="lin", + )(nn.silu(emb)) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = jnp.split(emb[:, None, :], 6, axis=-1) + shift_msa = nn.with_logical_constraint(shift_msa, ("activation_batch", "activation_embed")) + scale_msa = nn.with_logical_constraint(scale_msa, ("activation_batch", "activation_embed")) + gate_msa = nn.with_logical_constraint(gate_msa, ("activation_batch", "activation_embed")) + shift_mlp = nn.with_logical_constraint(shift_mlp, ("activation_batch", "activation_embed")) + scale_mlp = nn.with_logical_constraint(scale_mlp, ("activation_batch", "activation_embed")) + gate_mlp = nn.with_logical_constraint(gate_mlp, ("activation_batch", "activation_embed")) + + if self.norm_type == "layer_norm": + x = nn.LayerNorm( + epsilon=1e-6, + use_bias=False, + use_scale=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + )(x) + else: + raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") + x = x * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormZeroSingle(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + embedding_dim: int + norm_type: str = "layer_norm" + bias: bool = True + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, x, emb): + emb = nn.silu(emb) + emb = nn.Dense( + 3 * self.embedding_dim, + use_bias=self.bias, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + name="lin", + )(emb) + shift_msa, scale_msa, gate_msa = jnp.split(emb[:, None, :], 3, axis=-1) + shift_msa = nn.with_logical_constraint(shift_msa, ("activation_batch", "activation_embed")) + scale_msa = nn.with_logical_constraint(scale_msa, ("activation_batch", "activation_embed")) + gate_msa = nn.with_logical_constraint(gate_msa, ("activation_batch", "activation_embed")) + if self.norm_type == "layer_norm": + x = ( + nn.LayerNorm( + epsilon=1e-6, + use_bias=False, + use_scale=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + )(x) + * (1 + scale_msa) + + shift_msa + ) + else: + raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") + return x, gate_msa diff --git a/src/maxdiffusion/models/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/transformers/transformer_flux_flax.py new file mode 100644 index 000000000..ee5169811 --- /dev/null +++ b/src/maxdiffusion/models/transformers/transformer_flux_flax.py @@ -0,0 +1,621 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +"""This script is used an example of how to shard the UNET on TPU.""" + +from typing import Any, Dict, Optional, Tuple, Union +import jax +import math +import jax.numpy as jnp +import flax +import flax.linen as nn +from jax.random import PRNGKey +from einops import repeat, rearrange +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin +from ..normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous, AdaLayerNormZero +from ..attention_flax import FlaxFluxAttention +from ..embeddings_flax import (FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings) +from ... import common_types +from ...common_types import BlockSizes +from ... import max_logging +from ...utils import BaseOutput +from dataclasses import dataclass + +AxisNames = common_types.AxisNames +BATCH = common_types.BATCH +LENGTH = common_types.LENGTH +HEAD = common_types.HEAD +D_KV = common_types.D_KV + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + param_dtype: jnp.bfloat16 + rngs: jax.random.PRNGKey + + +@flax.struct.dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`FluxTransformer2DModel`]. + + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: jnp.ndarray + + +class FluxSingleTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + dim: int + num_attention_heads: int + attention_head_dim: int + mlp_ratio: int = 4.0 + attention_kernel: str = "dot_product" + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + def setup(self): + self.mlp_hidden_dim = int(self.dim * self.mlp_ratio) + + self.norm = AdaLayerNormZeroSingle( + self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision + ) + + self.linear1 = nn.Dense( + self.dim * 3 + self.mlp_hidden_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + + self.mlp_act = nn.gelu + self.linear2 = nn.Dense( + self.dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + self.attn = FlaxFluxAttention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + attention_kernel=self.attention_kernel, + mesh=self.mesh, + ) + + def __call__(self, hidden_states, temb, image_rotary_emb=None): + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + qkv, mlp = jnp.split(self.linear1(norm_hidden_states), [3 * self.dim], axis=-1) + mlp = nn.with_logical_constraint(mlp, ("activation_batch", "activation_length", "activation_embed")) + qkv = nn.with_logical_constraint(qkv, ("activation_batch", "activation_length", "activation_embed")) + + B, L = hidden_states.shape[:2] + H, D, K = self.num_attention_heads, qkv.shape[-1] // (self.num_attention_heads * 3), 3 + qkv_proj = qkv.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) + q, k, v = qkv_proj + + q = self.attn.query_norm(q) + k = self.attn.key_norm(k) + + if image_rotary_emb is not None: + # since this function returns image_rotary_emb and passes it between layers, + # we do not want to modify it + image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) + q, k = self.attn.apply_rope(q, k, image_rotary_emb_reordered) + + q = q.transpose(0, 2, 1, 3).reshape(q.shape[0], q.shape[2], -1) + k = k.transpose(0, 2, 1, 3).reshape(k.shape[0], k.shape[2], -1) + v = v.transpose(0, 2, 1, 3).reshape(v.shape[0], v.shape[2], -1) + + attn_output = self.attn.attention_op.apply_attention(q, k, v) + + attn_mlp = jnp.concatenate([attn_output, self.mlp_act(mlp)], axis=2) + attn_mlp = nn.with_logical_constraint(attn_mlp, ("activation_batch", "activation_length", "activation_embed")) + hidden_states = self.linear2(attn_mlp) + hidden_states = gate * hidden_states + hidden_states = residual + hidden_states + if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16: + hidden_states = jnp.clip(hidden_states, -65504, 65504) + + return hidden_states, temb, image_rotary_emb + + +class FluxTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + dim: int + num_attention_heads: int + attention_head_dim: int + qk_norm: str = "rms_norm" + eps: int = 1e-6 + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + mlp_ratio: float = 4.0 + qkv_bias: bool = False + attention_kernel: str = "dot_product" + + def setup(self): + + self.img_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) + self.txt_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) + + self.attn = FlaxFluxAttention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + qkv_bias=self.qkv_bias, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + attention_kernel=self.attention_kernel, + mesh=self.mesh, + ) + + self.img_norm2 = nn.LayerNorm( + use_bias=False, + use_scale=False, + epsilon=self.eps, + dtype=self.dtype, + param_dtype=self.weights_dtype, + ) + self.img_mlp = nn.Sequential( + [ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ] + ) + + self.txt_norm2 = nn.LayerNorm( + use_bias=False, + use_scale=False, + epsilon=self.eps, + dtype=self.dtype, + param_dtype=self.weights_dtype, + ) + self.txt_mlp = nn.Sequential( + [ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ] + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.img_norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.txt_norm1( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.img_norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.img_mlp(norm_hidden_states) + ff_output = gate_mlp * ff_output + + hidden_states = hidden_states + ff_output + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.txt_norm2(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + context_ff_output = self.txt_mlp(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + return hidden_states, encoder_hidden_states, temb, image_rotary_emb + + +@flax_register_to_config +class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + The Tranformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its + general usage and behavior. + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + + """ + + patch_size: int = 1 + in_channels: int = 64 + num_layers: int = 19 + num_single_layers: int = 38 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 4096 + pooled_projection_dim: int = 768 + guidance_embeds: bool = False + axes_dims_rope: Tuple[int] = (16, 56, 56) + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + mlp_ratio: float = 4.0 + qkv_bias: bool = True + theta: int = 1000 + attention_kernel: str = "dot_product" + eps = 1e-6 + rngs = (PRNGKey(seed=42),) + # hidden_state_axis_names: AxisNames = (BATCH, LENGTH, D_KV) + + def setup(self): + self.out_channels = self.in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pe_embedder = FluxPosEmbed(theta=self.theta, axes_dim=self.axes_dims_rope, dtype=self.dtype) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if self.guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, + pooled_projection_dim=self.pooled_projection_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + ) + self.txt_in = nn.Dense( + self.inner_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + self.img_in = nn.Dense( + self.inner_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + + self.double_blocks = nn.Sequential( + [ + *[ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + ) + for _ in range(self.num_layers) + ] + ] + ) + + self.single_blocks = nn.Sequential( + [ + *[ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + mlp_ratio=self.mlp_ratio, + ) + for _ in range(self.num_single_layers) + ] + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, + elementwise_affine=False, + eps=self.eps, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + ) + + self.proj_out = nn.Dense( + self.patch_size**2 * self.out_channels, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", None)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + use_bias=True, + ) + + def timestep_embedding(self, t: jax.Array, dim: int, max_period=10000, time_factor: float = 1000.0) -> jax.Array: + """ + Generate timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + time_factor: Tensor of positional embeddings. + + Returns: + timestep embeddings. + """ + t = time_factor * t + half = dim // 2 + + freqs = jnp.exp(-math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.bfloat16) / half).astype(dtype=t.dtype) + + args = t[:, None].astype(jnp.bfloat16) * freqs[None] + embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) + + if dim % 2: + embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1) + + if jnp.issubdtype(t.dtype, jnp.floating): + embedding = embedding.astype(t.dtype) + + return embedding + + def __call__( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + img_ids, + txt_ids, + guidance, + return_dict: bool = True, + train: bool = False, + ): + hidden_states = self.img_in(hidden_states) + timestep = self.timestep_embedding(timestep, 256) + if self.guidance_embeds: + guidance = self.timestep_embedding(guidance, 256) + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + ids = jnp.concatenate((txt_ids, img_ids), axis=0) + ids = nn.with_logical_constraint(ids, ("activation_batch", None)) + image_rotary_emb = self.pe_embedder(ids) + image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed")) + + hidden_states, encoder_hidden_states, temb, image_rotary_emb = self.double_blocks( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1) + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) + + hidden_states, temb, image_rotary_emb = self.single_blocks( + hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb + ) + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def init_weights(self, rngs, eval_only=True): + scale_factor = 16 + resolution = 1024 + num_devices = len(jax.devices()) + batch_size = 1 * num_devices + batch_image_shape = ( + batch_size, + 16, # 16 to match jflux.get_noise + 2 * resolution // scale_factor, + 2 * resolution // scale_factor, + ) + # bs, encoder_input, seq_length + text_shape = ( + batch_size, + 256, + 4096, # Sequence length of text encoder, how to get this programmatically? + ) + text_ids_shape = ( + batch_size, + 256, + 3, # Hardcoded to match jflux.prepare + ) + vec_shape = ( + batch_size, + 768, # Sequence length of clip, how to get this programmatically? + ) + img = jnp.zeros(batch_image_shape, dtype=self.dtype) + bs, c, h, w = img.shape + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + img_ids = jnp.zeros((h // 2, w // 2, 3), dtype=self.dtype) + img_ids = img_ids.at[..., 1].set(jnp.arange(h // 2)[:, None]) + img_ids = img_ids.at[..., 2].set(jnp.arange(w // 2)[None, :]) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt = jnp.zeros(text_shape, dtype=self.dtype) + txt_ids = jnp.zeros(text_ids_shape, dtype=self.dtype) + + t_vec = jnp.full(bs, 0, dtype=self.dtype) + + vec = jnp.zeros(vec_shape, dtype=self.dtype) + + guidance_vec = jnp.full(bs, 4.0, dtype=self.dtype) + + if eval_only: + return jax.eval_shape( + self.init, + rngs, + hidden_states=img, + img_ids=img_ids, + encoder_hidden_states=txt, + txt_ids=txt_ids, + pooled_projections=vec, + timestep=t_vec, + guidance=guidance_vec, + )["params"] + else: + return self.init( + rngs, + hidden_states=img, + img_ids=img_ids, + encoder_hidden_states=txt, + txt_ids=txt_ids, + pooled_projections=vec, + timestep=t_vec, + guidance=guidance_vec, + )["params"] diff --git a/src/maxdiffusion/pipelines/__init__.py b/src/maxdiffusion/pipelines/__init__.py index 227784ba6..1093caa95 100644 --- a/src/maxdiffusion/pipelines/__init__.py +++ b/src/maxdiffusion/pipelines/__init__.py @@ -32,7 +32,13 @@ # These modules contain pipelines from multiple libraries/frameworks _dummy_objects = {} -_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []} +_import_structure = { + "stable_diffusion": [], + "stable_diffusion_xl": [], + "latent_diffusion": [], + "controlnet": [], + "jflux": [], +} try: if not is_onnx_available(): @@ -94,8 +100,12 @@ "FlaxStableDiffusionXLPipeline", ] ) + _import_structure["jflux"].extend( + [ + "JfluxPipeline", + ] + ) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -140,6 +150,7 @@ FlaxStableDiffusionPipeline, ) from .stable_diffusion_xl import FlaxStableDiffusionXLPipeline + from .jflux import JfluxPipeline try: if not (is_torch_available() and is_note_seq_available()): diff --git a/src/maxdiffusion/pipelines/jflux/__init__.py b/src/maxdiffusion/pipelines/jflux/__init__.py new file mode 100644 index 000000000..3cd0411c8 --- /dev/null +++ b/src/maxdiffusion/pipelines/jflux/__init__.py @@ -0,0 +1,5 @@ +_import_structure = { "pipeline_jflux" : "JfluxPipeline" } + +from .pipeline_jflux import ( + JfluxPipeline, +) diff --git a/src/maxdiffusion/pipelines/jflux/pipeline_jflux.py b/src/maxdiffusion/pipelines/jflux/pipeline_jflux.py new file mode 100644 index 000000000..ba7abae45 --- /dev/null +++ b/src/maxdiffusion/pipelines/jflux/pipeline_jflux.py @@ -0,0 +1,239 @@ +# Adapted from pipeline_flax_stable_diffusion.py +from functools import partial +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from maxdiffusion.models.transformers.transformer_flux_flax import FluxTransformer2DModel +from maxdiffusion.models.embeddings_flax import HFEmbedder +from maxdiffusion.models.ae_flux_nnx import AutoEncoder +from jax.sharding import Sharding +from jax.typing import DTypeLike + +import jax +import math +import jax.numpy as jnp +from chex import Array +from einops import rearrange, repeat +from typing import Dict, List, Optional, Union +from ...utils import replace_example_docstring +from flax.core.frozen_dict import FrozenDict +import einops + +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) + +# Set to True to use python for loop instead of jax.fori_loop for easier FOR_LOOPging +FOR_LOOP = True + +EXAMPLE_DOC_STRING = """ + Examples: COMING SOON +""" + + +class JfluxPipeline(FlaxDiffusionPipeline): + + def __init__( + self, + t5: HFEmbedder, + clip: HFEmbedder, + flux: FluxTransformer2DModel, + ae: AutoEncoder, + scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler], + dtype: jnp.dtype, + sharding: Sharding, + ): + super().__init__() + self.dtype = dtype + self.data_sharding = sharding + self.register_modules(t5=t5, clip=clip, flux=flux, ae=ae, scheduler=scheduler) + + @staticmethod + def unpack(x: Array, height: int, width: int) -> Array: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + + @staticmethod + def create_noise( + num_samples: int, + height: int, + width: int, + dtype: DTypeLike, + seed: jax.random.PRNGKey, + ): + return jax.random.normal( + key=seed, + shape=(num_samples, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16)), + dtype=dtype, + ) + + # this is the reverse of the unpack function + @staticmethod + def pack_img(img): + bs, c, h, w = img.shape + return einops.rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + def prepare_inputs(self, prompt: Union[str, List[str]], img: Array): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(prompt, str): + prompt = [prompt] + + bs = len(prompt) + txt = jax.device_put(jnp.asarray(self.t5(prompt), dtype=jnp.bfloat16), self.data_sharding) + + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = jax.device_put(jnp.zeros((bs, txt.shape[1], 3), dtype=txt.dtype), self.data_sharding) + + vec = jax.device_put(self.clip(prompt), self.data_sharding) + + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return (txt, txt_ids, vec, img) + + def prepare_img_ids(self, img, guidance_scale): + img = jax.device_put(img, self.data_sharding) + batch_size, _, h, w = img.shape + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + img_ids = jnp.zeros((h // 2, w // 2, 3), dtype=img.dtype) + img_ids = img_ids.at[..., 1].set(jnp.arange(h // 2)[:, None]) + img_ids = img_ids.at[..., 2].set(jnp.arange(w // 2)[None, :]) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + img_ids = jax.device_put(img_ids, self.data_sharding) + guidance_vec = jnp.full((img.shape[0],), guidance_scale, dtype=img.dtype) + + return img, img_ids, guidance_vec + + def _generate( + self, + params: Union[Dict, FrozenDict], + txt: jnp.array, + txt_ids: jnp.array, + vec: jnp.array, + timesteps: jnp.array, + height: int, + width: int, + guidance_scale: float, + img: Array, + shift: bool = False, + ): + img, img_ids, guidance_vec = self.prepare_img_ids(img, guidance_scale) + + print(f"{len(timesteps) - 1} steps") + + @partial( + jax.jit, + in_shardings=( + self.data_sharding, + self.data_sharding, + self.data_sharding, + self.data_sharding, + self.data_sharding, + self.data_sharding, + None, + None, + None, + ), + out_shardings=(self.data_sharding), + ) + def loop_body(params, img, img_ids, txt, txt_ids, vec, guidance_vec, t_curr, t_prev): + # the order of timesteps is unintuitive... + t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype) + pred = self.flux.apply( + {"params": params}, + hidden_states=img, + img_ids=img_ids, + encoder_hidden_states=txt, + txt_ids=txt_ids, + pooled_projections=vec, + timestep=t_vec, + guidance=guidance_vec, + ) + + img = img + (t_prev - t_curr) * pred.sample + + return img + + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + + for i in range(len(timesteps) - 1): + img = loop_body(params, img, img_ids, txt, txt_ids, vec, guidance_vec, c_ts[i], p_ts[i]) + + # decode latents to pixel space + img = self.unpack(x=img, height=height, width=width) + img = self.ae.decode(img) + return img + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + params: Union[Dict, FrozenDict], + txt: jnp.array, + txt_ids: jnp.array, + vec: jnp.array, + timesteps: int, + height: int, + width: int, + guidance_scale: float, + img: Optional[jnp.ndarray] = None, + shift: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + txt: jnp.array, + txt_ids: jnp.array, + vec: jnp.array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + img: Optional[jnp.ndarray] = None, + shift: bool = False, + jit (`bool`, defaults to `False`): + + Examples: + + """ + + if isinstance(timesteps, int): + timesteps = jnp.linspace(1, 0, timesteps + 1) + + images = self._generate( + params, + txt, + txt_ids, + vec, + timesteps, + height, + width, + guidance_scale, + img, + shift, + ) + + images = images + return images + + def init_flux_weights(self, rng: jax.Array, eval_only: bool = False) -> FrozenDict: + return self.flux.init_weights(rng, eval_only) + + +@staticmethod +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) diff --git a/src/maxdiffusion/pipelines/pipeline_flax_utils.py b/src/maxdiffusion/pipelines/pipeline_flax_utils.py index e8e899df7..e4d53f933 100644 --- a/src/maxdiffusion/pipelines/pipeline_flax_utils.py +++ b/src/maxdiffusion/pipelines/pipeline_flax_utils.py @@ -41,7 +41,7 @@ ) -from maxdiffusion.transformers import FlaxPreTrainedModel +from transformers import FlaxPreTrainedModel INDEX_FILE = "diffusion_flax_model.bin" diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index bc9013cbf..e45aef06a 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -17,6 +17,7 @@ import flax import jax.numpy as jnp +from .. import max_logging from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( @@ -100,6 +101,7 @@ def __init__( self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState: + max_logging.log("Creating EulerDiscreteSchedulerState") if common is None: common = CommonSchedulerState.create(self) @@ -144,7 +146,7 @@ def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndar return sample def set_timesteps( - self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, state: EulerDiscreteSchedulerState, num_inference_steps: int, timestep_spacing: str = "", shape: Tuple = () ) -> EulerDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -155,20 +157,29 @@ def set_timesteps( num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ + if timestep_spacing == "": + timestep_spacing = self.config.timestep_spacing - if self.config.timestep_spacing == "linspace": + max_logging.log(f"Setting timesteps for {num_inference_steps} steps") + if timestep_spacing == "linspace": + max_logging.log("Using linspace timestep spacing") timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) - elif self.config.timestep_spacing == "leading": + elif timestep_spacing == "leading": + max_logging.log("Using leading timestep spacing") step_ratio = self.config.num_train_timesteps // num_inference_steps timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += 1 - elif self.config.timestep_spacing == "trailing": + elif timestep_spacing == "trailing": + max_logging.log("Using trailing timestep spacing") step_ratio = self.config.num_train_timesteps / num_inference_steps timesteps = (jnp.arange(self.config.num_train_timesteps, 0, -step_ratio)).round() timesteps -= 1 + elif timestep_spacing == "flux": + max_logging.log("Using flux timestep spacing") + timesteps = jnp.linspace(1, 0, num_inference_steps + 1) else: raise ValueError( - f"timestep_spacing must be one of ['linspace', 'leading', 'trailing'], got {self.config.timestep_spacing}" + f"timestep_spacing must be one of ['linspace', 'leading', 'trailing', 'flux'], got {self.config.timestep_spacing}" ) sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 @@ -250,7 +261,14 @@ def add_noise( original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, + flux: bool = False, ) -> jnp.ndarray: + if flux: + t = state.timesteps[timesteps] + t = t[:, None, None] + noisy_samples = t * noise + (1 - t) * original_samples + return noisy_samples + sigma = state.sigmas[timesteps].flatten() sigma = broadcast_to_shape_from_left(sigma, noise.shape) diff --git a/src/maxdiffusion/tests/flux_tests.py b/src/maxdiffusion/tests/flux_tests.py new file mode 100644 index 000000000..e45306dbd --- /dev/null +++ b/src/maxdiffusion/tests/flux_tests.py @@ -0,0 +1,35 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +'''This script is used an example of how to shard the UNET on TPU.''' + +import os +import unittest +from absl.testing import absltest +import jax +from ..models.normalization_flax import FlaxAdaLayerNormZeroSingle + +class FluxTests(unittest.TestCase): + def test_adalayernormzerosingle(self): + ada_layer = FlaxAdaLayerNormZeroSingle(embedding_dim=128) + x = jax.random.normal(jax.random.key(0), (2,128)) + params = ada_layer.init({"params" : jax.random.key(0)}, x, x)["params"] + x, y = ada_layer.apply({"params" : params["params"]}, x, x) + + + +if __name__ == '__main__': + absltest.main() diff --git a/src/maxdiffusion/train_jflux.py b/src/maxdiffusion/train_jflux.py new file mode 100644 index 000000000..c90287f9f --- /dev/null +++ b/src/maxdiffusion/train_jflux.py @@ -0,0 +1,48 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Sequence + +import jax +from absl import app +from maxdiffusion import ( + max_logging, + pyconfig, +) + +from maxdiffusion.train_utils import ( + validate_train_config, +) + + +def train(config): + # import JFluxTrainer here or else Jax gets initialized too early and we can't do jax.initialize_distributed + from maxdiffusion.trainers.jflux_trainer import JFluxTrainer + + trainer = JFluxTrainer(config) + trainer.start_training() + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + config = pyconfig.config + validate_train_config(config) + max_logging.log(f"Found {jax.device_count()} devices.") + train(config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 648c10d7e..67bbd2197 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -111,9 +111,12 @@ def write_metrics_to_tensorboard(writer, metrics, step, config): full_log = step % config.log_period == 0 if jax.process_index() == 0: max_logging.log( - f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " - f"loss: {metrics['scalar']['learning/loss']:.3f}" + "completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format( + step, + metrics['scalar']['perf/step_time_seconds'], + metrics['scalar']['perf/per_device_tflops_per_sec'], + float(metrics['scalar']['learning/loss']) + ) ) if full_log and jax.process_index() == 0: diff --git a/src/maxdiffusion/trainers/jflux_trainer.py b/src/maxdiffusion/trainers/jflux_trainer.py new file mode 100644 index 000000000..70f40ebd6 --- /dev/null +++ b/src/maxdiffusion/trainers/jflux_trainer.py @@ -0,0 +1,494 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import jax +import jax.numpy as jnp +import datetime +import time +import flax.linen as nn +from functools import partial +from jax.sharding import PartitionSpec as P +from flax.linen import partitioning as nn_partitioning +import optax +from einops import repeat, rearrange + +from maxdiffusion import (FlaxEulerDiscreteScheduler, maxdiffusion_utils, train_utils, max_utils, max_logging) +from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) + +from maxdiffusion.checkpointing.jflux_checkpointer import (JfluxCheckpointer) + + +# pulls in code from BaseStableDiffusionTrainer and StableDiffusionTrainer +class JFluxTrainer(JfluxCheckpointer): + + def __init__(self, config): + JfluxCheckpointer.__init__(self, config) + + # sharding + self.data_sharding = None + + self.per_device_tflops = None + + self.writer = max_utils.initialize_summary_writer(config) + + self.p_train_step = None + + def pre_training_steps(self): + pass + + def post_training_steps(self, pipeline, state): + if self.config.run_inference_after_training: + import os + from glob import iglob + import re + import numpy as np + from PIL import Image + + seed = jax.random.PRNGKey(seed=102333) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + state = jax.device_put(state, pipeline.data_sharding) + img = pipeline.create_noise( + len(jax.devices()), self.config.resolution, self.config.resolution, self.config.activations_dtype, seed + ) + (txt, txt_ids, vec, img) = pipeline.prepare_inputs([self.config.prompt for _ in range(len(jax.devices()))], img) + + def do_inference(): + return pipeline( + state, + txt, + txt_ids, + vec, + self.config.num_inference_steps, + self.config.resolution, + self.config.resolution, + self.config.guidance_scale, + img, + shift=self.config.model_name != "flux-schnell", + ) + + max_logging.log("Inference") + t0 = time.perf_counter() + x = do_inference() + t1 = time.perf_counter() + output_dir = "output" + output_name = os.path.join(output_dir, "maxdiff_img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"maxdiff_img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + fn = output_name.format(idx=idx) + max_logging.log(f"Done in {t1 - t0:.1f}s. Saving {fn}") + # bring into PIL format and save + x = x.clip(-1, 1) + x = rearrange(x[0], "c h w -> h w c") + + x = 127.5 * (x + 1.0) + x_numpy = np.array(x.astype(jnp.uint8)) + img = Image.fromarray(x_numpy) + + img.save(fn, quality=95, subsampling=0) + + def calculate_tflops(self, pipeline): + per_device_tflops = maxdiffusion_utils.calculate_flux_tflops( + self.config, pipeline, self.total_train_batch_size, self.rng, train=True + ) + max_logging.log(f"JFLUX per device TFLOPS: {per_device_tflops}") + return per_device_tflops + + def start_training(self): + # Hook + self.pre_training_steps() + # Load checkpoint - will load or create states + pipeline = self.load_checkpoint() + # create train states + train_states = {} + state_shardings = {} + flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state( + # ambiguous here, but if self.params.get("flux") doesn't exist + # Then its 1 of 2 scenarios: + # 1. flux state will be loaded directly from orbax + # 2. a new flux is being trained from scratch. + flux=pipeline.flux, + init_flux_weights=pipeline.init_flux_weights, + params=None, + is_training=True, + ) + train_states[JfluxCheckpointer.flux_state_item_name] = flux_state + state_shardings["flux_state_shardings"] = flux_state_mesh_shardings + + # Create scheduler + max_logging.log("Creating scheduler") + noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, None) + pipeline.scheduler = noise_scheduler + train_states["scheduler"] = noise_scheduler_state + + # Calculate tflops + per_device_tflops = self.calculate_tflops(pipeline) + self.per_device_tflops = per_device_tflops + + # Load dataset + max_logging.log("Loading data set") + data_iterator = self.load_dataset(pipeline) + max_logging.log("Data set loaded") + + data_shardings = self.get_data_shardings() + max_logging.log("Data sharding created") + # Compile train_step + p_train_step = self.compile_train_step(pipeline, train_states, state_shardings, data_shardings) + max_logging.log("Train loop compiled") + # Start training + max_logging.log("Starting training loop") + train_states = self.training_loop(p_train_step, pipeline, train_states, data_iterator, flux_learning_rate_scheduler) + max_logging.log("End training loop") + # 6. save final checkpoint + # Hook + self.post_training_steps(pipeline, train_states[JfluxCheckpointer.flux_state_item_name].params) + + def get_shaped_batch(self, config, pipeline): + """Return the shape of the batch - this is what eval_shape would return for the + output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078. + """ + + if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs: + scale_factor = 16 # hardcoded in jflux.get_noise + h = config.resolution // scale_factor + w = config.resolution // scale_factor + c = 16 + ph = pw = 2 + batch_image_shape = (self.total_train_batch_size, h * w, c * ph * pw) # b + img_ids_shape = (self.total_train_batch_size, (2 * h // 2) * (2 * w // 2), 3) + text_shape = ( + self.total_train_batch_size, + 256 if config.model_name == "flux-schnell" else 512, + 4096, # Sequence length of text encoder, how to get this programmatically? + ) + text_ids_shape = ( + self.total_train_batch_size, + 256 if config.model_name == "flux-schnell" else 512, + 3, + ) + prompt_embeds_shape = ( + self.total_train_batch_size, + 768, # Sequence length of clip, how to get this programmatically? + ) + input_ids_dtype = self.config.activations_dtype + else: + batch_image_shape = (self.total_train_batch_size, 3, config.resolution, config.resolution) + text_shape = ( + self.total_train_batch_size, + pipeline.t5.max_length, + ) + text_ids_shape = ( + self.total_train_batch_size, + pipeline.t5.max_length, + ) + prompt_embeds_shape = ( + self.total_train_batch_size, + pipeline.clip.max_length, + ) + input_ids_dtype = self.config.activations_dtype + + shaped_batch = {} + shaped_batch["pixel_values"] = jax.ShapeDtypeStruct(batch_image_shape, input_ids_dtype) + shaped_batch["text_embeds"] = jax.ShapeDtypeStruct(text_shape, input_ids_dtype) + shaped_batch["input_ids"] = jax.ShapeDtypeStruct(text_ids_shape, input_ids_dtype) + shaped_batch["prompt_embeds"] = jax.ShapeDtypeStruct(prompt_embeds_shape, input_ids_dtype) + shaped_batch["img_ids"] = jax.ShapeDtypeStruct(img_ids_shape, input_ids_dtype) + return shaped_batch + + def create_scheduler(self, pipeline, params): + noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32 + ) + noise_scheduler_state = noise_scheduler.set_timesteps( + state=noise_scheduler_state, num_inference_steps=self.config.num_inference_steps, timestep_spacing="flux" + ) + return noise_scheduler, noise_scheduler_state + + def get_data_shardings(self): + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = { + "text_embeds": data_sharding, + "input_ids": data_sharding, + "prompt_embeds": data_sharding, + "pixel_values": data_sharding, + "img_ids": data_sharding, + } + + return data_sharding + + # adapted from max_utils.tokenize_captions_xl + @staticmethod + def tokenize_captions(examples, caption_column, tokenizer_t5, tokenizer_clip): + prompt = list(examples[caption_column]) + bs = len(prompt) + + text_embeds = tokenizer_t5(prompt) + prompt_embeds = tokenizer_clip(prompt) + + examples["text_embeds"] = jnp.float16(text_embeds) + examples["input_ids"] = jnp.float16(jnp.zeros((bs, text_embeds.shape[1], 3))) + examples["prompt_embeds"] = jnp.float16(prompt_embeds) + + return examples + + @staticmethod + def transform_images( + examples, + image_column, + image_resolution, + encoder, + ): + """Preprocess images to latents.""" + images = list(examples[image_column]) + + images = [ + jax.image.resize( + jnp.asarray(image) / 127.5 - 1.0, [image_resolution, image_resolution, 3], method="bilinear", antialias=True + ) + for image in images + ] + + images = jnp.stack(images, axis=0, dtype=jnp.float16) + batch_size = 8 + num_batches = len(images) // batch_size + int(len(images) % batch_size != 0) + encoded_images = [] + + for i in range(num_batches): + batch_images = images[i * batch_size : (i + 1) * batch_size] + batch_images = rearrange(batch_images, "b h w c -> b c h w") + batch_images = encoder.encode(batch_images) + encoded_images.append(batch_images) + + images = jnp.concatenate(encoded_images, axis=0, dtype=jnp.float16) + + batch_size, _, h, w = images.shape + images = rearrange(images, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + img_ids = jnp.zeros((h // 2, w // 2, 3)) + img_ids = img_ids.at[..., 1].set(jnp.arange(h // 2)[:, None]) + img_ids = img_ids.at[..., 2].set(jnp.arange(w // 2)[None, :]) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + + examples["pixel_values"] = images + examples["img_ids"] = img_ids + + return examples + + @staticmethod + def encode_image(pipeline, x): + return pipeline.pack_img(pipeline.encode(x)) + + def load_dataset(self, pipeline): + p_encode = None + rng = None + + if self.config.dataset_type == "tf" and self.config.cache_latents_text_encoder_outputs: + ... + + tokenize_fn = partial( + JFluxTrainer.tokenize_captions, + caption_column=self.config.caption_column, + tokenizer_t5=pipeline.t5, + tokenizer_clip=pipeline.clip, + ) + max_logging.log("Creating image transforms") + image_transforms_fn = partial( + JFluxTrainer.transform_images, + image_column=self.config.image_column, + image_resolution=self.config.resolution, + encoder=pipeline.ae, + ) + max_logging.log("Creating data iterator") + data_iterator = make_data_iterator( + self.config, + jax.process_index(), + jax.process_count(), + self.mesh, + self.total_train_batch_size, + tokenize_fn=tokenize_fn, + image_transforms_fn=image_transforms_fn, + ) + return data_iterator + + def compile_train_step(self, pipeline, train_states, state_shardings, data_shardings): + self.rng, train_rngs = jax.random.split(self.rng) + guidance_vec = jnp.full((self.total_train_batch_size,), self.config.guidance, dtype=self.config.activations_dtype) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + p_train_step = jax.jit( + partial( + _train_step, + guidance_vec=guidance_vec, + pipeline=pipeline, + scheduler=train_states["scheduler"], + config=self.config, + ), + in_shardings=( + state_shardings["flux_state_shardings"], + data_shardings, + None, + ), + out_shardings=(state_shardings["flux_state_shardings"], None, None), + donate_argnums=(0,), + ) + max_logging.log("Precompiling...") + s = time.time() + dummy_batch = self.get_shaped_batch(self.config, pipeline) + p_train_step = p_train_step.lower(train_states["flux_state"], dummy_batch, train_rngs) + p_train_step = p_train_step.compile() + max_logging.log(f"Compile time: {(time.time() - s )}") + return p_train_step + + def training_loop(self, p_train_step, pipeline, train_states, data_iterator, unet_learning_rate_scheduler): + writer = max_utils.initialize_summary_writer(self.config) + flux_state = train_states["flux_state"] + + num_model_parameters = max_utils.calculate_num_params_from_pytree(flux_state.params) + + max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) + max_utils.add_config_to_summary_writer(self.config, writer) + + if jax.process_index() == 0: + max_logging.log("***** Running training *****") + max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.total_train_batch_size}") + max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") + + last_step_completion = datetime.datetime.now() + local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None + running_gcs_metrics = [] if self.config.gcs_metrics else None + example_batch = None + + first_profiling_step = self.config.skip_first_n_steps_for_profiler + if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + last_profiling_step = jnp.clip( + first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 + ) + + start_step = train_utils.get_first_step(train_states["flux_state"]) + _, train_rngs = jax.random.split(self.rng) + times = [] + for step in jnp.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + + example_batch = train_utils.load_next_batch(data_iterator, example_batch, self.config) + example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()} + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + with self.mesh: + flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs) + samples_count = self.total_train_batch_size * (step + 1) + new_time = datetime.datetime.now() + + train_utils.record_scalar_metrics( + train_metric, new_time - last_step_completion, self.per_device_tflops, unet_learning_rate_scheduler(step) + ) + times.append(new_time - last_step_completion) + if self.config.write_metrics: + train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + last_step_completion = new_time + + if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: + max_logging.log(f"Saving checkpoint for step {step}") + train_states["flux_state"] = flux_state + self.save_checkpoint(step, pipeline, train_states) + + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + + if self.config.write_metrics and start_step < self.config.max_train_steps: + train_utils.write_metrics( + writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config + ) + + train_states["flux_state"] = flux_state + max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}") + if self.config.save_final_checkpoint: + max_logging.log(f"Saving checkpoint for step {step}") + self.save_checkpoint(step, pipeline, train_states) + self.checkpoint_manager.wait_until_finished() + return train_states + + +def _train_step(flux_state, batch, train_rng, guidance_vec, pipeline, scheduler, config): + _, gen_dummy_rng = jax.random.split(train_rng) + sample_rng, timestep_bias_rng, new_train_rng = jax.random.split(gen_dummy_rng, 3) + state_params = {"flux": flux_state.params} + + def compute_loss(state_params): + latents = batch["pixel_values"] + text_embeds_ids = batch["input_ids"] + text_embeds = batch["text_embeds"] + prompt_embeds = batch["prompt_embeds"] + img_ids = batch["img_ids"] + + # Sample noise that we'll add to the latents + noise_rng, timestep_rng = jax.random.split(sample_rng) + noise = jax.random.normal( + key=noise_rng, + shape=latents.shape, + dtype=latents.dtype, + ) + # Sample a random timestep for each image + bsz = latents.shape[0] + if config.timestep_bias["strategy"] == "none": + timesteps = jax.random.randint(timestep_rng, shape=(bsz,), minval=0, maxval=len(scheduler.timesteps)) + else: + weights = train_utils.generate_timestep_weights(config, pipeline.scheduler.config.num_train_timesteps) + timesteps = jax.random.categorical(timestep_bias_rng, logits=jnp.log(weights), shape=(bsz,)) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = pipeline.scheduler.add_noise(scheduler, latents, noise, timesteps, flux=True) + + # Predict the noise residual and compute loss + # for flux, encoder_hidden_states = (txt, txt_ids, vec) + model_pred = pipeline.flux.apply( + {"params": state_params["flux"]}, + hidden_states=noisy_latents, + img_ids=img_ids, + encoder_hidden_states=text_embeds, + txt_ids=text_embeds_ids, + pooled_projections=prompt_embeds, + timestep=scheduler.timesteps[timesteps], + guidance=guidance_vec, + ).sample + + target = noise - latents + loss = (target - model_pred) ** 2 + + loss = nn.with_logical_constraint(loss, ("activation_embed_and_logits_batch", "activation_length")) + loss = jnp.mean(loss) + + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state_params) + + if config.max_grad_norm > 0: + grad, _ = optax.clip_by_global_norm(config.max_grad_norm).update(grad, flux_state, None) + + new_state = flux_state.apply_gradients(grads=grad["flux"]) + + metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} + + return new_state, metrics, new_train_rng diff --git a/src/maxdiffusion/utils/dynamic_modules_utils.py b/src/maxdiffusion/utils/dynamic_modules_utils.py index f12c0b71f..ec59368e7 100644 --- a/src/maxdiffusion/utils/dynamic_modules_utils.py +++ b/src/maxdiffusion/utils/dynamic_modules_utils.py @@ -25,7 +25,7 @@ from typing import Dict, Optional, Union from urllib import request -from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info +from huggingface_hub import HfFolder, hf_hub_download, model_info from packaging import version from .. import __version__ From aed1f77869d2abb0108ac206e48caa600528d901 Mon Sep 17 00:00:00 2001 From: jcaraban Date: Tue, 28 Jan 2025 10:57:32 +0000 Subject: [PATCH 2/3] split base_jflux.yml into dev/schnell --- src/maxdiffusion/configs/base_jflux_dev.yml | 260 ++++++++++++++++++ ...{base_jflux.yml => base_jflux_schnell.yml} | 4 +- .../models/modeling_flax_pytorch_utils.py | 2 +- 3 files changed, 263 insertions(+), 3 deletions(-) create mode 100644 src/maxdiffusion/configs/base_jflux_dev.yml rename src/maxdiffusion/configs/{base_jflux.yml => base_jflux_schnell.yml} (99%) diff --git a/src/maxdiffusion/configs/base_jflux_dev.yml b/src/maxdiffusion/configs/base_jflux_dev.yml new file mode 100644 index 000000000..0f3d02514 --- /dev/null +++ b/src/maxdiffusion/configs/base_jflux_dev.yml @@ -0,0 +1,260 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: 'jflux-dev' + +model_name: "flux-dev" +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 +save_interval_steps: -1 + +pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' +checkpoint_path: "" +checkpoint_step: -1 +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +guidance: 4.0 +save_final_checkpoint: False +run_inference_after_training: True + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# Set true to load weights from pytorch +from_pt: False +split_head_dim: True +attention: 'dot_product' # Supported attention: dot_product, flash +flash_block_sizes: {} +# GroupNorm groups +norm_num_groups: 32 + +# If train_new_flux, flux weights will be randomly initialized to train from scratch +# else they will be loaded from pretrained_model_name_or_path +train_new_flux: False +revision: '' + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: '', + # values are v_prediction or leave empty to use scheduler's default. + prediction_type: '', + rescale_zero_terminal_snr: False, + timestep_spacing: '' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + + +# Parallelism +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive'] +logical_axis_rules: [ + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages. + # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape. + # The "stage" needs to be listed first since the microbatch dimension is first before the reshape. + ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_heads', ['tensor','sequence']], + ['activation_kv_heads', ['tensor','sequence']], + ['activation_length', 'sequence'], + ['activation_embed', ['tensor', 'fsdp_transpose']], + ['activation_mlp', 'tensor'], + ['activation_kv', 'tensor'], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_head_dim', 'tensor'], + ['activation_vocab', ['tensor', 'sequence']], + ['activation_stage', 'stage'], + ['activation_exp', 'expert'], + ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], + ['vocab', ['tensor', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']], + ['norm', 'tensor'], + ['heads', ['tensor', 'autoregressive', 'fsdp_transpose']], + ['layers', 'stage'], + ['kv', []], + ['kv_heads', ['tensor', 'autoregressive']], + ['kv_head_dim', []], + ['cache_batch', []], + ['cache_heads', ['autoregressive', 'tensor']], + ['cache_kv', []], + ['cache_sequence', []], + ['exp', 'expert'], + ] +# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 +dcn_fsdp_transpose_parallelism: 1 +dcn_sequence_parallelism: 1 # never recommended +dcn_tensor_parallelism: 1 # never recommended +dcn_pipeline_parallelism: 1 +dcn_expert_parallelism: 1 +dcn_autoregressive_parallelism: 1 # never recommended +ici_data_parallelism: 1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_fsdp_transpose_parallelism: -1 +ici_sequence_parallelism: 1 +ici_tensor_parallelism: 1 +ici_autoregressive_parallelism: 1 +ici_pipeline_parallelism: 1 +ici_expert_parallelism: 1 + +# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation, +# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. +num_slices: 1 + + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_dev' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '/dev/shm/jax' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 50 +num_train_epochs: 1 +seed: 102333 +output_dir: '/workspace/runs' +per_device_batch_size: 2 + +warmup_steps_fraction: 0.0 +cosine_learning_rate_final_fraction: 1.0 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A confident Grovyle, a grass-type Pokémon, strikes a dynamic pose with its leafy appendages." +negative_prompt: "purple, red" +do_classifier_free_guidance: True +guidance_scale: 4.0 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 50 + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" +unet_checkpoint: "" # needed in pyconfig + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' + +# added from maxtext version +hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' +compile_topology: '' +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. + +custom_mesh: "" # Available options: ['hybrid_ring_64x4'] +# Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html +allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/base_jflux.yml b/src/maxdiffusion/configs/base_jflux_schnell.yml similarity index 99% rename from src/maxdiffusion/configs/base_jflux.yml rename to src/maxdiffusion/configs/base_jflux_schnell.yml index ce65590a1..540702b8e 100644 --- a/src/maxdiffusion/configs/base_jflux.yml +++ b/src/maxdiffusion/configs/base_jflux_schnell.yml @@ -13,7 +13,7 @@ # limitations under the License. # This sentinel is a reminder to choose a real run name. -run_name: 'jflux' +run_name: 'jflux-schnell' model_name: "flux-schnell" metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. @@ -170,7 +170,7 @@ cache_latents_text_encoder_outputs: True # prepare image latents and text encoder outputs # Reduce memory consumption and reduce step time during training # transformed dataset is saved at dataset_save_location -dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +dataset_save_location: '/tmp/pokemon-gpt4-captions_sch' train_data_dir: '' dataset_config_name: '' jax_cache_dir: '/dev/shm/jax' diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index fbb8ce992..0bfea42ef 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -247,7 +247,7 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alpha def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Step 1: Convert pytorch tensor to numpy - pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} # Step 2: Since the model is stateless, run eval_shape to get the pytree structure random_flax_params = flax_model.init_weights(PRNGKey(init_key), eval_only=True) From 878796f8fc7effdc5fd8a91704f8cc4e89b294eb Mon Sep 17 00:00:00 2001 From: jcaraban Date: Tue, 28 Jan 2025 07:39:34 -0600 Subject: [PATCH 3/3] fix missing 'guidance_in' --> FlaxTimestepEmbedding_1 --- .../base_stable_diffusion_checkpointer.py | 1 - .../checkpointing/jflux_checkpointer.py | 20 ----------------- src/maxdiffusion/generate_jflux.py | 22 ++++++++++++++++++- .../models/modeling_flax_pytorch_utils.py | 5 +++-- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index 9c989add5..8f0421b11 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -43,7 +43,6 @@ STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT" _CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS" _CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX" -JFLUX_CHECKPOINT = "JFLUX_CHECKPOINT" class BaseStableDiffusionCheckpointer(ABC): diff --git a/src/maxdiffusion/checkpointing/jflux_checkpointer.py b/src/maxdiffusion/checkpointing/jflux_checkpointer.py index b86aee375..c44282fc0 100644 --- a/src/maxdiffusion/checkpointing/jflux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/jflux_checkpointer.py @@ -33,26 +33,6 @@ ) -def get_device_type(): - """Returns the type of JAX device being used. - - Returns: - str: "gpu", "tpu", or "cpu" - """ - try: - device_kind = jax.devices()[0].device_kind - if "tpu" in device_kind.lower(): - return "tpu" - elif "amd" in device_kind.lower(): - return "rocm" - elif "nvidia" in device_kind.lower(): - return "cuda" - else: - return "cpu" - except IndexError: - return "cpu" # No devices found, likely using CPU - - class JfluxCheckpointer(ABC): flux_state_item_name = "flux_state" config_item_name = "config" diff --git a/src/maxdiffusion/generate_jflux.py b/src/maxdiffusion/generate_jflux.py index dc8e00ba0..3d1f30cd8 100644 --- a/src/maxdiffusion/generate_jflux.py +++ b/src/maxdiffusion/generate_jflux.py @@ -34,8 +34,28 @@ from glob import iglob +def get_device_type(): + """Returns the type of JAX device being used. + + Returns: + str: "gpu", "tpu", or "cpu" + """ + try: + device_kind = jax.devices()[0].device_kind + if "tpu" in device_kind.lower(): + return "tpu" + elif "amd" in device_kind.lower(): + return "rocm" + elif "nvidia" in device_kind.lower(): + return "cuda" + else: + return "cpu" + except IndexError: + return "cpu" # No devices found, likely using CPU + + def run(config): - device_type = jflux_checkpointer.get_device_type() + device_type = get_device_type() max_logging.log(f"Using {device_type} device") output_dir = "output" diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 0bfea42ef..826ac2290 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -269,12 +269,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): renamed_pt_key = renamed_pt_key.replace("txt_mod", "txt_norm1") renamed_pt_key = renamed_pt_key.replace("img_attn.qkv", "attn.i_qkv") renamed_pt_key = renamed_pt_key.replace("img_attn.proj", "attn.i_proj") + renamed_pt_key = renamed_pt_key.replace("img_attn.norm", "attn") renamed_pt_key = renamed_pt_key.replace("txt_attn.qkv", "attn.e_qkv") renamed_pt_key = renamed_pt_key.replace("txt_attn.proj", "attn.e_proj") - - renamed_pt_key = renamed_pt_key.replace("img_attn.norm", "attn") renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.key_norm", "attn.encoder_key_norm") renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.query_norm", "attn.encoder_query_norm") + elif("guidance_in" in renamed_pt_key): + renamed_pt_key = renamed_pt_key.replace("guidance_in", "time_text_embed.FlaxTimestepEmbedding_1") elif "single_blocks" in renamed_pt_key: renamed_pt_key = renamed_pt_key.replace("single_blocks_", "single_blocks.layers_") renamed_pt_key = renamed_pt_key.replace("modulation", "norm")