From bc6cd42f21668f78b3eeb10d490ccde69905993e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 14 Jan 2025 02:06:12 +0000 Subject: [PATCH 01/35] add support for flux vae. ~ wip --- src/maxdiffusion/models/embeddings_flax.py | 1 - src/maxdiffusion/models/flux/__init__.py | 2 +- .../models/flux/modules/__init__.py | 15 + .../models/flux/modules/layers.py | 95 +++ .../models/flux/transformers/__init__.py | 2 +- .../transformers/transformer_flux_flax.py | 608 +++--------------- src/maxdiffusion/models/modeling_utils.py | 1 - src/maxdiffusion/models/vae_flax.py | 4 +- src/maxdiffusion/tests/vae_test.py | 45 +- 9 files changed, 200 insertions(+), 573 deletions(-) create mode 100644 src/maxdiffusion/models/flux/modules/__init__.py create mode 100644 src/maxdiffusion/models/flux/modules/layers.py diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 42ca4b950..a42418f20 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -56,7 +56,6 @@ def get_sinusoidal_embeddings( signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal - class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. diff --git a/src/maxdiffusion/models/flux/__init__.py b/src/maxdiffusion/models/flux/__init__.py index 84dd0f150..6d7590d6d 100644 --- a/src/maxdiffusion/models/flux/__init__.py +++ b/src/maxdiffusion/models/flux/__init__.py @@ -14,4 +14,4 @@ limitations under the License. """ -from .transformers.transformer_flux_flax import FluxTransformer2DModel +from .transformers.transformer_flux_flax import FluxTransformer2DModel \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/modules/__init__.py b/src/maxdiffusion/models/flux/modules/__init__.py new file mode 100644 index 000000000..55bca151a --- /dev/null +++ b/src/maxdiffusion/models/flux/modules/__init__.py @@ -0,0 +1,15 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py new file mode 100644 index 000000000..3e4d5f083 --- /dev/null +++ b/src/maxdiffusion/models/flux/modules/layers.py @@ -0,0 +1,95 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import math +from dataclasses import dataclass +import jax +import jax.numpy as jnp +from chex import Array +from jax.typing import DTypeLike +import flax.linen as nn + +def timestep_embedding( + t: Array, dim: int, max_period=10000, time_factor: float = 1000.0 +) -> Array: + """ + Generate timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + time_factor: Tensor of positional embeddings. + + Returns: + timestep embeddings. + """ + t = time_factor * t + half = dim // 2 + + freqs = jnp.exp( + -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half + ).astype(dtype=t.dtype) + + args = t[:, None].astype(jnp.float32) * freqs[None] + embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) + + if dim % 2: + embedding = jnp.concatenate( + [embedding, jnp.zeros_like(embedding[:, :1])], axis=-1 + ) + + if jnp.issubdtype(t.dtype, jnp.floating): + embedding = embedding.astype(t.dtype) + + return embedding + + +class MLPEmbedder(nn.Module): + hidden_dim: int + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, x: Array) -> Array: + + x = nn.Dense( + self.hidden_dim, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ) + )(x) + x = nn.silu(x) + x = nn.Dense( + self.hidden_dim, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("heads", "embed") + ) + )(x) + + return x \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/transformers/__init__.py b/src/maxdiffusion/models/flux/transformers/__init__.py index 7e4185f36..55bca151a 100644 --- a/src/maxdiffusion/models/flux/transformers/__init__.py +++ b/src/maxdiffusion/models/flux/transformers/__init__.py @@ -12,4 +12,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - """ + """ \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 5035e36e4..5645a84e4 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -14,318 +14,35 @@ limitations under the License. """ -from typing import Tuple +from typing import Dict, Optional, Tuple, Union + import jax -import math import jax.numpy as jnp -import flax -import flax.linen as nn -from einops import repeat, rearrange -from ....configuration_utils import ConfigMixin, flax_register_to_config +import flax.linen as nn +from chex import Array + +from ..modules.layers import timestep_embedding, MLPEmbedder from ...modeling_flax_utils import FlaxModelMixin -from ...normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous, AdaLayerNormZero -from ...attention_flax import FlaxFluxAttention -from ...embeddings_flax import (FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings) -from .... import common_types +from ....configuration_utils import ConfigMixin, flax_register_to_config from ....common_types import BlockSizes -from ....utils import BaseOutput - -AxisNames = common_types.AxisNames -BATCH = common_types.BATCH -LENGTH = common_types.LENGTH -HEAD = common_types.HEAD -D_KV = common_types.D_KV - - -@flax.struct.dataclass -class Transformer2DModelOutput(BaseOutput): - """ - The output of [`FluxTransformer2DModel`]. - - Args: - sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: jnp.ndarray - - -class FluxSingleTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - dim: int - num_attention_heads: int - attention_head_dim: int - mlp_ratio: int = 4.0 - attention_kernel: str = "dot_product" - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - - def setup(self): - self.mlp_hidden_dim = int(self.dim * self.mlp_ratio) - - self.norm = AdaLayerNormZeroSingle( - self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision - ) - - self.linear1 = nn.Dense( - self.dim * 3 + self.mlp_hidden_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ) - - self.mlp_act = nn.gelu - self.linear2 = nn.Dense( - self.dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ) - self.attn = FlaxFluxAttention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - attention_kernel=self.attention_kernel, - mesh=self.mesh, - flash_block_sizes=self.flash_block_sizes, - ) - - def __call__(self, hidden_states, temb, image_rotary_emb=None): - residual = hidden_states - norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - qkv, mlp = jnp.split(self.linear1(norm_hidden_states), [3 * self.dim], axis=-1) - mlp = nn.with_logical_constraint(mlp, ("activation_batch", "activation_length", "activation_embed")) - qkv = nn.with_logical_constraint(qkv, ("activation_batch", "activation_length", "activation_embed")) - - B, L = hidden_states.shape[:2] - H, D, K = self.num_attention_heads, qkv.shape[-1] // (self.num_attention_heads * 3), 3 - qkv_proj = qkv.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) - q, k, v = qkv_proj - - q = self.attn.query_norm(q) - k = self.attn.key_norm(k) - - if image_rotary_emb is not None: - # since this function returns image_rotary_emb and passes it between layers, - # we do not want to modify it - image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) - q, k = self.attn.apply_rope(q, k, image_rotary_emb_reordered) - - q = q.transpose(0, 2, 1, 3).reshape(q.shape[0], q.shape[2], -1) - k = k.transpose(0, 2, 1, 3).reshape(k.shape[0], k.shape[2], -1) - v = v.transpose(0, 2, 1, 3).reshape(v.shape[0], v.shape[2], -1) - - attn_output = self.attn.attention_op.apply_attention(q, k, v) - - attn_mlp = jnp.concatenate([attn_output, self.mlp_act(mlp)], axis=2) - attn_mlp = nn.with_logical_constraint(attn_mlp, ("activation_batch", "activation_length", "activation_embed")) - hidden_states = self.linear2(attn_mlp) - hidden_states = gate * hidden_states - hidden_states = residual + hidden_states - if hidden_states.dtype == jnp.float16: - hidden_states = jnp.clip(hidden_states, -65504, 65504) - - return hidden_states, temb, image_rotary_emb - - -class FluxTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - dim: int - num_attention_heads: int - attention_head_dim: int - qk_norm: str = "rms_norm" - eps: int = 1e-6 - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - mlp_ratio: float = 4.0 - qkv_bias: bool = False - attention_kernel: str = "dot_product" - - def setup(self): - - self.img_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) - self.txt_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) - - self.attn = FlaxFluxAttention( - query_dim=self.dim, - heads=self.num_attention_heads, - dim_head=self.attention_head_dim, - qkv_bias=self.qkv_bias, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - attention_kernel=self.attention_kernel, - mesh=self.mesh, - flash_block_sizes=self.flash_block_sizes, - ) - - self.img_norm2 = nn.LayerNorm( - use_bias=False, - use_scale=False, - epsilon=self.eps, - dtype=self.dtype, - param_dtype=self.weights_dtype, - ) - self.img_mlp = nn.Sequential( - [ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ] - ) - - self.txt_norm2 = nn.LayerNorm( - use_bias=False, - use_scale=False, - epsilon=self.eps, - dtype=self.dtype, - param_dtype=self.weights_dtype, - ) - self.txt_mlp = nn.Sequential( - [ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ] - ) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None): - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.img_norm1(hidden_states, emb=temb) - - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.txt_norm1( - encoder_hidden_states, emb=temb - ) - - # Attention. - attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - ) - - attn_output = gate_msa * attn_output - hidden_states = hidden_states + attn_output - norm_hidden_states = self.img_norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - - ff_output = self.img_mlp(norm_hidden_states) - ff_output = gate_mlp * ff_output - - hidden_states = hidden_states + ff_output - # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.txt_norm2(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp - context_ff_output = self.txt_mlp(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output - if encoder_hidden_states.dtype == jnp.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - return hidden_states, encoder_hidden_states, temb, image_rotary_emb +class Identity(nn.Module): + def __call__(self, x: Array) -> Array: + return x - -@flax_register_to_config class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" - The Tranformer model introduced in Flux. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + The Tranformer model introduced in Flux. - This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods - implemented for all models (such as downloading or saving). + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its - general usage and behavior. - - Parameters: - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. - num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its + general usage and behavior. """ - patch_size: int = 1 in_channels: int = 64 num_layers: int = 19 @@ -342,253 +59,76 @@ class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 precision: jax.lax.Precision = None - mlp_ratio: float = 4.0 - qkv_bias: bool = True - theta: int = 1000 - attention_kernel: str = "dot_product" - eps = 1e-6 def setup(self): self.out_channels = self.in_channels - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.inner_dim = self.num_attention_heads * self.attention_head_dim - self.pe_embedder = FluxPosEmbed(theta=self.theta, axes_dim=self.axes_dims_rope, dtype=self.dtype) - - text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings if self.guidance_embeds else CombinedTimestepTextProjEmbeddings - ) - - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, - pooled_projection_dim=self.pooled_projection_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - ) - self.txt_in = nn.Dense( - self.inner_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ) self.img_in = nn.Dense( - self.inner_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ) - - double_blocks = [] - for _ in range(self.num_layers): - double_block = FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - mlp_ratio=self.mlp_ratio, - qkv_bias=self.qkv_bias, + self.inner_dim, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") ) - double_blocks.append(double_block) - self.double_blocks = double_blocks + ) - single_blocks = [] - for _ in range(self.num_single_layers): - single_block = FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - mlp_ratio=self.mlp_ratio, - ) - single_blocks.append(single_block) + self.time_in = MLPEmbedder( + hidden_dim=self.inner_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + ) - self.single_blocks = single_blocks + self.vector_in = MLPEmbedder( + hidden_dim=self.inner_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + ) - self.norm_out = AdaLayerNormContinuous( - self.inner_dim, - elementwise_affine=False, - eps=self.eps, + self.guidance_in = ( + MLPEmbedder( + hidden_dim=self.inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision, + precision=self.precision + ) + if self.guidance_embeds + else Identity() ) - self.proj_out = nn.Dense( - self.patch_size**2 * self.out_channels, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", None)), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - use_bias=True, + self.txt_in = nn.Dense( + self.inner_dim, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision ) - - def timestep_embedding(self, t: jax.Array, dim: int, max_period=10000, time_factor: float = 1000.0) -> jax.Array: - """ - Generate timestep embeddings. - - Args: - t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - dim: the dimension of the output. - max_period: controls the minimum frequency of the embeddings. - time_factor: Tensor of positional embeddings. - - Returns: - timestep embeddings. - """ - t = time_factor * t - half = dim // 2 - - freqs = jnp.exp(-math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.bfloat16) / half).astype(dtype=t.dtype) - - args = t[:, None].astype(jnp.bfloat16) * freqs[None] - embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) - - if dim % 2: - embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1) - - if jnp.issubdtype(t.dtype, jnp.floating): - embedding = embedding.astype(t.dtype) - - return embedding - + def __call__( - self, - hidden_states, - encoder_hidden_states, - pooled_projections, - timestep, - img_ids, - txt_ids, - guidance, - return_dict: bool = True, - train: bool = False, - ): - hidden_states = self.img_in(hidden_states) - timestep = self.timestep_embedding(timestep, 256) - if self.guidance_embeds: - guidance = self.timestep_embedding(guidance, 256) - else: - guidance = None - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.txt_in(encoder_hidden_states) - if txt_ids.ndim == 3: - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - img_ids = img_ids[0] - - ids = jnp.concatenate((txt_ids, img_ids), axis=0) - ids = nn.with_logical_constraint(ids, ("activation_batch", None)) - image_rotary_emb = self.pe_embedder(ids) - image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed")) + self, + img: Array, + img_ids: Array, + txt: Array, + txt_ids: Array, + timesteps: Array, + y: Array, + guidance: Array | None = None, + return_dict: bool = True, + train: bool = False): + + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) - for double_block in self.double_blocks: - hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1) - hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) - for single_block in self.single_blocks: - hidden_states, temb, image_rotary_emb = single_block( - hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb - ) - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) - - def init_weights(self, rngs, max_sequence_length, eval_only=True): - scale_factor = 16 - resolution = 1024 - num_devices = len(jax.devices()) - batch_size = 1 * num_devices - batch_image_shape = ( - batch_size, - 16, # 16 to match jflux.get_noise - 2 * resolution // scale_factor, - 2 * resolution // scale_factor, - ) - # bs, encoder_input, seq_length - text_shape = ( - batch_size, - max_sequence_length, - 4096, # Sequence length of text encoder, how to get this programmatically? - ) - text_ids_shape = ( - batch_size, - max_sequence_length, - 3, # Hardcoded to match jflux.prepare - ) - vec_shape = ( - batch_size, - 768, # Sequence length of clip, how to get this programmatically? - ) - img = jnp.zeros(batch_image_shape, dtype=self.dtype) - bs, _, h, w = img.shape - img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - img_ids = jnp.zeros((h // 2, w // 2, 3), dtype=self.dtype) - img_ids = img_ids.at[..., 1].set(jnp.arange(h // 2)[:, None]) - img_ids = img_ids.at[..., 2].set(jnp.arange(w // 2)[None, :]) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) - - txt = jnp.zeros(text_shape, dtype=self.dtype) - txt_ids = jnp.zeros(text_ids_shape, dtype=self.dtype) - - t_vec = jnp.full(bs, 0, dtype=self.dtype) - - vec = jnp.zeros(vec_shape, dtype=self.dtype) - - guidance_vec = jnp.full(bs, 4.0, dtype=self.dtype) - - if eval_only: - return jax.eval_shape( - self.init, - rngs, - hidden_states=img, - img_ids=img_ids, - encoder_hidden_states=txt, - txt_ids=txt_ids, - pooled_projections=vec, - timestep=t_vec, - guidance=guidance_vec, - )["params"] - else: - return self.init( - rngs, - hidden_states=img, - img_ids=img_ids, - encoder_hidden_states=txt, - txt_ids=txt_ids, - pooled_projections=vec, - timestep=t_vec, - guidance=guidance_vec, - )["params"] + if self.guidance_embeds: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distrilled model." + ) + + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) \ No newline at end of file diff --git a/src/maxdiffusion/models/modeling_utils.py b/src/maxdiffusion/models/modeling_utils.py index 3bf54107f..8d0ffe5e4 100644 --- a/src/maxdiffusion/models/modeling_utils.py +++ b/src/maxdiffusion/models/modeling_utils.py @@ -109,7 +109,6 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ return torch.load(checkpoint_file, map_location="cpu") else: from safetensors import torch as safetensors_torch - return safetensors_torch.load_file(checkpoint_file, device="cpu") except Exception as e: try: diff --git a/src/maxdiffusion/models/vae_flax.py b/src/maxdiffusion/models/vae_flax.py index dc9b00630..b7bc3e4d4 100644 --- a/src/maxdiffusion/models/vae_flax.py +++ b/src/maxdiffusion/models/vae_flax.py @@ -891,10 +891,10 @@ def encode(self, sample, deterministic: bool = True, return_dict: bool = True): sample = jnp.transpose(sample, (0, 2, 3, 1)) hidden_states = self.encoder(sample, deterministic=deterministic) - moments = hidden_states + moments = None if self.use_quant_conv: moments = self.quant_conv(hidden_states) - posterior = FlaxDiagonalGaussianDistribution(moments) + posterior = FlaxDiagonalGaussianDistribution(moments if moments else hidden_states) if not return_dict: return (posterior,) diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index cf7fb399d..858801248 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -16,7 +16,6 @@ import os import unittest -import pytest from absl.testing import absltest import numpy as np @@ -24,52 +23,32 @@ import jax import jax.numpy as jnp from maxdiffusion import FlaxAutoencoderKL -from skimage.metrics import structural_similarity as ssim THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" - class VaeTest(unittest.TestCase): """Test Vae""" def setUp(self): VaeTest.dummy_data = {} - - @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + def test_flux_vae(self): - + img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) - img_min = np.min(base_image) - img_max = np.max(base_image) - image = (base_image - img_min) / (img_max - img_min) - image = 2.0 * image - 1.0 - image = np.expand_dims(image, 0) - image = np.transpose(image, (0, 3, 1, 2)) # (1, 3, 1024, 1024), BCWH - + base_image = np.expand_dims(base_image, 0) + base_image = np.transpose(base_image, (0, 3, 1, 2)) # (1, 3, 1024, 1024), BCWH + vae, vae_params = FlaxAutoencoderKL.from_pretrained( - "black-forest-labs/FLUX.1-dev", subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" + "black-forest-labs/FLUX.1-dev", + subfolder="vae", + from_pt=True, + use_safetensors=True, + dtype="bfloat16" ) - encoded_image = vae.apply({"params": vae_params}, image, deterministic=True, method=vae.encode) + encoded_image = vae.apply({"params" : vae_params}, base_image, deterministic=True, method=vae.encode) latents = encoded_image[0].sample(jax.random.key(0)) latents = jnp.transpose(latents, (0, 3, 1, 2)) - latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor - - assert latents.shape == (1, 16, 128, 128) - - # decode back - latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor - image = vae.apply({"params": vae_params}, latents, deterministic=True, method=vae.decode).sample[0] - image = np.array(image) - image = (image * 0.5 + 0.5).clip(0, 1) - image = np.transpose(image, (1, 2, 0)) - image = np.uint8(image * 255) - ssim_compare = ssim(base_image, image, multichannel=True, channel_axis=-1, data_range=255) - assert ssim_compare >= 0.90 - - -if __name__ == "__main__": - absltest.main() + assert latents.shape == (1, 16, 128, 128) \ No newline at end of file From 5f5625726d4a0dcaad8a45242c2425af5d2e4e7b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 14 Jan 2025 20:43:42 +0000 Subject: [PATCH 02/35] test for flux vae both encoding and decoding. --- src/maxdiffusion/tests/vae_test.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index 858801248..8370ff5a3 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -23,6 +23,8 @@ import jax import jax.numpy as jnp from maxdiffusion import FlaxAutoencoderKL +from maxdiffusion.image_processor import VaeImageProcessor +from skimage.metrics import structural_similarity as ssim THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -36,8 +38,12 @@ def test_flux_vae(self): img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) - base_image = np.expand_dims(base_image, 0) - base_image = np.transpose(base_image, (0, 3, 1, 2)) # (1, 3, 1024, 1024), BCWH + img_min = np.min(base_image) + img_max = np.max(base_image) + image = (base_image - img_min) / (img_max - img_min) + image = 2.0 * image - 1.0 + image = np.expand_dims(image, 0) + image = np.transpose(image, (0, 3, 1, 2)) # (1, 3, 1024, 1024), BCWH vae, vae_params = FlaxAutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", @@ -47,8 +53,22 @@ def test_flux_vae(self): dtype="bfloat16" ) - encoded_image = vae.apply({"params" : vae_params}, base_image, deterministic=True, method=vae.encode) + encoded_image = vae.apply({"params" : vae_params}, image, deterministic=True, method=vae.encode) latents = encoded_image[0].sample(jax.random.key(0)) latents = jnp.transpose(latents, (0, 3, 1, 2)) - assert latents.shape == (1, 16, 128, 128) \ No newline at end of file + latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor + + assert latents.shape == (1, 16, 128, 128) + + # decode back + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + image = vae.apply({"params" : vae_params}, latents, deterministic=True, method=vae.decode).sample[0] + image = np.array(image) + image = (image * 0.5 + 0.5).clip(0, 1) + image = np.transpose(image, (1, 2, 0)) + image = np.uint8(image * 255) + ssim_compare = ssim(base_image, image, multichannel=True, channel_axis=-1, data_range=255) + assert ssim_compare >= 0.90 + + From c7829d1bb772d12e2269a354562e0dd007902d82 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 15 Jan 2025 19:03:17 +0000 Subject: [PATCH 03/35] add clip text encoder test --- src/maxdiffusion/tests/text_encoders_test.py | 67 +++++++++++++------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/src/maxdiffusion/tests/text_encoders_test.py b/src/maxdiffusion/tests/text_encoders_test.py index e7d3d6ddd..86d3bf612 100644 --- a/src/maxdiffusion/tests/text_encoders_test.py +++ b/src/maxdiffusion/tests/text_encoders_test.py @@ -16,45 +16,64 @@ import os import unittest -import pytest from absl.testing import absltest -from transformers import CLIPTokenizer, FlaxCLIPTextModel -from transformers import T5TokenizerFast, FlaxT5EncoderModel +import numpy as np +from PIL import Image +import jax +import jax.numpy as jnp -from ..generate_flux import get_clip_prompt_embeds, get_t5_prompt_embeds +from maxdiffusion.transformers import CLIPTokenizer, FlaxCLIPTextModel -IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" THIS_DIR = os.path.dirname(os.path.abspath(__file__)) - class TextEncoderTest(unittest.TestCase): """Test text encoders""" def setUp(self): TextEncoderTest.dummy_data = {} - - @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") - def test_flux_t5_text_encoder(self): - - text_encoder = FlaxT5EncoderModel.from_pretrained("ariG23498/t5-v1-1-xxl-flax") - - tokenizer_2 = T5TokenizerFast.from_pretrained("ariG23498/t5-v1-1-xxl-flax") - - embeds = get_t5_prompt_embeds("A dog on a skateboard", 2, tokenizer_2, text_encoder) - - assert embeds.shape == (2, 512, 4096) - - @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") - def test_flux_clip_text_encoder(self): + + def test_flux_text_encoders(self): + + def get_clip_prompt_embeds( + prompt, + num_images_per_prompt, + tokenizer, + text_encoder + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="np" + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False) + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1) + prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1)) + return prompt_embeds text_encoder = FlaxCLIPTextModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", subfolder="text_encoder", from_pt=True, dtype="bfloat16" + "black-forest-labs/FLUX.1-dev", + subfolder="text_encoder", + from_pt=True, + dtype="bfloat16" + ) + tokenizer = CLIPTokenizer.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="tokenizer", + dtype="bfloat16" ) - tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer", dtype="bfloat16") embeds = get_clip_prompt_embeds("A cat riding a skateboard", 2, tokenizer, text_encoder) assert embeds.shape == (2, 768) -if __name__ == "__main__": - absltest.main() From 572f20d88d1efead772b461e2ea70eda48d5a6dd Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 22 Jan 2025 03:15:56 +0000 Subject: [PATCH 04/35] remove transformers inside maxdiffusion, add transformers dependency. Start creating generation code for flux. --- requirements.txt | 5 +- .../base_stable_diffusion_checkpointer.py | 5 +- src/maxdiffusion/configs/base_flux.yml | 236 ++++++++ src/maxdiffusion/generate_flux.py | 504 +++++------------- src/maxdiffusion/models/flux/port.py | 223 ++++++++ src/maxdiffusion/models/flux/util.py | 190 +++---- src/maxdiffusion/models/vae_flax.py | 4 +- src/maxdiffusion/tests/text_encoders_test.py | 53 +- 8 files changed, 690 insertions(+), 530 deletions(-) create mode 100644 src/maxdiffusion/configs/base_flux.yml create mode 100644 src/maxdiffusion/models/flux/port.py diff --git a/requirements.txt b/requirements.txt index 09babb20f..e5ac624e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,8 +25,7 @@ ruff>=0.1.5,<=0.2 git+https://github.com/mlperf/logging.git opencv-python-headless==4.10.0.84 orbax-checkpoint==0.10.2 -tokenizers==0.21.0 +tokenizers==0.20.0 huggingface_hub==0.24.7 transformers==4.48.1 -einops==0.8.0 -sentencepiece \ No newline at end of file +einops==0.8.0 \ No newline at end of file diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index a7b597e36..92c7605d5 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -336,7 +336,10 @@ def load_checkpoint(self, step=None, scheduler_class=None): if self.checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT: te_pretrained_2_config = CLIPTextConfig(**model_configs[0]["text_encoder_2_config"]) text_encoder_2 = FlaxCLIPTextModelWithProjection( - te_pretrained_2_config, seed=self.config.seed, dtype=self.config.activations_dtype, _do_init=False + te_pretrained_2_config, + seed=self.config.seed, + dtype=self.config.activations_dtype, + _do_init=False ) pipeline_kwargs["text_encoder_2"] = text_encoder_2 # both tokenizers in sdxl are the same. diff --git a/src/maxdiffusion/configs/base_flux.yml b/src/maxdiffusion/configs/base_flux.yml new file mode 100644 index 000000000..9fe5ef5a6 --- /dev/null +++ b/src/maxdiffusion/configs/base_flux.yml @@ -0,0 +1,236 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' +unet_checkpoint: '' +revision: 'refs/pr/95' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" + +# Set true to load weights from pytorch +from_pt: False +split_head_dim: True +attention: 'dot_product' # Supported attention: dot_product, flash +flash_block_sizes: {} +# GroupNorm groups +norm_num_groups: 32 + +# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch +# else they will be loaded from pretrained_model_name_or_path +train_new_unet: False + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 4.e-7 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 200 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 2 + +warmup_steps_fraction: 0.0 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 1.e-2 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A magical castle in the middle of a forest, artistic drawing" +prompt_2: "A magical castle in the middle of a forest, artistic drawing" +negative_prompt: "purple, red" +do_classifier_free_guidance: True +guidance_scale: 9.0 +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 20 + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 1c221ee0b..609e37e7d 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -14,156 +14,48 @@ limitations under the License. """ -from typing import Callable, List, Union, Sequence +from typing import Any, Callable, Dict, List, Optional, Union, Sequence from absl import app -from contextlib import ExitStack -import functools -import math -import time + import numpy as np -from PIL import Image import jax -from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P import jax.numpy as jnp -import flax.linen as nn from chex import Array -from einops import rearrange -from flax.linen import partitioning as nn_partitioning -from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) - -from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging -from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel -from maxdiffusion.max_utils import ( - device_put_replicated, - get_memory_allocations, - create_device_mesh, - get_flash_block_sizes, - get_precision, - setup_initial_state, +from transformers import ( + CLIPTokenizer, + FlaxCLIPTextModel, + T5TokenizerFast, + T5EncoderModel ) -from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin - - -def maybe_load_flux_lora(config, lora_loader, params): - def _noop_interceptor(next_fn, args, kwargs, context): - return next_fn(*args, **kwargs) - - lora_config = config.lora_config - interceptors = [_noop_interceptor] - if len(lora_config["lora_model_name_or_path"]) > 0: - interceptors = [] - for i in range(len(lora_config["lora_model_name_or_path"])): - params, rank, network_alphas = lora_loader.load_lora_weights( - config, - lora_config["lora_model_name_or_path"][i], - weight_name=lora_config["weight_name"][i], - params=params, - adapter_name=lora_config["adapter_name"][i], - ) - interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i]) - interceptors.append(interceptor) - return params, interceptors - - -def unpack(x: Array, height: int, width: int) -> Array: - return rearrange( - x, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(height / 16), - w=math.ceil(width / 16), - ph=2, - pw=2, - ) - -def vae_decode(latents, vae, state, config): - img = unpack(x=latents, height=config.resolution, width=config.resolution) - img = img / vae.config.scaling_factor + vae.config.shift_factor - img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample - return img +from maxdiffusion import FlaxAutoencoderKL +from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel +from maxdiffusion import pyconfig -def loop_body( - step, - args, - transformer, - latent_image_ids, - prompt_embeds, - txt_ids, - vec, - guidance_vec, -): - latents, state, c_ts, p_ts = args - latents_dtype = latents.dtype - t_curr = c_ts[step] - t_prev = p_ts[step] - t_vec = jnp.full((latents.shape[0],), t_curr, dtype=latents.dtype) - pred = transformer.apply( - {"params": state.params}, - hidden_states=latents, - img_ids=latent_image_ids, - encoder_hidden_states=prompt_embeds, - txt_ids=txt_ids, - timestep=t_vec, - guidance=guidance_vec, - pooled_projections=vec, - ).sample - latents = latents + (t_prev - t_curr) * pred - latents = jnp.array(latents, dtype=latents_dtype) - return latents, state, c_ts, p_ts - - -def prepare_latent_image_ids(height, width): +def prepare_latent_image_ids(): latent_image_ids = jnp.zeros((height, width, 3)) - latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None]) - latent_image_ids = latent_image_ids.at[..., 2].set(jnp.arange(width)[None, :]) + latent_image_ids = latent_image_ids.at[..., 1].set( + latent_image_ids[..., 1] + jnp.arange(height)[:, None] + ) + latent_image_ids = latent_image_ids.at[..., 2].set( + latent_image_ids[..., 2] + jnp.arange(width)[None, :] + ) latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) - return latent_image_ids.astype(jnp.bfloat16) - - -def time_shift(mu: float, sigma: float, t: Array): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - -def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - -def run_inference( - states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts -): - - transformer_state = states["transformer"] - vae_state = states["vae"] - - loop_body_p = functools.partial( - loop_body, - transformer=transformer, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=txt_ids, - vec=vec, - guidance_vec=guidance_vec, + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config) - - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts)) - image = vae_decode_p(latents) - return image + return latent_image_ids.astype(jnp.bfloat16) def pack_latents( - latents: Array, - batch_size: int, - num_channels_latents: int, - height: int, - width: int, + latents: Array, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, ): latents = jnp.reshape(latents, (batch_size, num_channels_latents, height // 2, 2, width // 2, 2)) latents = jnp.permute_dims(latents, (0, 2, 4, 1, 3, 5)) @@ -171,15 +63,19 @@ def pack_latents( return latents - def prepare_latents( - batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + vae_scale_factor: int, + dtype: jnp.dtype, + rng: Array ): - # VAE applies 8x compression on images but we must also account for packing which # requires latent height and width to be divisibly by 2. height = 2 * (height // (vae_scale_factor * 2)) - width = 2 * (width // (vae_scale_factor * 2)) + width = 2 * (height // (vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) @@ -187,309 +83,177 @@ def prepare_latents( # pack latents latents = pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = prepare_latent_image_ids(height // 2, width // 2) - latent_image_ids = jnp.tile(latent_image_ids, (batch_size, 1, 1)) - - return latents, latent_image_ids - + latent_image_ids = prepare_latent_image_ids() + breakpoint() def get_clip_prompt_embeds( - prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel + prompt: Union[str, List[str]], + num_images_per_prompt : int, + tokenizer: CLIPTokenizer, + text_encoder : FlaxCLIPTextModel ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="np", + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="np" ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False) prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = jnp.tile(prompt_embeds, (batch_size * num_images_per_prompt, 1)) + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1) + prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1)) return prompt_embeds - def get_t5_prompt_embeds( - prompt: Union[str, List[str]], - num_images_per_prompt: int, - tokenizer: AutoTokenizer, - text_encoder: T5EncoderModel, - max_sequence_length: int = 512, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + tokenizer: T5TokenizerFast, + text_encoder: T5EncoderModel, + max_sequence_length: int = 512 ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( - prompt, - truncation=True, - max_length=max_sequence_length, - return_length=False, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="np", + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt" ) text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"] + + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False)[0] dtype = text_encoder.dtype - prompt_embeds = prompt_embeds.astype(dtype) + prompt_embeds = prompt_embeds.to(dtype=dtype) + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) - prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def encode_prompt( - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - clip_tokenizer: CLIPTokenizer, - clip_text_encoder: FlaxCLIPTextModel, - t5_tokenizer: AutoTokenizer, - t5_text_encoder: T5EncoderModel, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + clip_tokenizer: CLIPTokenizer, + clip_text_encoder: FlaxCLIPTextModel, + t5_tokenizer: T5TokenizerFast, + t5_text_encoder: T5EncoderModel, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512 ): - + prompt = [prompt] if isinstance(prompt, str) else prompt prompt_2 = prompt or prompt_2 prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 pooled_prompt_embeds = get_clip_prompt_embeds( - prompt=prompt, num_images_per_prompt=num_images_per_prompt, tokenizer=clip_tokenizer, text_encoder=clip_text_encoder + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + tokenizer=clip_tokenizer, + text_encoder=clip_text_encoder ) prompt_embeds = get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - tokenizer=t5_tokenizer, - text_encoder=t5_text_encoder, - max_sequence_length=max_sequence_length, + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + tokenizer=t5_tokenizer, + text_encoder=t5_text_encoder ) + prompt_embeds = jnp.asarray(prompt_embeds.detach().numpy()) text_ids = jnp.zeros((prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) return prompt_embeds, pooled_prompt_embeds, text_ids - def run(config): from maxdiffusion.models.flux.util import load_flow_model - rng = jax.random.key(config.seed) - devices_array = create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.PRNGKey(config.seed) - global_batch_size = config.per_device_batch_size * jax.local_device_count() + per_host_number_of_images = config.per_device_batch_size * jax.local_device_count() # LOAD VAE vae, vae_params = FlaxAutoencoderKL.from_pretrained( - config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" - ) - - weights_init_fn = functools.partial(vae.init_weights, rng=rng) - vae_state, vae_state_shardings = setup_initial_state( - model=vae, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=vae_params, - training=False, + config.pretrained_model_name_or_path, + subfolder="vae", + from_pt=True, + use_safetensors=True, + dtype="bfloat16" ) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - # LOAD TRANSFORMER - flash_block_sizes = get_flash_block_sizes(config) - transformer = FluxTransformer2DModel.from_config( - config.pretrained_model_name_or_path, - subfolder="transformer", - mesh=mesh, - split_head_dim=config.split_head_dim, - attention_kernel=config.attention, - flash_block_sizes=flash_block_sizes, - dtype=config.activations_dtype, - weights_dtype=config.weights_dtype, - precision=get_precision(config), - ) + # LOAD UNET + transformer = FluxTransformer2DModel.from_config(config.pretrained_model_name_or_path, subfolder="transformer") num_channels_latents = transformer.in_channels // 4 + latents, latent_image_ids = prepare_latents( - batch_size=global_batch_size, - num_channels_latents=num_channels_latents, - height=config.resolution, - width=config.resolution, - dtype=jnp.bfloat16, - vae_scale_factor=vae_scale_factor, - rng=rng, + batch_size=per_host_number_of_images, + num_channels_latents=num_channels_latents, + height=config.resolution, + width=config.resolution, + dtype=jnp.bfloat16, + vae_scale_factor=vae_scale_factor, + rng=rng ) - # LOAD TEXT ENCODERS - clip_text_encoder = FlaxCLIPTextModel.from_pretrained( - config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype - ) - clip_tokenizer = CLIPTokenizer.from_pretrained( - config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype - ) + load_flow_model("flux-dev", "cpu") - t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype) - t5_tokenizer = AutoTokenizer.from_pretrained( - config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True + transformer, params = FluxTransformer2DModel.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="text_encoder_2", + from_pt=True, + dtype=config.weights_dtype ) - encoders_sharding = PositionalSharding(devices_array).replicate() - partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) - clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) - clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) - t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) - t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - prompt=config.prompt, - prompt_2=config.prompt_2, - clip_tokenizer=clip_tokenizer, - clip_text_encoder=clip_text_encoder, - t5_tokenizer=t5_tokenizer, - t5_text_encoder=t5_encoder, - num_images_per_prompt=global_batch_size, - max_sequence_length=config.max_sequence_length, + # Initialize text encoders + clip_text_encoder = FlaxCLIPTextModel.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="text_encoder", + from_pt=True, + dtype=config.weights_dtype + ) + clip_tokenizer = CLIPTokenizer.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="tokenizer", + dtype=config.weights_dtype ) - def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): - print("latents.shape: ", latents.shape, latents.dtype) - print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype) - print("text_ids.shape: ", text_ids.shape, text_ids.dtype) - print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype) - print("timesteps.shape: ", timesteps.shape, timesteps.dtype) - print("guidance.shape: ", guidance.shape, guidance.dtype) - print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) - - guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) - - # move inputs to device and shard - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - latents = jax.device_put(latents, data_sharding) - latent_image_ids = jax.device_put(latent_image_ids) - prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - text_ids = jax.device_put(text_ids) - guidance = jax.device_put(guidance, data_sharding) - pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) - - if config.offload_encoders: - cpus = jax.devices("cpu") - t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) - - get_memory_allocations() - # evaluate shapes - transformer_eval_params = transformer.init_weights( - rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True + t5_encoder_pt = T5EncoderModel.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="text_encoder_2", ) - # loads pretrained weights - transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") - params = {} - params["transformer"] = transformer_params - # maybe load lora and create interceptor - lora_loader = FluxLoraLoaderMixin() - params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params) - transformer_params = params["transformer"] - # create transformer state - weights_init_fn = functools.partial( - transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False + t5_tokenizer = T5TokenizerFast.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="tokenizer_2", ) - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, - ) - transformer_state = transformer_state.replace(params=transformer_params) - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) - get_memory_allocations() - - states = {} - state_shardings = {} - - state_shardings["transformer"] = transformer_state_shardings - state_shardings["vae"] = vae_state_shardings - - states["transformer"] = transformer_state - states["vae"] = vae_state - - # Setup timesteps - timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) - # shifting the schedule to favor high timesteps for higher signal images - if config.time_shift: - # estimate mu based on linear estimation between two points - lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) - mu = lin_function(latents.shape[1]) - timesteps = time_shift(mu, 1.0, timesteps) - c_ts = timesteps[:-1] - p_ts = timesteps[1:] - - validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) - - p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - vae=vae, - config=config, - mesh=mesh, - latents=latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=text_ids, - vec=pooled_prompt_embeds, - guidance_vec=guidance, - c_ts=c_ts, - p_ts=p_ts, - ), - in_shardings=(state_shardings,), - out_shardings=None, + + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + prompt=config.prompt, + prompt_2=config.prompt_2, + clip_tokenizer=clip_tokenizer, + clip_text_encoder=clip_text_encoder, + t5_tokenizer=t5_tokenizer, + t5_text_encoder=t5_encoder_pt, ) - t0 = time.perf_counter() - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - p_run_inference(states).block_until_ready() - t1 = time.perf_counter() - max_logging.log(f"Compile time: {t1 - t0:.1f}s.") - - t0 = time.perf_counter() - with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"): - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - imgs = p_run_inference(states).block_until_ready() - t1 = time.perf_counter() - max_logging.log(f"Inference time: {t1 - t0:.1f}s.") - - t0 = time.perf_counter() - with ExitStack() as stack: - _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] - imgs = p_run_inference(states).block_until_ready() - imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) - t1 = time.perf_counter() - max_logging.log(f"Inference time: {t1 - t0:.1f}s.") - imgs = np.array(imgs) - imgs = (imgs * 0.5 + 0.5).clip(0, 1) - imgs = np.transpose(imgs, (0, 2, 3, 1)) - imgs = np.uint8(imgs * 255) - for i, image in enumerate(imgs): - Image.fromarray(image).save(f"flux_{i}.png") - - return imgs def main(argv: Sequence[str]) -> None: @@ -498,4 +262,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + app.run(main) \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/port.py b/src/maxdiffusion/models/flux/port.py new file mode 100644 index 000000000..1e9744ed3 --- /dev/null +++ b/src/maxdiffusion/models/flux/port.py @@ -0,0 +1,223 @@ +from einops import rearrange + +############################################################################################## +# FLUX MODEL PORTING +############################################################################################## + + +def port_linear(linear, tensors, prefix): + linear.kernel.value = rearrange(tensors[f"{prefix}.weight"], "i o -> o i") + linear.bias.value = tensors[f"{prefix}.bias"] + return linear + + +def port_modulation(modulation, tensors, prefix): + modulation.lin = port_linear( + linear=modulation.lin, tensors=tensors, prefix=f"{prefix}.lin" + ) + return modulation + + +def port_rms_norm(rms_norm, tensors, prefix): + rms_norm.scale.value = tensors[f"{prefix}.scale"] + return rms_norm + + +def port_qk_norm(qk_norm, tensors, prefix): + qk_norm.query_norm = port_rms_norm( + rms_norm=qk_norm.query_norm, + tensors=tensors, + prefix=f"{prefix}.query_norm", + ) + qk_norm.key_norm = port_rms_norm( + rms_norm=qk_norm.key_norm, + tensors=tensors, + prefix=f"{prefix}.key_norm", + ) + return qk_norm + + +def port_self_attention(self_attention, tensors, prefix): + self_attention.qkv = port_linear( + linear=self_attention.qkv, + tensors=tensors, + prefix=f"{prefix}.qkv", + ) + + self_attention.norm = port_qk_norm( + qk_norm=self_attention.norm, + tensors=tensors, + prefix=f"{prefix}.norm", + ) + + self_attention.proj = port_linear( + linear=self_attention.proj, + tensors=tensors, + prefix=f"{prefix}.proj", + ) + + return self_attention + + +def port_double_stream_block(double_stream_block, tensors, prefix): + double_stream_block.img_mod = port_modulation( + modulation=double_stream_block.img_mod, + tensors=tensors, + prefix=f"{prefix}.img_mod", + ) + + # double_stream_block.img_norm1 has no params + + double_stream_block.img_attn = port_self_attention( + self_attention=double_stream_block.img_attn, + tensors=tensors, + prefix=f"{prefix}.img_attn", + ) + + # double_stream_block.img_norm2 has no params + + double_stream_block.img_mlp.layers[0] = port_linear( + linear=double_stream_block.img_mlp.layers[0], + tensors=tensors, + prefix=f"{prefix}.img_mlp.0", + ) + double_stream_block.img_mlp.layers[2] = port_linear( + linear=double_stream_block.img_mlp.layers[2], + tensors=tensors, + prefix=f"{prefix}.img_mlp.2", + ) + + double_stream_block.txt_mod = port_modulation( + modulation=double_stream_block.txt_mod, + tensors=tensors, + prefix=f"{prefix}.txt_mod", + ) + + # double_stream_block.txt_norm1 has no params + + double_stream_block.txt_attn = port_self_attention( + self_attention=double_stream_block.txt_attn, + tensors=tensors, + prefix=f"{prefix}.txt_attn", + ) + + # double_stream_block.txt_norm2 has no params + + double_stream_block.txt_mlp.layers[0] = port_linear( + linear=double_stream_block.txt_mlp.layers[0], + tensors=tensors, + prefix=f"{prefix}.txt_mlp.0", + ) + double_stream_block.txt_mlp.layers[2] = port_linear( + linear=double_stream_block.txt_mlp.layers[2], + tensors=tensors, + prefix=f"{prefix}.txt_mlp.2", + ) + + return double_stream_block + + +def port_single_stream_block(single_stream_block, tensors, prefix): + single_stream_block.linear1 = port_linear( + linear=single_stream_block.linear1, tensors=tensors, prefix=f"{prefix}.linear1" + ) + single_stream_block.linear2 = port_linear( + linear=single_stream_block.linear2, tensors=tensors, prefix=f"{prefix}.linear2" + ) + + single_stream_block.norm = port_qk_norm( + qk_norm=single_stream_block.norm, tensors=tensors, prefix=f"{prefix}.norm" + ) + + # single_stream_block.pre_norm has no params + + single_stream_block.modulation = port_modulation( + modulation=single_stream_block.modulation, + tensors=tensors, + prefix=f"{prefix}.modulation", + ) + + return single_stream_block + + +def port_mlp_embedder(mlp_embedder, tensors, prefix): + mlp_embedder.in_layer = port_linear( + linear=mlp_embedder.in_layer, tensors=tensors, prefix=f"{prefix}.in_layer" + ) + + mlp_embedder.out_layer = port_linear( + linear=mlp_embedder.out_layer, tensors=tensors, prefix=f"{prefix}.out_layer" + ) + return mlp_embedder + + +def port_final_layer(final_layer, tensors, prefix): + # last_layer.norm_final has no params + final_layer.linear = port_linear( + linear=final_layer.linear, + tensors=tensors, + prefix=f"{prefix}.linear", + ) + + final_layer.adaLN_modulation.layers[1] = port_linear( + linear=final_layer.adaLN_modulation.layers[1], + tensors=tensors, + prefix=f"{prefix}.adaLN_modulation.1", + ) + + return final_layer + + +def port_flux(flux, tensors): + flux.img_in = port_linear( + linear=flux.img_in, + tensors=tensors, + prefix="img_in", + ) + + flux.time_in = port_mlp_embedder( + mlp_embedder=flux.time_in, + tensors=tensors, + prefix="time_in", + ) + + flux.vector_in = port_mlp_embedder( + mlp_embedder=flux.vector_in, + tensors=tensors, + prefix="vector_in", + ) + + if flux.params.guidance_embed: + flux.guidance_in = port_mlp_embedder( + mlp_embedder=flux.guidance_in, + tensors=tensors, + prefix="guidance_in", + ) + + flux.txt_in = port_linear( + linear=flux.txt_in, + tensors=tensors, + prefix="txt_in", + ) + + for i, layer in enumerate(flux.double_blocks.layers): + layer = port_double_stream_block( + double_stream_block=layer, + tensors=tensors, + prefix=f"double_blocks.{i}", + ) + + for i, layer in enumerate(flux.single_blocks.layers): + layer = port_single_stream_block( + single_stream_block=layer, + tensors=tensors, + prefix=f"single_blocks.{i}", + ) + + flux.final_layer = port_final_layer( + final_layer=flux.final_layer, + tensors=tensors, + prefix="final_layer", + ) + + return flux diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 362a39171..44ad90389 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -1,3 +1,4 @@ + # copied from https://github.com/ml-gde/jflux/blob/main/jflux/util.py import os from dataclasses import dataclass @@ -6,53 +7,51 @@ from jax.typing import DTypeLike import torch # need for torch 2 jax from chex import Array -from flax.traverse_util import flatten_dict, unflatten_dict +from flax import nnx from huggingface_hub import hf_hub_download from jax import numpy as jnp from safetensors import safe_open -from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor) -from maxdiffusion import max_logging - +# from jflux.model import Flux, FluxParams +from .port import port_flux @dataclass class FluxParams: - in_channels: int - vec_in_dim: int - context_in_dim: int - hidden_size: int - mlp_ratio: float - num_heads: int - depth: int - depth_single_blocks: int - axes_dim: list[int] - theta: int - qkv_bias: bool - guidance_embed: bool - rngs: Array - param_dtype: DTypeLike - + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + rngs: Array + param_dtype: DTypeLike def torch2jax(torch_tensor: torch.Tensor) -> Array: - is_bfloat16 = torch_tensor.dtype == torch.bfloat16 - if is_bfloat16: - # upcast the tensor to fp32 - torch_tensor = torch_tensor.float() + is_bfloat16 = torch_tensor.dtype == torch.bfloat16 + if is_bfloat16: + # upcast the tensor to fp32 + torch_tensor = torch_tensor.to(dtype=torch.float32) - if torch.device.type != "cpu": - torch_tensor = torch_tensor.to("cpu") + if torch.device.type != "cpu": + torch_tensor = torch_tensor.to("cpu") - numpy_value = torch_tensor.numpy() - jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) - return jax_array + numpy_value = torch_tensor.numpy() + jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) + return jax_array @dataclass class ModelSpec: - params: FluxParams - ckpt_path: str | None - repo_id: str | None - repo_flow: str | None + params: FluxParams + ckpt_path: str | None + repo_id: str | None + repo_flow: str | None configs = { @@ -73,6 +72,7 @@ class ModelSpec: theta=10_000, qkv_bias=True, guidance_embed=True, + #rngs=nnx.Rngs(default=42), rngs=jax.random.PRNGKey(42), param_dtype=jnp.bfloat16, ), @@ -94,6 +94,7 @@ class ModelSpec: theta=10_000, qkv_bias=True, guidance_embed=False, + #rngs=nnx.Rngs(default=42), rngs=jax.random.PRNGKey(47), param_dtype=jnp.bfloat16, ), @@ -102,92 +103,39 @@ class ModelSpec: def print_load_warning(missing: list[str], unexpected: list[str]) -> None: - if len(missing) > 0 and len(unexpected) > 0: - max_logging.log(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) - max_logging.log("\n" + "-" * 79 + "\n") - max_logging.log(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) - elif len(missing) > 0: - max_logging.log(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) - elif len(unexpected) > 0: - max_logging.log(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) - - -def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): - """ - expected_pytree: dict - a pytree that comes from initializing the model. - new_pytree: dict - a pytree that has been created from pytorch weights. - """ - expected_pytree = flatten_dict(expected_pytree) - if len(expected_pytree.keys()) != len(new_pytree.keys()): - set1 = set(expected_pytree.keys()) - set2 = set(new_pytree.keys()) - missing_keys = set1 ^ set2 - max_logging.log(f"missing keys : {missing_keys}") - for key in expected_pytree.keys(): - if key in new_pytree.keys(): - try: - expected_pytree_shape = expected_pytree[key].shape - except Exception: - expected_pytree_shape = expected_pytree[key].value.shape - if expected_pytree_shape != new_pytree[key].shape: - max_logging.log( - f"shape mismatch, expected shape of {expected_pytree[key].shape}, but got shape of {new_pytree[key].shape}" - ) - else: - max_logging.log(f"key: {key} not found...") - - -def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool = True): # -> Flux: - device = jax.devices(device)[0] - with jax.default_device(device): - ckpt_path = configs[name].ckpt_path - if ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download: - ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) - - max_logging.log(f"Load and port flux on {device}") - - if ckpt_path is not None: - tensors = {} - with safe_open(ckpt_path, framework="pt") as f: - for k in f.keys(): - tensors[k] = torch2jax(f.get_tensor(k)) - flax_state_dict = {} - cpu = jax.local_devices(backend="cpu")[0] - for pt_key, tensor in tensors.items(): - renamed_pt_key = rename_key(pt_key) - if "double_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("img_mlp_", "img_mlp.layers_") - renamed_pt_key = renamed_pt_key.replace("txt_mlp_", "txt_mlp.layers_") - renamed_pt_key = renamed_pt_key.replace("img_mod", "img_norm1") - renamed_pt_key = renamed_pt_key.replace("txt_mod", "txt_norm1") - renamed_pt_key = renamed_pt_key.replace("img_attn.qkv", "attn.i_qkv") - renamed_pt_key = renamed_pt_key.replace("img_attn.proj", "attn.i_proj") - renamed_pt_key = renamed_pt_key.replace("img_attn.norm", "attn") - renamed_pt_key = renamed_pt_key.replace("txt_attn.qkv", "attn.e_qkv") - renamed_pt_key = renamed_pt_key.replace("txt_attn.proj", "attn.e_proj") - renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.key_norm", "attn.encoder_key_norm") - renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.query_norm", "attn.encoder_query_norm") - elif "guidance_in" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("guidance_in", "time_text_embed.FlaxTimestepEmbedding_1") - renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") - renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") - elif "single_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("modulation", "norm") - renamed_pt_key = renamed_pt_key.replace("norm.key_norm", "attn.key_norm") - renamed_pt_key = renamed_pt_key.replace("norm.query_norm", "attn.query_norm") - elif "vector_in" in renamed_pt_key or "time_in" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("vector_in", "time_text_embed.PixArtAlphaTextProjection_0") - renamed_pt_key = renamed_pt_key.replace("time_in", "time_text_embed.FlaxTimestepEmbedding_0") - renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") - renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") - elif "final_layer" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("final_layer.linear", "proj_out") - renamed_pt_key = renamed_pt_key.replace("final_layer.adaLN_modulation_1", "norm_out.Dense_0") - pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) - flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) - validate_flax_state_dict(eval_shapes, flax_state_dict) - flax_state_dict = unflatten_dict(flax_state_dict) - del tensors - jax.clear_caches() - return flax_state_dict + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + + +def load_flow_model(name: str, device: str, hf_download: bool = True): # -> Flux: + device = jax.devices(device)[0] + with jax.default_device(device): + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + print(f"Load and port flux on {device}") + + #model = Flux(params=configs[name].params) + if ckpt_path is not None: + tensors = {} + with safe_open(ckpt_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + breakpoint() + model = port_flux(flux=model, tensors=tensors) + + del tensors + jax.clear_caches() + return model diff --git a/src/maxdiffusion/models/vae_flax.py b/src/maxdiffusion/models/vae_flax.py index b7bc3e4d4..dc9b00630 100644 --- a/src/maxdiffusion/models/vae_flax.py +++ b/src/maxdiffusion/models/vae_flax.py @@ -891,10 +891,10 @@ def encode(self, sample, deterministic: bool = True, return_dict: bool = True): sample = jnp.transpose(sample, (0, 2, 3, 1)) hidden_states = self.encoder(sample, deterministic=deterministic) - moments = None + moments = hidden_states if self.use_quant_conv: moments = self.quant_conv(hidden_states) - posterior = FlaxDiagonalGaussianDistribution(moments if moments else hidden_states) + posterior = FlaxDiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) diff --git a/src/maxdiffusion/tests/text_encoders_test.py b/src/maxdiffusion/tests/text_encoders_test.py index 86d3bf612..cf4ba0c1b 100644 --- a/src/maxdiffusion/tests/text_encoders_test.py +++ b/src/maxdiffusion/tests/text_encoders_test.py @@ -18,12 +18,10 @@ import unittest from absl.testing import absltest -import numpy as np -from PIL import Image -import jax -import jax.numpy as jnp +from transformers import CLIPTokenizer, FlaxCLIPTextModel +from transformers import T5TokenizerFast, T5EncoderModel -from maxdiffusion.transformers import CLIPTokenizer, FlaxCLIPTextModel +from ..generate_flux import get_clip_prompt_embeds, get_t5_prompt_embeds THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -33,34 +31,23 @@ class TextEncoderTest(unittest.TestCase): def setUp(self): TextEncoderTest.dummy_data = {} - def test_flux_text_encoders(self): - - def get_clip_prompt_embeds( - prompt, - num_images_per_prompt, - tokenizer, - text_encoder - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="np" - ) - - text_input_ids = text_inputs.input_ids - - prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False) - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1) - prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1)) - return prompt_embeds + def test_flux_t5_text_encoder(self): + + text_encoder_2_pt = T5EncoderModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="text_encoder_2", + ) + + tokenizer_2 = T5TokenizerFast.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="tokenizer_2", + ) + + embeds = get_t5_prompt_embeds("A dog on a skateboard", 2, tokenizer_2, text_encoder_2_pt) + + assert embeds.shape == (2, 512, 4096) + + def test_flux_clip_text_encoder(self): text_encoder = FlaxCLIPTextModel.from_pretrained( "black-forest-labs/FLUX.1-dev", From ff04543132f768781bd075260165e816d15155e9 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 22 Jan 2025 19:19:53 +0000 Subject: [PATCH 05/35] add double block to flux --- src/maxdiffusion/generate_flux.py | 45 ++- .../models/flux/modules/layers.py | 284 +++++++++++++++++- .../transformers/transformer_flux_flax.py | 147 +++++---- 3 files changed, 410 insertions(+), 66 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 609e37e7d..bab30902d 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -33,7 +33,7 @@ from maxdiffusion import pyconfig -def prepare_latent_image_ids(): +def prepare_latent_image_ids(height, width): latent_image_ids = jnp.zeros((height, width, 3)) latent_image_ids = latent_image_ids.at[..., 1].set( latent_image_ids[..., 1] + jnp.arange(height)[:, None] @@ -72,10 +72,11 @@ def prepare_latents( dtype: jnp.dtype, rng: Array ): + # VAE applies 8x compression on images but we must also account for packing which # requires latent height and width to be divisibly by 2. height = 2 * (height // (vae_scale_factor * 2)) - width = 2 * (height // (vae_scale_factor * 2)) + width = 2 * (width // (vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) @@ -83,8 +84,9 @@ def prepare_latents( # pack latents latents = pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = prepare_latent_image_ids() - breakpoint() + latent_image_ids = prepare_latent_image_ids(height // 2, width // 2) + + return latents, latent_image_ids def get_clip_prompt_embeds( prompt: Union[str, List[str]], @@ -200,9 +202,12 @@ def run(config): # LOAD UNET - transformer = FluxTransformer2DModel.from_config(config.pretrained_model_name_or_path, subfolder="transformer") + transformer = FluxTransformer2DModel.from_config( + config.pretrained_model_name_or_path, + subfolder="transformer" + ) + num_channels_latents = transformer.in_channels // 4 - latents, latent_image_ids = prepare_latents( batch_size=per_host_number_of_images, num_channels_latents=num_channels_latents, @@ -213,17 +218,17 @@ def run(config): rng=rng ) - load_flow_model("flux-dev", "cpu") + #load_flow_model("flux-dev", "cpu") - transformer, params = FluxTransformer2DModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder_2", - from_pt=True, - dtype=config.weights_dtype - ) + # transformer, params = FluxTransformer2DModel.from_pretrained( + # config.pretrained_model_name_or_path, + # subfolder="text_encoder_2", + # from_pt=True, + # dtype=config.weights_dtype + # ) - # Initialize text encoders + # LOAD TEXT ENCODERS - t5 on cpu clip_text_encoder = FlaxCLIPTextModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder", @@ -255,6 +260,18 @@ def run(config): t5_text_encoder=t5_encoder_pt, ) + transformer_params = transformer.init( + {"params" : rng}, + img=latents, + img_ids=latent_image_ids, + txt=prompt_embeds, + txt_ids=text_ids, + timesteps=[1.0], + y=pooled_prompt_embeds + )["params"] + breakpoint() + + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py index 3e4d5f083..95ff52097 100644 --- a/src/maxdiffusion/models/flux/modules/layers.py +++ b/src/maxdiffusion/models/flux/modules/layers.py @@ -16,11 +16,54 @@ import math from dataclasses import dataclass +from einops import rearrange import jax import jax.numpy as jnp from chex import Array from jax.typing import DTypeLike import flax.linen as nn +from ...attention_flax import AttentionOp +from .... import common_types + +BlockSizes = common_types.BlockSizes + +def rope(pos: Array, dim: int, theta: int) -> Array: + assert dim % 2 == 0 + scale = jnp.arange(0, dim, 2, dtype=jnp.float32) / dim + omega = 1.0 / (theta ** scale) + out = jnp.einsum("...n,d->...nd", pos, omega) + out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], axis=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.astype(jnp.float32) + +class QKNorm(nn.Module): + dtype: DTypeLike = jnp.bfloat16 + weights_dtype: DTypeLike = jnp.bfloat16 + + @nn.compact + def __call__(self, q: Array, k: Array, v: Array) -> tuple[Array, Array]: + q = nn.RMSNorm( + dtype=self.dtype, + param_dtype=self.weights_dtype + )(q) + k = nn.RMSNorm( + dtype=self.dtype, + param_dtype=self.weights_dtype + )(k) + return q, k + +class EmbedND(nn.Module): + dim: int + theta: int + axes_dim: list[int] + + def __call__(self, ids: Array): + n_axes = ids.shape[-1] + emb = jnp.concatenate( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], axis=-3, + ) + + return jnp.expand_dims(emb, axis=1) def timestep_embedding( t: Array, dim: int, max_period=10000, time_factor: float = 1000.0 @@ -38,6 +81,7 @@ def timestep_embedding( Returns: timestep embeddings. """ + breakpoint() t = time_factor * t half = dim // 2 @@ -92,4 +136,242 @@ def __call__(self, x: Array) -> Array: ) )(x) - return x \ No newline at end of file + return x + +@dataclass +class ModulationOut: + shift: Array + scale: Array + gate: Array + +class Modulation(nn.Module): + dim: int + double: bool + dtype: DTypeLike = jnp.bfloat16 + weights_dtype: DTypeLike = jnp.bfloat16 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]: + multiplier = 6 if self.double else 3 + lin = nn.Dense( + multiplier * self.dim, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ) + )(nn.silu(vec)) + out = jnp.split(lin[:, None, :], multiplier, axis=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:] if self.double else None) + ) + +class DoubleStreamBlock(nn.Module): + hidden_size: int + num_heads: int + mlp_ratio: float + attention_head_dim: int = 128 + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + qkv_bias: bool = False + attention_kernel: str = "dot_product" + + @nn.compact + def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array, Array]: + + mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) + + img_mod1, img_mod2 = Modulation( + self.hidden_size, + double=True, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(vec) + + txt_mod1, txt_mod2 = Modulation( + self.hidden_size, + double=True, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(vec) + + # prepare image for attention + img_modulated = nn.LayerNorm( + use_scale=False, + use_bias=False, + epsilon=1e-6, + dtype=self.dtype, + param_dtype=self.weights_dtype + )(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = nn.Dense( + self.hidden_size * 3, + use_bias=self.qkv_bias, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ) + )(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = QKNorm( + dtype=self.dtype, + weights_dtype=self.weights_dtype + )(img_q, img_k, img_v) + + # prepare text for attention + txt_modulated = nn.LayerNorm( + use_scale=False, + use_bias=False, + epsilon=1e-6, + dtype=self.dtype, + param_dtype=self.weights_dtype + )(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = nn.Dense( + self.hidden_size * 3, + use_bias=self.qkv_bias, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ) + )(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = QKNorm( + dtype=self.dtype, + weights_dtype=self.weights_dtype + )(txt_q, txt_k, txt_v) + + # run actual attention + q = jnp.concatenate((txt_q, img_q), axis=2) + k = jnp.concatenate((txt_k, img_k), axis=2) + v = jnp.concatenate((txt_v, img_v), axis=2) + + attn = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + scale=self.attention_head_dim**-0.5, + heads=self.num_heads, + dim_head=self.attention_head_dim, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=False, + split_head_dim=True, + flash_block_sizes=self.flash_block_sizes, + dtype=self.dtype + )(q, k, v) + + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + #calculate the img blocks + img = img + img_mod1.gate * nn.Dense( + self.hidden_size, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("heads", "embed") + ), + )(img_attn) + img = img + img_mod2.gate * nn.Sequential( + [ + nn.Dense( + mlp_hidden_dim, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ) + ), + nn.gelu, + nn.Dense( + self.hidden_size, + use_bias=True, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("heads", "embed") + ) + ) + ] + )( + (1 + img_mod2.scale) * nn.LayerNorm( + use_scale=False, + use_bias=False, + param_dtype=self.weights_dtype + )(img) + img_mod2.shift + ) + + # calculate the txt blocks + txt = txt + txt_mod1.gate * nn.Dense( + self.hidden_size, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("heads", "embed") + ), + )(txt_attn) + txt = txt + txt_mod2.gate * nn.Sequential( + [ + nn.Dense( + mlp_hidden_dim, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ) + ), + nn.gelu, + nn.Dense( + self.hidden_size, + use_bias=True, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("heads", "embed") + ) + ) + ] + )( + (1 + txt_mod2.scale) * nn.LayerNorm( + use_scale=False, + use_bias=False, + param_dtype=self.weights_dtype + )(txt) + txt_mod2.shift + ) + + return img, txt + diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 5645a84e4..4ea33882b 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -21,7 +21,7 @@ import flax.linen as nn from chex import Array -from ..modules.layers import timestep_embedding, MLPEmbedder +from ..modules.layers import timestep_embedding, MLPEmbedder, EmbedND, DoubleStreamBlock from ...modeling_flax_utils import FlaxModelMixin from ....configuration_utils import ConfigMixin, flax_register_to_config from ....common_types import BlockSizes @@ -51,21 +51,38 @@ class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): num_attention_heads: int = 24 joint_attention_dim: int = 4096 pooled_projection_dim: int = 768 + mlp_ratio: int = 4 + qkv_bias: bool = True guidance_embeds: bool = False axes_dims_rope: Tuple[int] = (16, 56, 56) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None + attention_kernel: str = "dot_product" mesh: jax.sharding.Mesh = None dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 precision: jax.lax.Precision = None + + @nn.compact + def __call__( + self, + img: Array, + img_ids: Array, + txt: Array, + txt_ids: Array, + timesteps: Array, + y: Array, + guidance: Array | None = None, + return_dict: bool = True, + train: bool = False): - def setup(self): - self.out_channels = self.in_channels - self.inner_dim = self.num_attention_heads * self.attention_head_dim + out_channels = self.in_channels + inner_dim = self.num_attention_heads * self.attention_head_dim + pe_dim = inner_dim // self.num_attention_heads - self.img_in = nn.Dense( - self.inner_dim, + #img = self.img_in(img) + img = nn.Dense( + inner_dim, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -73,62 +90,90 @@ def setup(self): nn.initializers.lecun_normal(), ("embed", "heads") ) - ) + )(img) - self.time_in = MLPEmbedder( - hidden_dim=self.inner_dim, + #vec = self.time_in(timestep_embedding(timesteps, 256)) + vec = MLPEmbedder( + hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision - ) + )(timestep_embedding(timesteps, 256)) - self.vector_in = MLPEmbedder( - hidden_dim=self.inner_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision - ) - - self.guidance_in = ( - MLPEmbedder( - hidden_dim=self.inner_dim, + if self.guidance_embeds: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distrilled model." + ) + + #vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + MLPEmbedder( + hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision - ) - if self.guidance_embeds - else Identity() - ) + )(timestep_embedding(guidance, 256)) + + #vec = vec + self.vector_in(y) + vec = vec + MLPEmbedder( + hidden_dim=inner_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(y) - self.txt_in = nn.Dense( - self.inner_dim, + #txt = self.txt_in(txt) + txt = nn.Dense( + inner_dim, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision - ) - - def __call__( - self, - img: Array, - img_ids: Array, - txt: Array, - txt_ids: Array, - timesteps: Array, - y: Array, - guidance: Array | None = None, - return_dict: bool = True, - train: bool = False): + )(txt) + + ids = jnp.concatenate((txt_ids, img_ids), axis=1) + + #pe_embedder + pe = EmbedND( + dim=pe_dim, + theta=10000, + axoes_dim=self.axes_dims_rope + )(ids) + + img, text = nn.scan( + DoubleStreamBlock, + variable_broadcast='params', + in_axes=0, out_axes=0, + split_rngs={'params' : False} + )( + hidden_size=self.hidden_size, + num_heads=self.num_attention_heads, + mlp_ratio=self.mlp_ratio, + attention_head_dim=self.attention_head_dim, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + qkv_bias=self.qkv_bias, + attention_kernel=self.attention_kernel, + + )(img=img, txt=txt, vec=vec, pe=pe) - img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256)) + return img, text - if self.guidance_embeds: - if guidance is None: - raise ValueError( - "Didn't get guidance strength for guidance distrilled model." - ) - - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) - - vec = vec + self.vector_in(y) - txt = self.txt_in(txt) \ No newline at end of file + # img = jnp.concatenate((txt, img), axis=1) + + # img = nn.scan( + # SingleStreamBlock, + # variable_broadcast='params', + # in_axes=0, out_axes=0, + # split_rngs={'params' : False} + # )(img, vec=vec, pe=pe) + + # img = img[:, txt.shape[1] :, ...] + + # img = LastLayer( + + # )(img, vec) # (N, T, patch_size ** 2 * out_channels) + # return img \ No newline at end of file From 8a0ede489fcc1dc790ec9d516537b50a230ec582 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 22 Jan 2025 22:04:37 +0000 Subject: [PATCH 06/35] forward pass for single double block. --- src/maxdiffusion/generate_flux.py | 30 +++++++++++++-- .../models/flux/modules/layers.py | 7 ++-- .../transformers/transformer_flux_flax.py | 38 ++++++++++++------- 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index bab30902d..995a86d49 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -85,6 +85,7 @@ def prepare_latents( latents = pack_latents(latents, batch_size, num_channels_latents, height, width) latent_image_ids = prepare_latent_image_ids(height // 2, width // 2) + latent_image_ids = jnp.tile(latent_image_ids, (batch_size, 1, 1)) return latents, latent_image_ids @@ -179,7 +180,7 @@ def encode_prompt( ) prompt_embeds = jnp.asarray(prompt_embeds.detach().numpy()) - text_ids = jnp.zeros((prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) + text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) return prompt_embeds, pooled_prompt_embeds, text_ids def run(config): @@ -187,7 +188,7 @@ def run(config): rng = jax.random.PRNGKey(config.seed) - per_host_number_of_images = config.per_device_batch_size * jax.local_device_count() + per_host_number_of_images = 1#config.per_device_batch_size * jax.local_device_count() # LOAD VAE @@ -258,15 +259,38 @@ def run(config): clip_text_encoder=clip_text_encoder, t5_tokenizer=t5_tokenizer, t5_text_encoder=t5_encoder_pt, + num_images_per_prompt=per_host_number_of_images ) + def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): + print("latents.shape: ", latents.shape) + print("latent_image_ids.shape: ", latent_image_ids.shape) + print("text_ids.shape: ", text_ids.shape) + print("prompt_embeds: ", prompt_embeds.shape) + print("timesteps.shape: ", timesteps.shape) + print("guidance.shape: ", guidance.shape) + print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape) + + timesteps = jnp.asarray([1.0], dtype=jnp.bfloat16) + guidance = jnp.asarray([3.5], dtype=jnp.bfloat16) + validate_inputs( + latents, + latent_image_ids, + prompt_embeds, + text_ids, + timesteps, + guidance, + pooled_prompt_embeds + ) + transformer_params = transformer.init( {"params" : rng}, img=latents, img_ids=latent_image_ids, txt=prompt_embeds, txt_ids=text_ids, - timesteps=[1.0], + timesteps=timesteps, + guidance=guidance, y=pooled_prompt_embeds )["params"] breakpoint() diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py index 95ff52097..4444ff669 100644 --- a/src/maxdiffusion/models/flux/modules/layers.py +++ b/src/maxdiffusion/models/flux/modules/layers.py @@ -81,7 +81,6 @@ def timestep_embedding( Returns: timestep embeddings. """ - breakpoint() t = time_factor * t half = dim // 2 @@ -188,7 +187,7 @@ class DoubleStreamBlock(nn.Module): @nn.compact def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array, Array]: - + mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) img_mod1, img_mod2 = Modulation( @@ -267,7 +266,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array q = jnp.concatenate((txt_q, img_q), axis=2) k = jnp.concatenate((txt_k, img_k), axis=2) v = jnp.concatenate((txt_v, img_v), axis=2) - + attn = AttentionOp( mesh=self.mesh, attention_kernel=self.attention_kernel, @@ -279,7 +278,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array split_head_dim=True, flash_block_sizes=self.flash_block_sizes, dtype=self.dtype - )(q, k, v) + ).apply_attention(q, k, v) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 4ea33882b..5b88fe41a 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -92,7 +92,6 @@ def __call__( ) )(img) - #vec = self.time_in(timestep_embedding(timesteps, 256)) vec = MLPEmbedder( hidden_dim=inner_dim, dtype=self.dtype, @@ -106,7 +105,6 @@ def __call__( "Didn't get guidance strength for guidance distrilled model." ) - #vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) vec = vec + MLPEmbedder( hidden_dim=inner_dim, dtype=self.dtype, @@ -114,7 +112,6 @@ def __call__( precision=self.precision )(timestep_embedding(guidance, 256)) - #vec = vec + self.vector_in(y) vec = vec + MLPEmbedder( hidden_dim=inner_dim, dtype=self.dtype, @@ -122,7 +119,6 @@ def __call__( precision=self.precision )(y) - #txt = self.txt_in(txt) txt = nn.Dense( inner_dim, dtype=self.dtype, @@ -136,16 +132,11 @@ def __call__( pe = EmbedND( dim=pe_dim, theta=10000, - axoes_dim=self.axes_dims_rope + axes_dim=self.axes_dims_rope )(ids) - img, text = nn.scan( - DoubleStreamBlock, - variable_broadcast='params', - in_axes=0, out_axes=0, - split_rngs={'params' : False} - )( - hidden_size=self.hidden_size, + img, text = DoubleStreamBlock( + hidden_size=inner_dim, num_heads=self.num_attention_heads, mlp_ratio=self.mlp_ratio, attention_head_dim=self.attention_head_dim, @@ -157,9 +148,30 @@ def __call__( precision=self.precision, qkv_bias=self.qkv_bias, attention_kernel=self.attention_kernel, - )(img=img, txt=txt, vec=vec, pe=pe) + # img, text = nn.scan( + # DoubleStreamBlock, + # variable_broadcast='params', + # in_axes=0, out_axes=0, + # split_rngs={'params' : False}, + # length=self.num_layers + # )( + # hidden_size=inner_dim, + # num_heads=self.num_attention_heads, + # mlp_ratio=self.mlp_ratio, + # attention_head_dim=self.attention_head_dim, + # flash_min_seq_length=self.flash_min_seq_length, + # flash_block_sizes=self.flash_block_sizes, + # mesh=self.mesh, + # dtype=self.dtype, + # weights_dtype=self.weights_dtype, + # precision=self.precision, + # qkv_bias=self.qkv_bias, + # attention_kernel=self.attention_kernel, + + # )(img=img, txt=txt, vec=vec, pe=pe) + return img, text # img = jnp.concatenate((txt, img), axis=1) From 9fe42ba1724d736c209cc9dba429a34c6da8d86b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 23 Jan 2025 22:08:45 +0000 Subject: [PATCH 07/35] trying to use scan. --- .../models/flux/modules/layers.py | 24 +++- .../transformers/transformer_flux_flax.py | 122 +++++++++++++----- 2 files changed, 111 insertions(+), 35 deletions(-) diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py index 4444ff669..ed593d7ed 100644 --- a/src/maxdiffusion/models/flux/modules/layers.py +++ b/src/maxdiffusion/models/flux/modules/layers.py @@ -36,6 +36,15 @@ def rope(pos: Array, dim: int, theta: int) -> Array: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.astype(jnp.float32) +def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: + xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype( + xk.dtype + ) + class QKNorm(nn.Module): dtype: DTypeLike = jnp.bfloat16 weights_dtype: DTypeLike = jnp.bfloat16 @@ -120,7 +129,8 @@ def __call__(self, x: Array) -> Array: kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("embed", "heads") - ) + ), + name="in_layer" )(x) x = nn.silu(x) x = nn.Dense( @@ -132,7 +142,8 @@ def __call__(self, x: Array) -> Array: kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("heads", "embed") - ) + ), + name="out_layer" )(x) return x @@ -164,6 +175,7 @@ def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]: ("embed", "heads") ) )(nn.silu(vec)) + out = jnp.split(lin[:, None, :], multiplier, axis=-1) return ( @@ -186,7 +198,7 @@ class DoubleStreamBlock(nn.Module): attention_kernel: str = "dot_product" @nn.compact - def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array, Array]: + def __call__(self, img: Array, txt: Array, vec: Array, pe: Array): mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) @@ -266,7 +278,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array q = jnp.concatenate((txt_q, img_q), axis=2) k = jnp.concatenate((txt_k, img_k), axis=2) v = jnp.concatenate((txt_v, img_v), axis=2) - + q, k = apply_rope(q, k, pe) + attn = AttentionOp( mesh=self.mesh, attention_kernel=self.attention_kernel, @@ -372,5 +385,4 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array )(txt) + txt_mod2.shift ) - return img, txt - + return img, txt, None \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 5b88fe41a..c9aa8dad8 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -30,6 +30,51 @@ class Identity(nn.Module): def __call__(self, x: Array) -> Array: return x +def scan_double_block_layers( + inner_dim, + num_heads, + mlp_ratio, + attention_head_dim, + flash_min_seq_length, + flash_block_sizes, + mesh, + dtype, + weights_dtype, + precision, + qkv_bias, + attention_kernel: str, + num_layers: int): + + scan_fn = nn.scan( + DoubleStreamBlock, + variable_broadcast='params', + in_axes=( + nn.broadcast, + nn.broadcast, + nn.broadcast, + nn.broadcast + ), + out_axes=( + nn.broadcast, + nn.broadcast, + ), + split_rngs={'params' : False}, + length=num_layers + ) + return scan_fn( + hidden_size=inner_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + attention_head_dim=attention_head_dim, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + qkv_bias=qkv_bias, + attention_kernel=attention_kernel) + class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" The Tranformer model introduced in Flux. @@ -80,7 +125,6 @@ def __call__( inner_dim = self.num_attention_heads * self.attention_head_dim pe_dim = inner_dim // self.num_attention_heads - #img = self.img_in(img) img = nn.Dense( inner_dim, dtype=self.dtype, @@ -89,14 +133,16 @@ def __call__( kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("embed", "heads") - ) + ), + name="img_in" )(img) vec = MLPEmbedder( hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision + precision=self.precision, + name="time_in" )(timestep_embedding(timesteps, 256)) if self.guidance_embeds: @@ -109,21 +155,28 @@ def __call__( hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision + precision=self.precision, + name="guidance_in" )(timestep_embedding(guidance, 256)) vec = vec + MLPEmbedder( hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision + precision=self.precision, + name="vector_in" )(y) txt = nn.Dense( inner_dim, dtype=self.dtype, param_dtype=self.weights_dtype, - precision=self.precision + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ), + name="text_in" )(txt) ids = jnp.concatenate((txt_ids, img_ids), axis=1) @@ -135,28 +188,7 @@ def __call__( axes_dim=self.axes_dims_rope )(ids) - img, text = DoubleStreamBlock( - hidden_size=inner_dim, - num_heads=self.num_attention_heads, - mlp_ratio=self.mlp_ratio, - attention_head_dim=self.attention_head_dim, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - qkv_bias=self.qkv_bias, - attention_kernel=self.attention_kernel, - )(img=img, txt=txt, vec=vec, pe=pe) - - # img, text = nn.scan( - # DoubleStreamBlock, - # variable_broadcast='params', - # in_axes=0, out_axes=0, - # split_rngs={'params' : False}, - # length=self.num_layers - # )( + # img, text = DoubleStreamBlock( # hidden_size=inner_dim, # num_heads=self.num_attention_heads, # mlp_ratio=self.mlp_ratio, @@ -169,10 +201,42 @@ def __call__( # precision=self.precision, # qkv_bias=self.qkv_bias, # attention_kernel=self.attention_kernel, + # name="double_blocks_0" + # )(img=img, txt=txt, vec=vec, pe=pe) + # img, text = DoubleStreamBlock( + # hidden_size=inner_dim, + # num_heads=self.num_attention_heads, + # mlp_ratio=self.mlp_ratio, + # attention_head_dim=self.attention_head_dim, + # flash_min_seq_length=self.flash_min_seq_length, + # flash_block_sizes=self.flash_block_sizes, + # mesh=self.mesh, + # dtype=self.dtype, + # weights_dtype=self.weights_dtype, + # precision=self.precision, + # qkv_bias=self.qkv_bias, + # attention_kernel=self.attention_kernel, + # name="double_blocks_1" # )(img=img, txt=txt, vec=vec, pe=pe) - return img, text + img, txt, _ = scan_double_block_layers( + inner_dim=inner_dim, + num_heads=self.num_attention_heads, + mlp_ratio=self.mlp_ratio, + attention_head_dim=self.attention_head_dim, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + qkv_bias=self.qkv_bias, + attention_kernel=self.attention_kernel, + num_layers=self.num_layers + )(img, txt, vec, pe) + + return img, txt # img = jnp.concatenate((txt, img), axis=1) From 7e79e05c22a058d03589fc4eca5daa1dcdce3847 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 24 Jan 2025 01:39:53 +0000 Subject: [PATCH 08/35] add single stream block --- .../models/flux/modules/layers.py | 83 ++++++++++++++++++- .../transformers/transformer_flux_flax.py | 43 +++++++--- 2 files changed, 113 insertions(+), 13 deletions(-) diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py index ed593d7ed..3fa6d44e2 100644 --- a/src/maxdiffusion/models/flux/modules/layers.py +++ b/src/maxdiffusion/models/flux/modules/layers.py @@ -183,6 +183,87 @@ def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]: ModulationOut(*out[3:] if self.double else None) ) +class SingleStreamBlock(nn.Module): + hidden_size: int + num_heads: int + mlp_ratio: float + qk_scale: float | None = None + attention_head_dim: int = 128 + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + attention_kernel: str = "dot_product" + + @nn.compact + def __call__(self, x: Array, vec: Array, pe: Array) -> Array: + mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) + + mod, _ = Modulation( + self.hidden_size, + double=False, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + )(vec) + x_mod = (1 + mod.scale) * nn.LayerNorm( + self.hidden_size, + use_scale=False, + use_bias=False, + epsilon=1e-6, + dtype=self.dtype, + param_dtype=self.weights_dtype + )(x) + mod.shift + + x_mod = nn.Dense( + self.hidden_size * 3 + mlp_hidden_dim, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ), + name="linear1" + )(x_mod) + + qkv, mlp = jnp.split(x_mod, [3 * self.hidden_size], axis=-1) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = QKNorm( + dtype=self.dtype, + weights_dtype=self.weights_dtype + )(q, k, v) + + q, k = apply_rope(q, k, pe) + #compute attention + attn = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + scale=self.attention_head_dim**-0.5, + heads=self.num_heads, + dim_head=self.attention_head_dim, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=False, + split_head_dim=True, + flash_block_sizes=self.flash_block_sizes, + dtype=self.dtype + ).apply_attention(q, k, v) + + output = nn.Dense( + self.hidden_size, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ), + name="linear2" + )(jnp.concatenate((attn, nn.genu(mlp)), 2)) + return x + mod.gate * output + class DoubleStreamBlock(nn.Module): hidden_size: int num_heads: int @@ -385,4 +466,4 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array): )(txt) + txt_mod2.shift ) - return img, txt, None \ No newline at end of file + return img, txt \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index c9aa8dad8..bd052a170 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -21,7 +21,13 @@ import flax.linen as nn from chex import Array -from ..modules.layers import timestep_embedding, MLPEmbedder, EmbedND, DoubleStreamBlock +from ..modules.layers import ( + timestep_embedding, + MLPEmbedder, + EmbedND, + DoubleStreamBlock, + SingleStreamBlock +) from ...modeling_flax_utils import FlaxModelMixin from ....configuration_utils import ConfigMixin, flax_register_to_config from ....common_types import BlockSizes @@ -49,15 +55,11 @@ def scan_double_block_layers( DoubleStreamBlock, variable_broadcast='params', in_axes=( - nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast ), - out_axes=( - nn.broadcast, - nn.broadcast, - ), + out_axes=nn.broadcast, split_rngs={'params' : False}, length=num_layers ) @@ -187,8 +189,8 @@ def __call__( theta=10000, axes_dim=self.axes_dims_rope )(ids) - - # img, text = DoubleStreamBlock( + # breakpoint() + # img, txt = DoubleStreamBlock( # hidden_size=inner_dim, # num_heads=self.num_attention_heads, # mlp_ratio=self.mlp_ratio, @@ -203,8 +205,8 @@ def __call__( # attention_kernel=self.attention_kernel, # name="double_blocks_0" # )(img=img, txt=txt, vec=vec, pe=pe) - - # img, text = DoubleStreamBlock( + # breakpoint() + # img, txt = DoubleStreamBlock( # hidden_size=inner_dim, # num_heads=self.num_attention_heads, # mlp_ratio=self.mlp_ratio, @@ -219,8 +221,8 @@ def __call__( # attention_kernel=self.attention_kernel, # name="double_blocks_1" # )(img=img, txt=txt, vec=vec, pe=pe) - - img, txt, _ = scan_double_block_layers( + # breakpoint() + img, txt = scan_double_block_layers( inner_dim=inner_dim, num_heads=self.num_attention_heads, mlp_ratio=self.mlp_ratio, @@ -236,6 +238,23 @@ def __call__( num_layers=self.num_layers )(img, txt, vec, pe) + # SingleStreamBlock( + # hidden_size=inner_dim, + # num_heads=self.num_attention_heads, + # mlp_ratio=self.mlp_ratio, + # attention_head_dim=self.attention_head_dim, + # flash_min_seq_length=self.flash_min_seq_length, + # flash_block_sizes=self.flash_block_sizes, + # mesh=self.mesh, + # dtype=self.dtype, + # weights_dtype=self.weights_dtype, + # precision=self.precision, + # attention_kernel=self.attention_kernel + # )(img, vec, pe) + + img = img[:, txt.shape[1] :, ...] + + return img, txt # img = jnp.concatenate((txt, img), axis=1) From 6641dda43f594ede042d5de39ac77bbe5aa52212 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 29 Jan 2025 18:41:07 +0000 Subject: [PATCH 09/35] finish transformer --- src/maxdiffusion/configs/base_flux.yml | 7 +- src/maxdiffusion/generate_flux.py | 96 ++++++---- .../models/flux/modules/layers.py | 64 ++++++- .../transformers/transformer_flux_flax.py | 178 +++++++++++++----- 4 files changed, 252 insertions(+), 93 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux.yml b/src/maxdiffusion/configs/base_flux.yml index 9fe5ef5a6..e6d9f15a9 100644 --- a/src/maxdiffusion/configs/base_flux.yml +++ b/src/maxdiffusion/configs/base_flux.yml @@ -24,6 +24,9 @@ save_config_to_gcs: False log_period: 100 pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' +clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' +t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' + unet_checkpoint: '' revision: 'refs/pr/95' # This will convert the weights to this dtype. @@ -41,7 +44,7 @@ precision: "DEFAULT" # Set true to load weights from pytorch from_pt: False split_head_dim: True -attention: 'dot_product' # Supported attention: dot_product, flash +attention: 'flash' # Supported attention: dot_product, flash flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 @@ -171,7 +174,7 @@ max_train_steps: 200 num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' -per_device_batch_size: 2 +per_device_batch_size: 1 warmup_steps_fraction: 0.0 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 995a86d49..4d183c20f 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -16,22 +16,30 @@ from typing import Any, Callable, Dict, List, Optional, Union, Sequence from absl import app - +import functools import numpy as np import jax +from jax.sharding import Mesh, PositionalSharding import jax.numpy as jnp from chex import Array from transformers import ( CLIPTokenizer, FlaxCLIPTextModel, T5TokenizerFast, - T5EncoderModel + T5EncoderModel, + FlaxT5EncoderModel ) from maxdiffusion import FlaxAutoencoderKL from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel - from maxdiffusion import pyconfig +from max_utils import ( + device_put_replicated, + get_memory_allocations, + create_device_mesh, + get_flash_block_sizes, + get_precision +) def prepare_latent_image_ids(height, width): latent_image_ids = jnp.zeros((height, width, 3)) @@ -133,19 +141,17 @@ def get_t5_prompt_embeds( truncation=True, return_length=False, return_overflowing_tokens=False, - return_tensors="pt" + return_tensors="np" ) text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False)[0] dtype = text_encoder.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = prompt_embeds.astype(dtype) _, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = jnp.reshape(prompt_embeds, (batch_size * num_images_per_prompt, seq_len, -1)) return prompt_embeds @@ -178,7 +184,6 @@ def encode_prompt( tokenizer=t5_tokenizer, text_encoder=t5_text_encoder ) - prompt_embeds = jnp.asarray(prompt_embeds.detach().numpy()) text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -186,7 +191,9 @@ def encode_prompt( def run(config): from maxdiffusion.models.flux.util import load_flow_model - rng = jax.random.PRNGKey(config.seed) + rng = jax.random.key(config.seed) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) per_host_number_of_images = 1#config.per_device_batch_size * jax.local_device_count() @@ -201,11 +208,18 @@ def run(config): ) vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - # LOAD UNET - + # LOAD TRANSFORMER + flash_block_sizes = get_flash_block_sizes(config) transformer = FluxTransformer2DModel.from_config( config.pretrained_model_name_or_path, - subfolder="transformer" + subfolder="transformer", + mesh=mesh, + split_head_dim=config.split_head_dim, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, + precision=get_precision(config) ) num_channels_latents = transformer.in_channels // 4 @@ -242,34 +256,40 @@ def run(config): dtype=config.weights_dtype ) - t5_encoder_pt = T5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder_2", + t5_encoder = FlaxT5EncoderModel.from_pretrained( + config.clip_model_name_or_path, + dtype=config.weights_dtype ) - t5_tokenizer = T5TokenizerFast.from_pretrained( config.pretrained_model_name_or_path, subfolder="tokenizer_2", ) + encoders_sharding = PositionalSharding(devices_array).replicate() + partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) + clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) + clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) + t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) + t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( prompt=config.prompt, prompt_2=config.prompt_2, clip_tokenizer=clip_tokenizer, clip_text_encoder=clip_text_encoder, t5_tokenizer=t5_tokenizer, - t5_text_encoder=t5_encoder_pt, + t5_text_encoder=t5_encoder, num_images_per_prompt=per_host_number_of_images ) def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): - print("latents.shape: ", latents.shape) - print("latent_image_ids.shape: ", latent_image_ids.shape) - print("text_ids.shape: ", text_ids.shape) - print("prompt_embeds: ", prompt_embeds.shape) - print("timesteps.shape: ", timesteps.shape) - print("guidance.shape: ", guidance.shape) - print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape) + print("latents.shape: ", latents.shape, latents.dtype) + print("latent_image_ids.shape: ", latent_image_ids.shape, latent_image_ids.dtype) + print("text_ids.shape: ", text_ids.shape, text_ids.dtype) + print("prompt_embeds: ", prompt_embeds.shape, prompt_embeds.dtype) + print("timesteps.shape: ", timesteps.shape, timesteps.dtype) + print("guidance.shape: ", guidance.shape, guidance.dtype) + print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) timesteps = jnp.asarray([1.0], dtype=jnp.bfloat16) guidance = jnp.asarray([3.5], dtype=jnp.bfloat16) @@ -282,17 +302,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep guidance, pooled_prompt_embeds ) - - transformer_params = transformer.init( - {"params" : rng}, - img=latents, - img_ids=latent_image_ids, - txt=prompt_embeds, - txt_ids=text_ids, - timesteps=timesteps, - guidance=guidance, - y=pooled_prompt_embeds - )["params"] + get_memory_allocations() + transformer_params = transformer.init_weights(rng, True) + # transformer_params = transformer.init( + # {"params" : rng}, + # img=latents, + # img_ids=latent_image_ids, + # txt=prompt_embeds, + # txt_ids=text_ids, + # timesteps=timesteps, + # guidance=guidance, + # y=pooled_prompt_embeds + # )["params"] + get_memory_allocations() breakpoint() diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py index 3fa6d44e2..574b8a747 100644 --- a/src/maxdiffusion/models/flux/modules/layers.py +++ b/src/maxdiffusion/models/flux/modules/layers.py @@ -177,10 +177,9 @@ def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]: )(nn.silu(vec)) out = jnp.split(lin[:, None, :], multiplier, axis=-1) - return ( ModulationOut(*out[:3]), - ModulationOut(*out[3:] if self.double else None) + ModulationOut(*out[3:]) if self.double else None ) class SingleStreamBlock(nn.Module): @@ -209,7 +208,6 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array: precision=self.precision )(vec) x_mod = (1 + mod.scale) * nn.LayerNorm( - self.hidden_size, use_scale=False, use_bias=False, epsilon=1e-6, @@ -261,7 +259,7 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array: ("embed", "heads") ), name="linear2" - )(jnp.concatenate((attn, nn.genu(mlp)), 2)) + )(jnp.concatenate((attn, nn.gelu(mlp)), 2)) return x + mod.gate * output class DoubleStreamBlock(nn.Module): @@ -279,7 +277,7 @@ class DoubleStreamBlock(nn.Module): attention_kernel: str = "dot_product" @nn.compact - def __call__(self, img: Array, txt: Array, vec: Array, pe: Array): + def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array, Array]: mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) @@ -422,7 +420,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array): ) # calculate the txt blocks - txt = txt + txt_mod1.gate * nn.Dense( + txt_proj = nn.Dense( self.hidden_size, use_bias=True, dtype=self.dtype, @@ -433,6 +431,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array): ("heads", "embed") ), )(txt_attn) + txt = txt + txt_mod1.gate * txt_proj + txt = txt + txt_mod2.gate * nn.Sequential( [ nn.Dense( @@ -466,4 +466,54 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array): )(txt) + txt_mod2.shift ) - return img, txt \ No newline at end of file + return img, txt + +class LastLayer(nn.Module): + hidden_size: int + patch_size: int + out_channels: int + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, x: Array, vec: Array) -> Array: + shift, scale = jnp.split( + nn.Sequential( + [ + nn.silu, + nn.Dense( + 2 * self.hidden_size, + use_bias=True, + param_dtype=self.weights_dtype, + dtype=self.dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ) + ) + ] + )(vec), 2, axis=1 + ) + norm_final = nn.LayerNorm( + epsilon=1e-6, + use_scale=False, + use_bias=False, + param_dtype=self.weights_dtype, + name="norm_final" + )(x) + x = (1 + scale[:, None, :]) * norm_final + shift[:, None, :] + x = nn.Dense( + self.patch_size * self.patch_size * self.out_channels, + use_bias=True, + param_dtype=self.weights_dtype, + dtype=self.dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("heads", "embed") + ), + name="linear" + ) + return x \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index bd052a170..fe54bb6d3 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -16,6 +16,7 @@ from typing import Dict, Optional, Tuple, Union +from einops import repeat, rearrange import jax import jax.numpy as jnp import flax.linen as nn @@ -26,7 +27,8 @@ MLPEmbedder, EmbedND, DoubleStreamBlock, - SingleStreamBlock + SingleStreamBlock, + LastLayer ) from ...modeling_flax_utils import FlaxModelMixin from ....configuration_utils import ConfigMixin, flax_register_to_config @@ -205,9 +207,46 @@ def __call__( # attention_kernel=self.attention_kernel, # name="double_blocks_0" # )(img=img, txt=txt, vec=vec, pe=pe) + # # breakpoint() + for _ in range(self.num_layers): + img, txt = DoubleStreamBlock( + hidden_size=inner_dim, + num_heads=self.num_attention_heads, + mlp_ratio=self.mlp_ratio, + attention_head_dim=self.attention_head_dim, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + qkv_bias=self.qkv_bias, + attention_kernel=self.attention_kernel, + )(img=img, txt=txt, vec=vec, pe=pe) + # img, txt = nn.Sequential( + # [ + # *[ + # DoubleStreamBlock( + # hidden_size=inner_dim, + # num_heads=self.num_attention_heads, + # mlp_ratio=self.mlp_ratio, + # attention_head_dim=self.attention_head_dim, + # flash_min_seq_length=self.flash_min_seq_length, + # flash_block_sizes=self.flash_block_sizes, + # mesh=self.mesh, + # dtype=self.dtype, + # weights_dtype=self.weights_dtype, + # precision=self.precision, + # qkv_bias=self.qkv_bias, + # attention_kernel=self.attention_kernel, + # )(img=img, txt=txt, vec=vec, pe=pe) + # for _ in range(2) + # ] + # ] + # ) # breakpoint() - # img, txt = DoubleStreamBlock( - # hidden_size=inner_dim, + # img, txt = scan_double_block_layers( + # inner_dim=inner_dim, # num_heads=self.num_attention_heads, # mlp_ratio=self.mlp_ratio, # attention_head_dim=self.attention_head_dim, @@ -219,56 +258,101 @@ def __call__( # precision=self.precision, # qkv_bias=self.qkv_bias, # attention_kernel=self.attention_kernel, - # name="double_blocks_1" - # )(img=img, txt=txt, vec=vec, pe=pe) - # breakpoint() - img, txt = scan_double_block_layers( - inner_dim=inner_dim, - num_heads=self.num_attention_heads, - mlp_ratio=self.mlp_ratio, - attention_head_dim=self.attention_head_dim, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - qkv_bias=self.qkv_bias, - attention_kernel=self.attention_kernel, - num_layers=self.num_layers - )(img, txt, vec, pe) - - # SingleStreamBlock( - # hidden_size=inner_dim, - # num_heads=self.num_attention_heads, - # mlp_ratio=self.mlp_ratio, - # attention_head_dim=self.attention_head_dim, - # flash_min_seq_length=self.flash_min_seq_length, - # flash_block_sizes=self.flash_block_sizes, - # mesh=self.mesh, - # dtype=self.dtype, - # weights_dtype=self.weights_dtype, - # precision=self.precision, - # attention_kernel=self.attention_kernel - # )(img, vec, pe) + # num_layers=self.num_layers + # )(img, txt, vec, pe) + img = jnp.concatenate((txt, img), axis=1) + for _ in range(self.num_single_layers): + img, SingleStreamBlock( + hidden_size=inner_dim, + num_heads=self.num_attention_heads, + mlp_ratio=self.mlp_ratio, + attention_head_dim=self.attention_head_dim, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + attention_kernel=self.attention_kernel + )(img, vec, pe) img = img[:, txt.shape[1] :, ...] + LastLayer( + hidden_size=inner_dim, + patch_size=1, + out_channels=out_channels, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + name="final_layer" + ) - return img, txt + return img + + def init_weights(self, rngs, eval_only=True): + scale_factor = 16 + resolution = 1024 + num_devices = len(jax.devices()) + batch_size = 1 * num_devices + batch_image_shape = ( + batch_size, + 16, # 16 to match jflux.get_noise + 2 * resolution // scale_factor, + 2 * resolution // scale_factor, + ) + # bs, encoder_input, seq_length + text_shape = ( + batch_size, + 256, + 4096, # Sequence length of text encoder, how to get this programmatically? + ) + text_ids_shape = ( + batch_size, + 256, + 3, # Hardcoded to match jflux.prepare + ) + vec_shape = ( + batch_size, + 768, # Sequence length of clip, how to get this programmatically? + ) + img = jnp.zeros(batch_image_shape, dtype=self.dtype) + bs, c, h, w = img.shape + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + img_ids = jnp.zeros((h // 2, w // 2, 3), dtype=self.dtype) + img_ids = img_ids.at[..., 1].set(jnp.arange(h // 2)[:, None]) + img_ids = img_ids.at[..., 2].set(jnp.arange(w // 2)[None, :]) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) - # img = jnp.concatenate((txt, img), axis=1) + txt = jnp.zeros(text_shape, dtype=self.dtype) + txt_ids = jnp.zeros(text_ids_shape, dtype=self.dtype) - # img = nn.scan( - # SingleStreamBlock, - # variable_broadcast='params', - # in_axes=0, out_axes=0, - # split_rngs={'params' : False} - # )(img, vec=vec, pe=pe) + t_vec = jnp.full(bs, 0, dtype=self.dtype) - # img = img[:, txt.shape[1] :, ...] + vec = jnp.zeros(vec_shape, dtype=self.dtype) - # img = LastLayer( + guidance_vec = jnp.full(bs, 4.0, dtype=self.dtype) - # )(img, vec) # (N, T, patch_size ** 2 * out_channels) - # return img \ No newline at end of file + if eval_only: + return jax.eval_shape( + self.init, + rngs, + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + )["params"] + else: + return self.init( + rngs, + hidden_states=img, + img_ids=img_ids, + encoder_hidden_states=txt, + txt_ids=txt_ids, + y=vec, + timestep=t_vec, + guidance=guidance_vec, + )["params"] \ No newline at end of file From d37a278ae6388e8b00d658cdeb95ed1bb2062bea Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 30 Jan 2025 00:50:44 +0000 Subject: [PATCH 10/35] convert pt weights to flax and load transformer state. --- src/maxdiffusion/generate_flux.py | 55 +++++++------ .../models/flux/modules/layers.py | 77 ++++++++++++------- .../transformers/transformer_flux_flax.py | 28 +++---- src/maxdiffusion/models/flux/util.py | 74 ++++++++++++++---- 4 files changed, 154 insertions(+), 80 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 4d183c20f..6e95deb20 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -38,7 +38,8 @@ get_memory_allocations, create_device_mesh, get_flash_block_sizes, - get_precision + get_precision, + setup_initial_state ) def prepare_latent_image_ids(height, width): @@ -195,7 +196,7 @@ def run(config): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - per_host_number_of_images = 1#config.per_device_batch_size * jax.local_device_count() + per_host_number_of_images = config.per_device_batch_size * jax.local_device_count() # LOAD VAE @@ -233,16 +234,6 @@ def run(config): rng=rng ) - #load_flow_model("flux-dev", "cpu") - - # transformer, params = FluxTransformer2DModel.from_pretrained( - # config.pretrained_model_name_or_path, - # subfolder="text_encoder_2", - # from_pt=True, - # dtype=config.weights_dtype - # ) - - # LOAD TEXT ENCODERS - t5 on cpu clip_text_encoder = FlaxCLIPTextModel.from_pretrained( config.pretrained_model_name_or_path, @@ -303,17 +294,35 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep pooled_prompt_embeds ) get_memory_allocations() - transformer_params = transformer.init_weights(rng, True) - # transformer_params = transformer.init( - # {"params" : rng}, - # img=latents, - # img_ids=latent_image_ids, - # txt=prompt_embeds, - # txt_ids=text_ids, - # timesteps=timesteps, - # guidance=guidance, - # y=pooled_prompt_embeds - # )["params"] + # evaluate shapes + transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True) + + # loads pretrained weights + transformer_params = load_flow_model("flux-dev", transformer_eval_params, "cpu") + get_memory_allocations() + # create transformer state + weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=512, eval_only=False) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False + ) + breakpoint() + transformer_state = transformer_state.replace(params=transformer_params) + img = transformer.apply( + {"params" : transformer_state.params}, + img=latents, + img_ids=latent_image_ids, + txt=prompt_embeds, + txt_ids=text_ids, + timesteps=timesteps, + guidance=guidance, + y=pooled_prompt_embeds + ) get_memory_allocations() breakpoint() diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py index 574b8a747..91562d406 100644 --- a/src/maxdiffusion/models/flux/modules/layers.py +++ b/src/maxdiffusion/models/flux/modules/layers.py @@ -53,11 +53,13 @@ class QKNorm(nn.Module): def __call__(self, q: Array, k: Array, v: Array) -> tuple[Array, Array]: q = nn.RMSNorm( dtype=self.dtype, - param_dtype=self.weights_dtype + param_dtype=self.weights_dtype, + name="query_norm" )(q) k = nn.RMSNorm( dtype=self.dtype, - param_dtype=self.weights_dtype + param_dtype=self.weights_dtype, + name="key_norm" )(k) return q, k @@ -173,7 +175,8 @@ def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]: kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("embed", "heads") - ) + ), + name="lin" )(nn.silu(vec)) out = jnp.split(lin[:, None, :], multiplier, axis=-1) @@ -205,14 +208,16 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array: double=False, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision + precision=self.precision, + name="modulation" )(vec) x_mod = (1 + mod.scale) * nn.LayerNorm( use_scale=False, use_bias=False, epsilon=1e-6, dtype=self.dtype, - param_dtype=self.weights_dtype + param_dtype=self.weights_dtype, + name="pre_norm" )(x) + mod.shift x_mod = nn.Dense( @@ -231,7 +236,8 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array: q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = QKNorm( dtype=self.dtype, - weights_dtype=self.weights_dtype + weights_dtype=self.weights_dtype, + name="norm" )(q, k, v) q, k = apply_rope(q, k, pe) @@ -286,7 +292,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array double=True, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision + precision=self.precision, + name="img_mod" )(vec) txt_mod1, txt_mod2 = Modulation( @@ -294,7 +301,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array double=True, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision + precision=self.precision, + name="txt_mod" )(vec) # prepare image for attention @@ -303,7 +311,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array use_bias=False, epsilon=1e-6, dtype=self.dtype, - param_dtype=self.weights_dtype + param_dtype=self.weights_dtype, + name="img_norm1" )(img) img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_qkv = nn.Dense( @@ -315,14 +324,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("embed", "heads") - ) + ), + name="img_attn_qkv" )(img_modulated) img_q, img_k, img_v = rearrange( img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads ) img_q, img_k = QKNorm( dtype=self.dtype, - weights_dtype=self.weights_dtype + weights_dtype=self.weights_dtype, + name="img_attn_norm" )(img_q, img_k, img_v) # prepare text for attention @@ -331,7 +342,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array use_bias=False, epsilon=1e-6, dtype=self.dtype, - param_dtype=self.weights_dtype + param_dtype=self.weights_dtype, + name="txt_norm1" )(txt) txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = nn.Dense( @@ -343,14 +355,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("embed", "heads") - ) + ), + name="txt_attn_qkv" )(txt_modulated) txt_q, txt_k, txt_v = rearrange( txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads ) txt_q, txt_k = QKNorm( dtype=self.dtype, - weights_dtype=self.weights_dtype + weights_dtype=self.weights_dtype, + name="txt_attn_norm" )(txt_q, txt_k, txt_v) # run actual attention @@ -385,6 +399,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array nn.initializers.lecun_normal(), ("heads", "embed") ), + name="img_attn_proj" )(img_attn) img = img + img_mod2.gate * nn.Sequential( [ @@ -397,7 +412,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("embed", "heads") - ) + ), + name="img_mlp_0" ), nn.gelu, nn.Dense( @@ -408,14 +424,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("heads", "embed") - ) - ) - ] + ), + name="img_mlp_2" + ), + ], )( (1 + img_mod2.scale) * nn.LayerNorm( use_scale=False, use_bias=False, - param_dtype=self.weights_dtype + param_dtype=self.weights_dtype, + name="img_norm2" )(img) + img_mod2.shift ) @@ -430,6 +448,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array nn.initializers.lecun_normal(), ("heads", "embed") ), + name="txt_attn_proj" )(txt_attn) txt = txt + txt_mod1.gate * txt_proj @@ -444,7 +463,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("embed", "heads") - ) + ), + name="txt_mlp_0" ), nn.gelu, nn.Dense( @@ -455,14 +475,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("heads", "embed") - ) - ) - ] + ), + name="txt_mlp_2" + ), + ], )( (1 + txt_mod2.scale) * nn.LayerNorm( use_scale=False, use_bias=False, - param_dtype=self.weights_dtype + param_dtype=self.weights_dtype, + name="txt_norm2" )(txt) + txt_mod2.shift ) @@ -491,8 +513,9 @@ def __call__(self, x: Array, vec: Array) -> Array: kernel_init=nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("embed", "heads") - ) - ) + ), + name="adaLN_modulation_1" + ), ] )(vec), 2, axis=1 ) @@ -515,5 +538,5 @@ def __call__(self, x: Array, vec: Array) -> Array: ("heads", "embed") ), name="linear" - ) + )(x) return x \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index fe54bb6d3..7baa1775e 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -180,7 +180,7 @@ def __call__( nn.initializers.lecun_normal(), ("embed", "heads") ), - name="text_in" + name="txt_in" )(txt) ids = jnp.concatenate((txt_ids, img_ids), axis=1) @@ -208,7 +208,7 @@ def __call__( # name="double_blocks_0" # )(img=img, txt=txt, vec=vec, pe=pe) # # breakpoint() - for _ in range(self.num_layers): + for i in range(self.num_layers): img, txt = DoubleStreamBlock( hidden_size=inner_dim, num_heads=self.num_attention_heads, @@ -222,6 +222,7 @@ def __call__( precision=self.precision, qkv_bias=self.qkv_bias, attention_kernel=self.attention_kernel, + name=f"double_blocks_{i}" )(img=img, txt=txt, vec=vec, pe=pe) # img, txt = nn.Sequential( # [ @@ -261,7 +262,7 @@ def __call__( # num_layers=self.num_layers # )(img, txt, vec, pe) img = jnp.concatenate((txt, img), axis=1) - for _ in range(self.num_single_layers): + for i in range(self.num_single_layers): img, SingleStreamBlock( hidden_size=inner_dim, num_heads=self.num_attention_heads, @@ -273,12 +274,13 @@ def __call__( dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, - attention_kernel=self.attention_kernel + attention_kernel=self.attention_kernel, + name=f"single_blocks_{i}" )(img, vec, pe) img = img[:, txt.shape[1] :, ...] - LastLayer( + img = LastLayer( hidden_size=inner_dim, patch_size=1, out_channels=out_channels, @@ -286,14 +288,14 @@ def __call__( weights_dtype=self.weights_dtype, precision=self.precision, name="final_layer" - ) + )(img, vec) return img - def init_weights(self, rngs, eval_only=True): + def init_weights(self, rngs, max_sequence_length, eval_only=True): scale_factor = 16 resolution = 1024 - num_devices = len(jax.devices()) + num_devices = jax.local_device_count() batch_size = 1 * num_devices batch_image_shape = ( batch_size, @@ -304,12 +306,12 @@ def init_weights(self, rngs, eval_only=True): # bs, encoder_input, seq_length text_shape = ( batch_size, - 256, + max_sequence_length, 4096, # Sequence length of text encoder, how to get this programmatically? ) text_ids_shape = ( batch_size, - 256, + max_sequence_length, 3, # Hardcoded to match jflux.prepare ) vec_shape = ( @@ -348,11 +350,11 @@ def init_weights(self, rngs, eval_only=True): else: return self.init( rngs, - hidden_states=img, + img=img, img_ids=img_ids, - encoder_hidden_states=txt, + txt=txt, txt_ids=txt_ids, y=vec, - timestep=t_vec, + timesteps=t_vec, guidance=guidance_vec, )["params"] \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 44ad90389..fc8e02165 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -7,11 +7,17 @@ from jax.typing import DTypeLike import torch # need for torch 2 jax from chex import Array -from flax import nnx +from flax.traverse_util import flatten_dict, unflatten_dict from huggingface_hub import hf_hub_download from jax import numpy as jnp from safetensors import safe_open +from maxdiffusion.models.modeling_flax_pytorch_utils import ( + rename_key, + rename_key_and_reshape_tensor +) +from maxdiffusion import max_logging + # from jflux.model import Flux, FluxParams from .port import port_flux @@ -72,7 +78,6 @@ class ModelSpec: theta=10_000, qkv_bias=True, guidance_embed=True, - #rngs=nnx.Rngs(default=42), rngs=jax.random.PRNGKey(42), param_dtype=jnp.bfloat16, ), @@ -94,7 +99,6 @@ class ModelSpec: theta=10_000, qkv_bias=True, guidance_embed=False, - #rngs=nnx.Rngs(default=42), rngs=jax.random.PRNGKey(47), param_dtype=jnp.bfloat16, ), @@ -104,16 +108,37 @@ class ModelSpec: def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: - print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) - print("\n" + "-" * 79 + "\n") - print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + max_logging.log(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + max_logging.log("\n" + "-" * 79 + "\n") + max_logging.log(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: - print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + max_logging.log(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: - print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) - - -def load_flow_model(name: str, device: str, hf_download: bool = True): # -> Flux: + max_logging.log(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + +def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): + """ + expected_pytree: dict - a pytree that comes from initializing the model. + new_pytree: dict - a pytree that has been created from pytorch weights. + """ + expected_pytree = flatten_dict(expected_pytree) + if len(expected_pytree.keys()) != len(new_pytree.keys()): + set1 = set(expected_pytree.keys()) + set2 = set(new_pytree.keys()) + missing_keys = set1 ^ set2 + max_logging.log(f"missing keys : {missing_keys}") + for key in expected_pytree.keys(): + if key in new_pytree.keys(): + try: + expected_pytree_shape = expected_pytree[key].shape + except: + expected_pytree_shape= expected_pytree[key].value.shape + if expected_pytree_shape != new_pytree[key].shape: + max_logging.log(f"shape mismatch, expected shape of {expected_pytree[key].shape}, but got shape of {new_pytree[key].shape}") + else: + max_logging.log(f"key: {key} not found...") + +def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool = True): # -> Flux: device = jax.devices(device)[0] with jax.default_device(device): ckpt_path = configs[name].ckpt_path @@ -125,17 +150,32 @@ def load_flow_model(name: str, device: str, hf_download: bool = True): # -> Flux ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) - print(f"Load and port flux on {device}") + max_logging.log(f"Load and port flux on {device}") - #model = Flux(params=configs[name].params) if ckpt_path is not None: tensors = {} with safe_open(ckpt_path, framework="pt") as f: for k in f.keys(): tensors[k] = torch2jax(f.get_tensor(k)) - breakpoint() - model = port_flux(flux=model, tensors=tensors) - + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + if "double_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("double_blocks.", "double_blocks_") + renamed_pt_key = renamed_pt_key.replace("img_attn.", "img_attn_") + renamed_pt_key = renamed_pt_key.replace("img_mlp.", "img_mlp_") + renamed_pt_key = renamed_pt_key.replace("txt_attn.", "txt_attn_") + renamed_pt_key = renamed_pt_key.replace("txt_mlp.", "txt_mlp_") + + if "single_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("single_blocks.", "single_blocks_") + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) del tensors jax.clear_caches() - return model + return flax_state_dict From bb91e8e3d4a9b5de597d363d5daae92207d61ef0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 30 Jan 2025 02:10:07 +0000 Subject: [PATCH 11/35] apply fsdp sharding, do one forward pass in the transformer. --- src/maxdiffusion/configs/base_flux.yml | 8 ++-- src/maxdiffusion/generate_flux.py | 66 ++++++++++++++++++-------- 2 files changed, 49 insertions(+), 25 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux.yml b/src/maxdiffusion/configs/base_flux.yml index e6d9f15a9..e2a831527 100644 --- a/src/maxdiffusion/configs/base_flux.yml +++ b/src/maxdiffusion/configs/base_flux.yml @@ -123,11 +123,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: 1 +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 # Dataset diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 6e95deb20..98fad6525 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -19,7 +19,7 @@ import functools import numpy as np import jax -from jax.sharding import Mesh, PositionalSharding +from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P import jax.numpy as jnp from chex import Array from transformers import ( @@ -196,7 +196,7 @@ def run(config): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - per_host_number_of_images = config.per_device_batch_size * jax.local_device_count() + global_batch_size = config.per_device_batch_size * jax.local_device_count() # LOAD VAE @@ -225,7 +225,7 @@ def run(config): num_channels_latents = transformer.in_channels // 4 latents, latent_image_ids = prepare_latents( - batch_size=per_host_number_of_images, + batch_size=global_batch_size, num_channels_latents=num_channels_latents, height=config.resolution, width=config.resolution, @@ -270,7 +270,7 @@ def run(config): clip_text_encoder=clip_text_encoder, t5_tokenizer=t5_tokenizer, t5_text_encoder=t5_encoder, - num_images_per_prompt=per_host_number_of_images + num_images_per_prompt=global_batch_size ) def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): @@ -282,8 +282,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep print("guidance.shape: ", guidance.shape, guidance.dtype) print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) - timesteps = jnp.asarray([1.0], dtype=jnp.bfloat16) - guidance = jnp.asarray([3.5], dtype=jnp.bfloat16) + timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16) + guidance = jnp.asarray([3.5] * global_batch_size, dtype=jnp.bfloat16) validate_inputs( latents, latent_image_ids, @@ -293,13 +293,26 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep guidance, pooled_prompt_embeds ) + + # TODO - remove this later and figure out why t5x is returning wrong shape + prompt_embeds = jnp.ones((global_batch_size, 512, 4096)) + + # move inputs to device and shard + data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) + latents = jax.device_put(latents, data_sharding) + latent_image_ids = jax.device_put(latent_image_ids, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + text_ids = jax.device_put(text_ids, data_sharding) + timesteps = jax.device_put(timesteps, data_sharding) + guidance = jax.device_put(guidance, data_sharding) + pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) + get_memory_allocations() # evaluate shapes transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True) # loads pretrained weights transformer_params = load_flow_model("flux-dev", transformer_eval_params, "cpu") - get_memory_allocations() # create transformer state weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=512, eval_only=False) transformer_state, transformer_state_shardings = setup_initial_state( @@ -308,24 +321,35 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep config=config, mesh=mesh, weights_init_fn=weights_init_fn, - model_params=None, + model_params=transformer_params, training=False ) - breakpoint() - transformer_state = transformer_state.replace(params=transformer_params) - img = transformer.apply( - {"params" : transformer_state.params}, - img=latents, - img_ids=latent_image_ids, - txt=prompt_embeds, - txt_ids=text_ids, - timesteps=timesteps, - guidance=guidance, - y=pooled_prompt_embeds - ) + #transformer_state = transformer_state.replace(params=transformer_params) get_memory_allocations() - breakpoint() + def run_inference(state, transformer): + img = transformer.apply( + {"params" : state.params}, + img=latents, + img_ids=latent_image_ids, + txt=prompt_embeds, + txt_ids=text_ids, + timesteps=timesteps, + guidance=guidance, + y=pooled_prompt_embeds + ) + return img + + p_run_inference = jax.jit( + functools.partial( + run_inference, + transformer=transformer + ), + in_shardings=(transformer_state_shardings,), + out_shardings=None + ) + img = p_run_inference(transformer_state) + print("img.shape: ", img.shape) def main(argv: Sequence[str]) -> None: From dfe1089ec7c5ed83b72d0bde30def9187d2f8b1e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 30 Jan 2025 06:59:11 +0000 Subject: [PATCH 12/35] wip - generate fn --- src/maxdiffusion/configs/base_flux.yml | 2 +- src/maxdiffusion/generate_flux.py | 170 ++++++++++++++++++++++--- 2 files changed, 153 insertions(+), 19 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux.yml b/src/maxdiffusion/configs/base_flux.yml index e2a831527..b4e38ceef 100644 --- a/src/maxdiffusion/configs/base_flux.yml +++ b/src/maxdiffusion/configs/base_flux.yml @@ -200,7 +200,7 @@ prompt: "A magical castle in the middle of a forest, artistic drawing" prompt_2: "A magical castle in the middle of a forest, artistic drawing" negative_prompt: "purple, red" do_classifier_free_guidance: True -guidance_scale: 9.0 +guidance_scale: 3.5 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 20 diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 98fad6525..9dcb33e88 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -17,11 +17,14 @@ from typing import Any, Callable, Dict, List, Optional, Union, Sequence from absl import app import functools +import math import numpy as np import jax from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P import jax.numpy as jnp from chex import Array +from einops import rearrange +from flax.linen import partitioning as nn_partitioning from transformers import ( CLIPTokenizer, FlaxCLIPTextModel, @@ -42,6 +45,51 @@ setup_initial_state ) +def unpack(x: Array, height: int, width: int) -> Array: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + +def vae_decode(latents, vae, state, config): + img = unpack(x=latents, height=config.resolution, width=config.resolution) + img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample[0] + breakpoint() + return img + +def loop_body( + step, + args, + transformer, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, +): + latents, state, c_ts, p_ts = args + latents_dtype = latents.dtype + t_curr = c_ts[step] + t_prev = p_ts[step] + t_vec = jnp.full((latents.shape[0], ), t_curr, dtype=latents.dtype) + pred = transformer.apply( + {"params" : state.params}, + img=latents, + img_ids=latent_image_ids, + txt=prompt_embeds, + txt_ids=txt_ids, + timesteps=t_vec, + guidance=guidance_vec, + y=vec + ) + latents = latents + (t_prev - t_curr) * pred + latents = jnp.array(latents, dtype=latents_dtype) + return latents, state, c_ts, p_ts + def prepare_latent_image_ids(height, width): latent_image_ids = jnp.zeros((height, width, 3)) latent_image_ids = latent_image_ids.at[..., 1].set( @@ -59,6 +107,45 @@ def prepare_latent_image_ids(height, width): return latent_image_ids.astype(jnp.bfloat16) +def run_inference( + states, + transformer, + vae, + config, + mesh, + latents, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, +): + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + + transformer_state = states["transformer"] + vae_state = states["vae"] + + loop_body_p = functools.partial( + loop_body, + transformer=transformer, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=txt_ids, + vec=vec, + guidance_vec=guidance_vec, + ) + + vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config) + + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + latents, _, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, transformer_state, c_ts, p_ts)) + image = vae_decode_p(latents) + breakpoint() + return image + + def pack_latents( latents: Array, batch_size: int, @@ -207,6 +294,18 @@ def run(config): use_safetensors=True, dtype="bfloat16" ) + + weights_init_fn = functools.partial(vae.init_weights, rng=rng) + vae_state, vae_state_shardings = setup_initial_state( + model=vae, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=vae_params, + training=False, + ) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) # LOAD TRANSFORMER @@ -283,7 +382,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16) - guidance = jnp.asarray([3.5] * global_batch_size, dtype=jnp.bfloat16) + guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) validate_inputs( latents, latent_image_ids, @@ -321,34 +420,69 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep config=config, mesh=mesh, weights_init_fn=weights_init_fn, - model_params=transformer_params, + #model_params=transformer_params, + model_params=None, training=False ) - #transformer_state = transformer_state.replace(params=transformer_params) + transformer_state = transformer_state.replace(params=transformer_params) get_memory_allocations() - def run_inference(state, transformer): - img = transformer.apply( - {"params" : state.params}, - img=latents, - img_ids=latent_image_ids, - txt=prompt_embeds, - txt_ids=text_ids, - timesteps=timesteps, - guidance=guidance, - y=pooled_prompt_embeds - ) - return img + + states = {} + state_shardings = {} + + state_shardings["transformer"] = transformer_state_shardings + state_shardings["vae"] = vae_state_shardings + + states["transformer"] = transformer_state + states["vae"] = vae_state p_run_inference = jax.jit( functools.partial( run_inference, - transformer=transformer + transformer=transformer, + vae=vae, + config=config, + mesh=mesh, + latents=latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, + guidance_vec=guidance, ), - in_shardings=(transformer_state_shardings,), - out_shardings=None + in_shardings=(state_shardings,), + out_shardings=None, ) + img = p_run_inference(states) + + + + + # def run_inference(state, transformer): + # img = transformer.apply( + # {"params" : state.params}, + # img=latents, + # img_ids=latent_image_ids, + # txt=prompt_embeds, + # txt_ids=text_ids, + # timesteps=timesteps, + # guidance=guidance, + # y=pooled_prompt_embeds + # ) + # return img + + # p_run_inference = jax.jit( + # functools.partial( + # run_inference, + # transformer=transformer, + # ), + # in_shardings=(transformer_state_shardings,), + # out_shardings=None + # ) + img = p_run_inference(transformer_state) + breakpoint() print("img.shape: ", img.shape) From cbc772389f267baec6cb59fda2c3db06c0e9e61e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 30 Jan 2025 19:05:13 +0000 Subject: [PATCH 13/35] working loop, bad generation --- src/maxdiffusion/configs/base_flux_dev.yml | 26 +---- .../{base_flux.yml => base_fux_schnell.yml} | 6 ++ src/maxdiffusion/generate_flux.py | 96 +++++++++++-------- .../transformers/transformer_flux_flax.py | 7 +- 4 files changed, 71 insertions(+), 64 deletions(-) rename src/maxdiffusion/configs/{base_flux.yml => base_fux_schnell.yml} (98%) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 3077e3b56..e8495efda 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -28,13 +28,10 @@ clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' # Flux params -flux_name: "flux-dev" max_sequence_length: 512 -time_shift: True +time_shift: False base_shift: 0.5 max_shift: 1.15 -# offloads t5 encoder after text encoding to save memory. -offload_encoders: True unet_checkpoint: '' @@ -52,22 +49,10 @@ activations_dtype: 'bfloat16' precision: "DEFAULT" # Set true to load weights from pytorch -from_pt: True +from_pt: False split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash - flash_block_sizes: {} -# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. -# flash_block_sizes: { -# "block_q" : 1536, -# "block_kv_compute" : 1536, -# "block_kv" : 1536, -# "block_q_dkv" : 1536, -# "block_kv_dkv" : 1536, -# "block_kv_dkv_compute" : 1536, -# "block_q_dq" : 1536, -# "block_kv_dq" : 1536 -# } # GroupNorm groups norm_num_groups: 32 @@ -133,7 +118,6 @@ logical_axis_rules: [ ['activation_batch', ['data','fsdp']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], - ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], ['conv_batch', ['data','fsdp']], @@ -149,8 +133,8 @@ data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 # Dataset @@ -226,7 +210,7 @@ do_classifier_free_guidance: True guidance_scale: 3.5 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 50 +num_inference_steps: 20 # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/configs/base_flux.yml b/src/maxdiffusion/configs/base_fux_schnell.yml similarity index 98% rename from src/maxdiffusion/configs/base_flux.yml rename to src/maxdiffusion/configs/base_fux_schnell.yml index b4e38ceef..3f2ebff12 100644 --- a/src/maxdiffusion/configs/base_flux.yml +++ b/src/maxdiffusion/configs/base_fux_schnell.yml @@ -27,6 +27,12 @@ pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' +# Flux params +max_sequence_length: 256 +time_shift: False +base_shift: 0.5 +max_shift: 1.15 + unet_checkpoint: '' revision: 'refs/pr/95' # This will convert the weights to this dtype. diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 9dcb33e88..1778e979c 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -18,7 +18,9 @@ from absl import app import functools import math +import time import numpy as np +from PIL import Image import jax from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P import jax.numpy as jnp @@ -33,9 +35,8 @@ FlaxT5EncoderModel ) -from maxdiffusion import FlaxAutoencoderKL +from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel -from maxdiffusion import pyconfig from max_utils import ( device_put_replicated, get_memory_allocations, @@ -57,8 +58,8 @@ def unpack(x: Array, height: int, width: int) -> Array: def vae_decode(latents, vae, state, config): img = unpack(x=latents, height=config.resolution, width=config.resolution) - img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample[0] - breakpoint() + img = img / vae.config.scaling_factor + vae.config.shift_factor + img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample return img def loop_body( @@ -107,6 +108,19 @@ def prepare_latent_image_ids(height, width): return latent_image_ids.astype(jnp.bfloat16) +def time_shift(mu: float, sigma: float, t: Array): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + +def get_lin_function( + x1: float = 256, + y1: float = 0.5, + x2: float = 4096, + y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + def run_inference( states, transformer, @@ -120,10 +134,18 @@ def run_inference( vec, guidance_vec, ): + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) + # shifting the schedule to favor high timesteps for higher signal images + if config.time_shift: + # estimate mu based on linear estimation between two points + lin_function = get_lin_function(y1=config.base_shift, y2=config.max_shift) + mu = lin_function(latents.shape[1]) + timesteps = time_shift(mu, 1.0, timesteps).tolist() c_ts = timesteps[:-1] p_ts = timesteps[1:] + transformer_state = states["transformer"] vae_state = states["vae"] @@ -142,7 +164,6 @@ def run_inference( with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): latents, _, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, transformer_state, c_ts, p_ts)) image = vae_decode_p(latents) - breakpoint() return image @@ -383,6 +404,10 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16) guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) + + # TODO - remove this later and figure out why t5x is returning wrong shape + prompt_embeds = jnp.ones((global_batch_size, 512, 4096)) + validate_inputs( latents, latent_image_ids, @@ -393,8 +418,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep pooled_prompt_embeds ) - # TODO - remove this later and figure out why t5x is returning wrong shape - prompt_embeds = jnp.ones((global_batch_size, 512, 4096)) + # move inputs to device and shard data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) @@ -420,11 +444,11 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep config=config, mesh=mesh, weights_init_fn=weights_init_fn, - #model_params=transformer_params, - model_params=None, + model_params=transformer_params, + #model_params=None, training=False ) - transformer_state = transformer_state.replace(params=transformer_params) + #transformer_state = transformer_state.replace(params=transformer_params) get_memory_allocations() states = {} @@ -453,37 +477,27 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep in_shardings=(state_shardings,), out_shardings=None, ) - - img = p_run_inference(states) - - - - - # def run_inference(state, transformer): - # img = transformer.apply( - # {"params" : state.params}, - # img=latents, - # img_ids=latent_image_ids, - # txt=prompt_embeds, - # txt_ids=text_ids, - # timesteps=timesteps, - # guidance=guidance, - # y=pooled_prompt_embeds - # ) - # return img - - # p_run_inference = jax.jit( - # functools.partial( - # run_inference, - # transformer=transformer, - # ), - # in_shardings=(transformer_state_shardings,), - # out_shardings=None - # ) - - img = p_run_inference(transformer_state) - breakpoint() - print("img.shape: ", img.shape) + t0 = time.perf_counter() + p_run_inference(states).block_until_ready() + t1 = time.perf_counter() + max_logging.log(f"Compile time: {t1 - t0:.1f}s.") + + t0 = time.perf_counter() + imgs = p_run_inference(states).block_until_ready() + t1 = time.perf_counter() + max_logging.log(f"Inference time: {t1 - t0:.1f}s.") + + t0 = time.perf_counter() + imgs = p_run_inference(states).block_until_ready() + imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) + t1 = time.perf_counter() + max_logging.log(f"Inference time: {t1 - t0:.1f}s.") + imgs = np.array(imgs) + imgs = (imgs * 0.5 + 0.5).clip(0, 1) + imgs = np.transpose(imgs, (0, 2, 3, 1)) + imgs = np.uint8(imgs * 255) + for i, image in enumerate(imgs): + Image.fromarray(image).save(f"flux_{i}.png") def main(argv: Sequence[str]) -> None: diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 7baa1775e..3e04dcd41 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -154,14 +154,17 @@ def __call__( raise ValueError( "Didn't get guidance strength for guidance distrilled model." ) - - vec = vec + MLPEmbedder( + guidance_in = MLPEmbedder( hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, name="guidance_in" )(timestep_embedding(guidance, 256)) + else: + guidance_in = Identity(timestep_embedding(guidance, 256)) + + vec = vec + guidance_in vec = vec + MLPEmbedder( hidden_dim=inner_dim, From ac14a4bab0a8b539eada95dcba0655999296a561 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 30 Jan 2025 22:40:34 +0000 Subject: [PATCH 14/35] e2e, encoder offloading. --- src/maxdiffusion/configs/base_flux_dev.yml | 4 +- src/maxdiffusion/configs/base_fux_schnell.yml | 2 + src/maxdiffusion/generate_flux.py | 40 +++++++++---------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index e8495efda..35d54a78e 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -32,6 +32,8 @@ max_sequence_length: 512 time_shift: False base_shift: 0.5 max_shift: 1.15 +# offloads t5 encoder after text encoding to save memory. +offload_encoders: True unet_checkpoint: '' @@ -210,7 +212,7 @@ do_classifier_free_guidance: True guidance_scale: 3.5 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 20 +num_inference_steps: 50 # SDXL Lightning parameters lightning_from_pt: True diff --git a/src/maxdiffusion/configs/base_fux_schnell.yml b/src/maxdiffusion/configs/base_fux_schnell.yml index 3f2ebff12..a65781548 100644 --- a/src/maxdiffusion/configs/base_fux_schnell.yml +++ b/src/maxdiffusion/configs/base_fux_schnell.yml @@ -32,6 +32,8 @@ max_sequence_length: 256 time_shift: False base_shift: 0.5 max_shift: 1.15 +# offloads t5 encoder after text encoding to save memory. +offload_encoders: True unet_checkpoint: '' revision: 'refs/pr/95' diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 1778e979c..00bd4844e 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -30,9 +30,9 @@ from transformers import ( CLIPTokenizer, FlaxCLIPTextModel, - T5TokenizerFast, T5EncoderModel, - FlaxT5EncoderModel + FlaxT5EncoderModel, + AutoTokenizer ) from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging @@ -235,7 +235,7 @@ def get_clip_prompt_embeds( def get_t5_prompt_embeds( prompt: Union[str, List[str]], num_images_per_prompt: int, - tokenizer: T5TokenizerFast, + tokenizer: AutoTokenizer, text_encoder: T5EncoderModel, max_sequence_length: int = 512 ): @@ -245,18 +245,20 @@ def get_t5_prompt_embeds( text_inputs = tokenizer( prompt, - padding="max_length", - max_length=max_sequence_length, truncation=True, + max_length=max_sequence_length, return_length=False, return_overflowing_tokens=False, + padding="max_length", return_tensors="np" ) text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False)[0] + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=None, + output_hidden_states=False)["last_hidden_state"] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.astype(dtype) - _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = jnp.tile(prompt_embeds, (1, num_images_per_prompt, 1)) @@ -270,7 +272,7 @@ def encode_prompt( prompt_2: Union[str, List[str]], clip_tokenizer: CLIPTokenizer, clip_text_encoder: FlaxCLIPTextModel, - t5_tokenizer: T5TokenizerFast, + t5_tokenizer: AutoTokenizer, t5_text_encoder: T5EncoderModel, num_images_per_prompt: int = 1, max_sequence_length: int = 512 @@ -368,13 +370,10 @@ def run(config): ) t5_encoder = FlaxT5EncoderModel.from_pretrained( - config.clip_model_name_or_path, + config.t5xxl_model_name_or_path, dtype=config.weights_dtype ) - t5_tokenizer = T5TokenizerFast.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="tokenizer_2", - ) + t5_tokenizer = AutoTokenizer.from_pretrained(config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True) encoders_sharding = PositionalSharding(devices_array).replicate() partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) @@ -405,9 +404,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16) guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) - # TODO - remove this later and figure out why t5x is returning wrong shape - prompt_embeds = jnp.ones((global_batch_size, 512, 4096)) - validate_inputs( latents, latent_image_ids, @@ -418,8 +414,6 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep pooled_prompt_embeds ) - - # move inputs to device and shard data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) latents = jax.device_put(latents, data_sharding) @@ -430,6 +424,10 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep guidance = jax.device_put(guidance, data_sharding) pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) + if config.offload_encoders: + cpus = jax.devices("cpu") + t5_encoder.params = jax.device_put(t5_encoder.params, device=cpus[0]) + get_memory_allocations() # evaluate shapes transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True) @@ -444,11 +442,11 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep config=config, mesh=mesh, weights_init_fn=weights_init_fn, - model_params=transformer_params, - #model_params=None, + model_params=None, training=False ) - #transformer_state = transformer_state.replace(params=transformer_params) + transformer_state = transformer_state.replace(params=transformer_params) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) get_memory_allocations() states = {} From 1c8ed7b938d0ae2b00f4d8e5a92d336c9aa720d0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Sat, 1 Feb 2025 00:26:34 +0000 Subject: [PATCH 15/35] support both dev and schnell loading. Images still incorrect. --- src/maxdiffusion/configs/base_flux_dev.yml | 1 + .../configs/base_flux_schnell.yml | 34 +-- src/maxdiffusion/configs/base_fux_schnell.yml | 247 ------------------ src/maxdiffusion/generate_flux.py | 22 +- .../models/flux/modules/layers.py | 42 +++ .../transformers/transformer_flux_flax.py | 49 +++- src/maxdiffusion/models/flux/util.py | 5 +- 7 files changed, 101 insertions(+), 299 deletions(-) delete mode 100644 src/maxdiffusion/configs/base_fux_schnell.yml diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 35d54a78e..fc4a3c2da 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -28,6 +28,7 @@ clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' # Flux params +flux_name: "flux-dev" max_sequence_length: 512 time_shift: False base_shift: 0.5 diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index ee9db566a..316fa753d 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -51,31 +51,10 @@ activations_dtype: 'bfloat16' precision: "DEFAULT" # Set true to load weights from pytorch -from_pt: True +from_pt: False split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: { - "block_q" : 256, - "block_kv_compute" : 256, - "block_kv" : 256, - "block_q_dkv" : 256, - "block_kv_dkv" : 256, - "block_kv_dkv_compute" : 256, - "block_q_dq" : 256, - "block_kv_dq" : 256 -} - -# Use the following flash_block_sizes on v6e (Trillium). -# flash_block_sizes: { -# "block_q" : 2176, -# "block_kv_compute" : 2176, -# "block_kv" : 2176, -# "block_q_dkv" : 2176, -# "block_kv_dkv" : 2176, -# "block_kv_dkv_compute" : 2176, -# "block_q_dq" : 2176, -# "block_kv_dq" : 2176 -# } +flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 @@ -141,7 +120,6 @@ logical_axis_rules: [ ['activation_batch', ['data','fsdp']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], - ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], ['conv_batch', ['data','fsdp']], @@ -154,11 +132,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: 1 +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 -ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 # Dataset diff --git a/src/maxdiffusion/configs/base_fux_schnell.yml b/src/maxdiffusion/configs/base_fux_schnell.yml deleted file mode 100644 index a65781548..000000000 --- a/src/maxdiffusion/configs/base_fux_schnell.yml +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This sentinel is a reminder to choose a real run name. -run_name: '' - -metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. -# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ -write_metrics: True -gcs_metrics: False -# If true save config to GCS in {base_output_directory}/{run_name}/ -save_config_to_gcs: False -log_period: 100 - -pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' -clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' -t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' - -# Flux params -max_sequence_length: 256 -time_shift: False -base_shift: 0.5 -max_shift: 1.15 -# offloads t5 encoder after text encoding to save memory. -offload_encoders: True - -unet_checkpoint: '' -revision: 'refs/pr/95' -# This will convert the weights to this dtype. -# When running inference on TPUv5e, use weights_dtype: 'bfloat16' -weights_dtype: 'bfloat16' -# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) -activations_dtype: 'bfloat16' - -# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision -# Options are "DEFAULT", "HIGH", "HIGHEST" -# fp32 activations and fp32 weights with HIGHEST will provide the best precision -# at the cost of time. -precision: "DEFAULT" - -# Set true to load weights from pytorch -from_pt: False -split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: {} -# GroupNorm groups -norm_num_groups: 32 - -# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch -# else they will be loaded from pretrained_model_name_or_path -train_new_unet: False - -# train text_encoder - Currently not supported for SDXL -train_text_encoder: False -text_encoder_learning_rate: 4.25e-6 - -# https://arxiv.org/pdf/2305.08891.pdf -snr_gamma: -1.0 - -timestep_bias: { - # a value of later will increase the frequence of the model's final training steps. - # none, earlier, later, range - strategy: "none", - # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. - multiplier: 1.0, - # when using strategy=range, the beginning (inclusive) timestep to bias. - begin: 0, - # when using strategy=range, the final step (inclusive) to bias. - end: 1000, - # portion of timesteps to bias. - # 0.5 will bias one half of the timesteps. Value of strategy determines - # whether the biased portions are in the earlier or later timesteps. - portion: 0.25 -} - -# Override parameters from checkpoints's scheduler. -diffusion_scheduler_config: { - _class_name: 'FlaxEulerDiscreteScheduler', - prediction_type: 'epsilon', - rescale_zero_terminal_snr: False, - timestep_spacing: 'trailing' -} - -# Output directory -# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" -base_output_directory: "" - -# Hardware -hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' - -# Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] - -# batch : batch dimension of data and activations -# hidden : -# embed : attention qkv dense layer hidden dim named as embed -# heads : attention head dim = num_heads * head_dim -# length : attention sequence length -# temb_in : dense.shape[0] of resnet dense before conv -# out_c : dense.shape[1] of resnet dense before conv -# out_channels : conv.shape[-1] activation -# keep_1 : conv.shape[0] weight -# keep_2 : conv.shape[1] weight -# conv_in : conv.shape[2] weight -# conv_out : conv.shape[-1] weight -logical_axis_rules: [ - ['batch', 'data'], - ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], - ['activation_kv', 'tensor'], - ['embed','fsdp'], - ['heads', 'tensor'], - ['conv_batch', ['data','fsdp']], - ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], - ] -data_sharding: [['data', 'fsdp', 'tensor']] - -# One axis for each parallelism type may hold a placeholder (-1) -# value to auto-shard based on available slices and devices. -# By default, product of the DCN axes should equal number of slices -# and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: -1 -dcn_tensor_parallelism: 1 -ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded -ici_tensor_parallelism: 1 - -# Dataset -# Replace with dataset path or train_data_dir. One has to be set. -dataset_name: 'diffusers/pokemon-gpt4-captions' -train_split: 'train' -dataset_type: 'tf' -cache_latents_text_encoder_outputs: True -# cache_latents_text_encoder_outputs only apply to dataset_type="tf", -# only apply to small dataset that fits in memory -# prepare image latents and text encoder outputs -# Reduce memory consumption and reduce step time during training -# transformed dataset is saved at dataset_save_location -dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' -train_data_dir: '' -dataset_config_name: '' -jax_cache_dir: '' -hf_data_dir: '' -hf_train_files: '' -hf_access_token: '' -image_column: 'image' -caption_column: 'text' -resolution: 1024 -center_crop: False -random_flip: False -# If cache_latents_text_encoder_outputs is True -# the num_proc is set to 1 -tokenize_captions_num_proc: 4 -transform_images_num_proc: 4 -reuse_example_batch: False -enable_data_shuffling: True - -# checkpoint every number of samples, -1 means don't checkpoint. -checkpoint_every: -1 -# enables one replica to read the ckpt then broadcast to the rest -enable_single_replica_ckpt_restoring: False - -# Training loop -learning_rate: 4.e-7 -scale_lr: False -max_train_samples: -1 -# max_train_steps takes priority over num_train_epochs. -max_train_steps: 200 -num_train_epochs: 1 -seed: 0 -output_dir: 'sdxl-model-finetuned' -per_device_batch_size: 1 - -warmup_steps_fraction: 0.0 -learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. - -# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before -# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. - -# AdamW optimizer parameters -adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. -adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. -adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. -adam_weight_decay: 1.e-2 # AdamW Weight decay -max_grad_norm: 1.0 - -enable_profiler: False -# Skip first n steps for profiling, to omit things like compilation and to give -# the iteration time a chance to stabilize. -skip_first_n_steps_for_profiler: 5 -profiler_steps: 10 - -# Generation parameters -prompt: "A magical castle in the middle of a forest, artistic drawing" -prompt_2: "A magical castle in the middle of a forest, artistic drawing" -negative_prompt: "purple, red" -do_classifier_free_guidance: True -guidance_scale: 3.5 -# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf -guidance_rescale: 0.0 -num_inference_steps: 20 - -# SDXL Lightning parameters -lightning_from_pt: True -# Empty or "ByteDance/SDXL-Lightning" to enable lightning. -lightning_repo: "" -# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. -lightning_ckpt: "" - -# LoRA parameters -# Values are lists to support multiple LoRA loading during inference in the future. -lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], - from_pt: [] -} -# Ex with values: -# lora_config : { -# lora_model_name_or_path: ["ByteDance/Hyper-SD"], -# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], -# adapter_name: ["hyper-sdxl"], -# scale: [0.7], -# from_pt: [True] -# } - -enable_mllog: False - -#controlnet -controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' -controlnet_from_pt: True -controlnet_conditioning_scale: 0.5 -controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 00bd4844e..0d542bb95 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -55,7 +55,7 @@ def unpack(x: Array, height: int, width: int) -> Array: ph=2, pw=2, ) - +from einops import rearrange def vae_decode(latents, vae, state, config): img = unpack(x=latents, height=config.resolution, width=config.resolution) img = img / vae.config.scaling_factor + vae.config.shift_factor @@ -87,6 +87,8 @@ def loop_body( guidance=guidance_vec, y=vec ) + jax.debug.print("*****pred max: {x}", x=np.max(pred)) + jax.debug.print("*****pred min: {x}", x=np.min(pred)) latents = latents + (t_prev - t_curr) * pred latents = jnp.array(latents, dtype=latents_dtype) return latents, state, c_ts, p_ts @@ -144,6 +146,8 @@ def run_inference( timesteps = time_shift(mu, 1.0, timesteps).tolist() c_ts = timesteps[:-1] p_ts = timesteps[1:] + # jax.debug.print("c_ts: {x}", x=c_ts) + # jax.debug.print("p_ts: {x}", x=p_ts) transformer_state = states["transformer"] @@ -162,7 +166,7 @@ def run_inference( vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, _, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, transformer_state, c_ts, p_ts)) + latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, transformer_state, c_ts, p_ts)) image = vae_decode_p(latents) return image @@ -293,7 +297,8 @@ def encode_prompt( prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, tokenizer=t5_tokenizer, - text_encoder=t5_text_encoder + text_encoder=t5_text_encoder, + max_sequence_length=max_sequence_length ) text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) @@ -356,7 +361,7 @@ def run(config): rng=rng ) - # LOAD TEXT ENCODERS - t5 on cpu + # LOAD TEXT ENCODERS clip_text_encoder = FlaxCLIPTextModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder", @@ -389,7 +394,8 @@ def run(config): clip_text_encoder=clip_text_encoder, t5_tokenizer=t5_tokenizer, t5_text_encoder=t5_encoder, - num_images_per_prompt=global_batch_size + num_images_per_prompt=global_batch_size, + max_sequence_length=config.max_sequence_length ) def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): @@ -430,12 +436,12 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep get_memory_allocations() # evaluate shapes - transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True) + transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True) # loads pretrained weights - transformer_params = load_flow_model("flux-dev", transformer_eval_params, "cpu") + transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") # create transformer state - weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=512, eval_only=False) + weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False) transformer_state, transformer_state_shardings = setup_initial_state( model=transformer, tx=None, diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py index 91562d406..d4c3d5534 100644 --- a/src/maxdiffusion/models/flux/modules/layers.py +++ b/src/maxdiffusion/models/flux/modules/layers.py @@ -111,7 +111,49 @@ def timestep_embedding( embedding = embedding.astype(t.dtype) return embedding +import numpy as np +class PixArtAlphaTextProjection(nn.Module): + hidden_dim: int + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, x: Array) -> Array: + + hidden_states = nn.Dense( + self.hidden_dim, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("embed", "heads") + ), + name="in_layer" + )(x) + jax.debug.print("PixArtAlphaTextProjection, in_layer min: {x}", x=np.min(hidden_states)) + jax.debug.print("PixArtAlphaTextProjection, in_layer max: {x}", x=np.max(hidden_states)) + hidden_states = nn.swish(hidden_states) + jax.debug.print("PixArtAlphaTextProjection, act min: {x}", x=np.min(hidden_states)) + jax.debug.print("PixArtAlphaTextProjection, act max: {x}", x=np.max(hidden_states)) + hidden_states = nn.Dense( + self.hidden_dim, + use_bias=True, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + kernel_init=nn.with_logical_partitioning( + nn.initializers.lecun_normal(), + ("heads", "embed") + ), + name="out_layer" + )(hidden_states) + jax.debug.print("PixArtAlphaTextProjection, out min: {x}", x=np.min(hidden_states)) + jax.debug.print("PixArtAlphaTextProjection, out max: {x}", x=np.max(hidden_states)) + return hidden_states class MLPEmbedder(nn.Module): hidden_dim: int diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 3e04dcd41..71fa03d0a 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -17,6 +17,7 @@ from typing import Dict, Optional, Tuple, Union from einops import repeat, rearrange +import numpy as np import jax import jax.numpy as jnp import flax.linen as nn @@ -28,7 +29,8 @@ EmbedND, DoubleStreamBlock, SingleStreamBlock, - LastLayer + LastLayer, + PixArtAlphaTextProjection ) from ...modeling_flax_utils import FlaxModelMixin from ....configuration_utils import ConfigMixin, flax_register_to_config @@ -129,6 +131,9 @@ def __call__( inner_dim = self.num_attention_heads * self.attention_head_dim pe_dim = inner_dim // self.num_attention_heads + jax.debug.print("pooled_projections value min: {x}", x=np.min(y)) + jax.debug.print("pooled_projections value max: {x}", x=np.max(y)) + img = nn.Dense( inner_dim, dtype=self.dtype, @@ -140,39 +145,57 @@ def __call__( ), name="img_in" )(img) - + jax.debug.print("img.min: {x}", x=np.min(img)) + jax.debug.print("img.max: {x}", x=np.max(img)) + timestep = timestep_embedding(timesteps, 256) + jax.debug.print("timestep.min: {x}", x=np.min(timestep)) + jax.debug.print("timestep.max: {x}", x=np.max(timestep)) vec = MLPEmbedder( hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, name="time_in" - )(timestep_embedding(timesteps, 256)) - + )(timestep) + jax.debug.print("timestep.vec min: {x}", x=np.min(vec)) + jax.debug.print("timestep.vec max: {x}", x=np.max(vec)) + print(f"guidance_embeds? {self.guidance_embeds}") if self.guidance_embeds: if guidance is None: raise ValueError( "Didn't get guidance strength for guidance distrilled model." ) + guidance_in = timestep_embedding(guidance, 256) + + jax.debug.print("guidance_in.min: {x}", x=np.min(guidance_in)) + jax.debug.print("guidance_in.max: {x}", x=np.max(guidance_in)) guidance_in = MLPEmbedder( hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, name="guidance_in" - )(timestep_embedding(guidance, 256)) - else: - guidance_in = Identity(timestep_embedding(guidance, 256)) - + )(guidance_in) + jax.debug.print("guidance.vec min: {x}", x=np.min(guidance_in)) + jax.debug.print("guidance.vec max: {x}", x=np.max(guidance_in)) vec = vec + guidance_in - - vec = vec + MLPEmbedder( + jax.debug.print("timestep_guidance.vec min: {x}", x=np.min(vec)) + jax.debug.print("timestep_guidance.vec max: {x}", x=np.max(vec)) + # else: + # guidance_in = Identity()(timestep_embedding(guidance, 256)) + + pooled_projections = PixArtAlphaTextProjection( hidden_dim=inner_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, name="vector_in" )(y) + jax.debug.print("pooled_projections.min: {x}", x=np.min(pooled_projections)) + jax.debug.print("pooled_projections.max: {x}", x=np.max(pooled_projections)) + vec = vec + pooled_projections + jax.debug.print("temb.min: {x}", x=np.min(vec)) + jax.debug.print("temb.max: {x}", x=np.max(vec)) txt = nn.Dense( inner_dim, @@ -185,7 +208,8 @@ def __call__( ), name="txt_in" )(txt) - + jax.debug.print("txt.min: {x}", x=np.min(txt)) + jax.debug.print("txt.max: {x}", x=np.max(txt)) ids = jnp.concatenate((txt_ids, img_ids), axis=1) #pe_embedder @@ -194,7 +218,8 @@ def __call__( theta=10000, axes_dim=self.axes_dims_rope )(ids) - # breakpoint() + jax.debug.print("pe.min: {x}", x=np.min(pe)) + jax.debug.print("pe.max: {x}", x=np.max(pe)) # img, txt = DoubleStreamBlock( # hidden_size=inner_dim, # num_heads=self.num_attention_heads, diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index fc8e02165..59ff4509d 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -18,9 +18,6 @@ ) from maxdiffusion import max_logging -# from jflux.model import Flux, FluxParams -from .port import port_flux - @dataclass class FluxParams: in_channels: int @@ -42,7 +39,7 @@ def torch2jax(torch_tensor: torch.Tensor) -> Array: is_bfloat16 = torch_tensor.dtype == torch.bfloat16 if is_bfloat16: # upcast the tensor to fp32 - torch_tensor = torch_tensor.to(dtype=torch.float32) + torch_tensor = torch_tensor.float() if torch.device.type != "cpu": torch_tensor = torch_tensor.to("cpu") From c8196edd544cddb927713ebfaacbab1a11a996b1 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 3 Feb 2025 21:23:43 +0000 Subject: [PATCH 16/35] flux schnell working --- src/maxdiffusion/configs/base_flux_dev.yml | 13 +- .../configs/base_flux_schnell.yml | 13 +- src/maxdiffusion/generate_flux.py | 12 +- src/maxdiffusion/models/attention_flax.py | 171 ++++ src/maxdiffusion/models/embeddings_flax.py | 44 +- .../transformers/transformer_flux_flax.py | 789 ++++++++++++------ src/maxdiffusion/models/flux/util.py | 35 +- src/maxdiffusion/models/normalization_flax.py | 2 +- 8 files changed, 777 insertions(+), 302 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index fc4a3c2da..987101672 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -52,10 +52,19 @@ activations_dtype: 'bfloat16' precision: "DEFAULT" # Set true to load weights from pytorch -from_pt: False +from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: {} +flash_block_sizes: { + "block_q" : 128, + "block_kv" : 128, + "block_kv_compute" : 128, + "block_q_dkv" : 128, + "block_kv_dkv" : 128, + "block_kv_dkv_compute" : 128, + "block_q_dq" : 128, + "block_kv_dq" : 128 +} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 316fa753d..fd80c5f28 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -51,10 +51,19 @@ activations_dtype: 'bfloat16' precision: "DEFAULT" # Set true to load weights from pytorch -from_pt: False +from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: {} +flash_block_sizes: { + "block_q" : 128, + "block_kv" : 128, + "block_kv_compute" : 128, + "block_q_dkv" : 128, + "block_kv_dkv" : 128, + "block_kv_dkv_compute" : 128, + "block_q_dq" : 128, + "block_kv_dq" : 128 +} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 0d542bb95..606b15051 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -79,16 +79,14 @@ def loop_body( t_vec = jnp.full((latents.shape[0], ), t_curr, dtype=latents.dtype) pred = transformer.apply( {"params" : state.params}, - img=latents, + hidden_states=latents, img_ids=latent_image_ids, - txt=prompt_embeds, + encoder_hidden_states=prompt_embeds, txt_ids=txt_ids, - timesteps=t_vec, + timestep=t_vec, guidance=guidance_vec, - y=vec - ) - jax.debug.print("*****pred max: {x}", x=np.max(pred)) - jax.debug.print("*****pred min: {x}", x=np.min(pred)) + pooled_projections=vec + ).sample latents = latents + (t_prev - t_curr) * pred latents = jnp.array(latents, dtype=latents_dtype) return latents, state, c_ts, p_ts diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 884b6d688..2ac2182fc 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -322,6 +322,177 @@ def chunk_scanner(chunk_idx, _): return jnp.concatenate(res, axis=-3) # fuse the chunked result back +class FlaxFluxAttention(nn.Module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + attention_kernel: str = "dot_product" + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) + out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) + precision: jax.lax.Precision = None + qkv_bias: bool = False + + def setup(self): + if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: + raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") + inner_dim = self.dim_head * self.heads + scale = self.dim_head**-0.5 + + self.attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + scale=scale, + heads=self.heads, + dim_head=self.dim_head, + flash_min_seq_length=self.flash_min_seq_length, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + flash_block_sizes=self.flash_block_sizes, + dtype=self.dtype, + float32_qk_product=False, + ) + + kernel_axes = ("embed", "heads") + qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) + + self.qkv = nn.Dense( + inner_dim * 3, + kernel_init=qkv_init_kernel, + use_bias=self.qkv_bias, + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="i_qkv", + precision=self.precision, + ) + + self.encoder_qkv = nn.Dense( + inner_dim * 3, + kernel_init=qkv_init_kernel, + use_bias=self.qkv_bias, + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="e_qkv", + precision=self.precision, + ) + + self.proj_attn = nn.Dense( + self.query_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + use_bias=True, + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="i_proj", + precision=self.precision, + ) + + self.encoder_proj_attn = nn.Dense( + self.query_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + use_bias=True, + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="e_proj", + precision=self.precision, + ) + + self.query_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + self.key_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + + self.encoder_query_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + self.encoder_key_norm = nn.RMSNorm( + dtype=self.dtype, + scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), + param_dtype=self.weights_dtype, + ) + + def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) + + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + + return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + + def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): + + qkv_proj = self.qkv(hidden_states) + B, L = hidden_states.shape[:2] + H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 + qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) + query_proj, key_proj, value_proj = qkv_proj + + query_proj = self.query_norm(query_proj) + + key_proj = self.key_norm(key_proj) + + if encoder_hidden_states is not None: + + encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states) + B, L = encoder_hidden_states.shape[:2] + H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3 + encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) + encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj + + encoder_query_proj = self.encoder_query_norm(encoder_query_proj) + + encoder_key_proj = self.encoder_key_norm(encoder_key_proj) + + query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2) + key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2) + value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2) + + query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) + key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) + value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) + + image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) + query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb) + + query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1) + key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1) + value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1) + + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + context_attn_output = None + + if encoder_hidden_states is not None: + context_attn_output, attn_output = ( + attn_output[:, : encoder_hidden_states.shape[1]], + attn_output[:, encoder_hidden_states.shape[1] :], + ) + + attn_output = self.proj_attn(attn_output) + + context_attn_output = self.encoder_proj_attn(context_attn_output) + + return attn_output, context_attn_output class FlaxFluxAttention(nn.Module): query_dim: int diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index a42418f20..52a70d97d 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -73,9 +73,14 @@ class FlaxTimestepEmbedding(nn.Module): @nn.compact def __call__(self, temb): - temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_1")(temb) + temb = nn.Dense(self.time_embed_dim, + dtype=self.dtype, + param_dtype=self.weights_dtype, + name="linear_1")(temb) temb = nn.silu(temb) - temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_2")(temb) + temb = nn.Dense(self.time_embed_dim, + dtype=self.dtype, + param_dtype=self.weights_dtype, name="linear_2")(temb) return temb @@ -98,7 +103,6 @@ def __call__(self, timesteps): timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift ) - def get_1d_rotary_pos_embed( dim: int, pos: Union[jnp.array, int], theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0, freqs_dtype=jnp.float32 ): @@ -119,7 +123,6 @@ def get_1d_rotary_pos_embed( return out - class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -236,3 +239,36 @@ def __call__(self, timestep, guidance, pooled_projection): conditioning = time_guidance_emb + pooled_projections return conditioning + + +# class HFEmbedder(nnx.Module): + +# def __init__(self, version: str, max_length: int, **hf_kwargs): +# super().__init__() +# self.is_clip = version.split("/")[1].startswith("clip") +# self.max_length = max_length +# self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + +# if self.is_clip: +# self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version, max_length=max_length, use_fast=True) +# self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(version, **hf_kwargs) +# else: +# self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version, max_length=max_length, use_fast=True) +# self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(version, **hf_kwargs) + +# def __call__(self, text: list[str]): +# batch_encoding = self.tokenizer( +# text, +# truncation=True, +# max_length=self.max_length, +# return_length=False, +# return_overflowing_tokens=False, +# padding="max_length", +# return_tensors="np", +# ) +# outputs = self.hf_module( +# input_ids=batch_encoding["input_ids"], +# attention_mask=None, +# output_hidden_states=False, +# ) +# return outputs[self.output_key] \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 71fa03d0a..d9f398f6a 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -14,86 +14,341 @@ limitations under the License. """ -from typing import Dict, Optional, Tuple, Union +"""This script is used an example of how to shard the UNET on TPU.""" -from einops import repeat, rearrange -import numpy as np +from typing import Any, Dict, Optional, Tuple, Union import jax +import math import jax.numpy as jnp -import flax.linen as nn -from chex import Array - -from ..modules.layers import ( - timestep_embedding, - MLPEmbedder, - EmbedND, - DoubleStreamBlock, - SingleStreamBlock, - LastLayer, - PixArtAlphaTextProjection -) -from ...modeling_flax_utils import FlaxModelMixin +import flax +import flax.linen as nn +from jax.random import PRNGKey +from einops import repeat, rearrange from ....configuration_utils import ConfigMixin, flax_register_to_config +from ...modeling_flax_utils import FlaxModelMixin +from ...normalization_flax import AdaLayerNormZeroSingle, AdaLayerNormContinuous, AdaLayerNormZero +from ...attention_flax import FlaxFluxAttention +from ...embeddings_flax import (FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings) +from .... import common_types from ....common_types import BlockSizes +from .... import max_logging +from ....utils import BaseOutput +from dataclasses import dataclass + +AxisNames = common_types.AxisNames +BATCH = common_types.BATCH +LENGTH = common_types.LENGTH +HEAD = common_types.HEAD +D_KV = common_types.D_KV + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + param_dtype: jnp.bfloat16 + rngs: jax.random.PRNGKey + + +@flax.struct.dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`FluxTransformer2DModel`]. + + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: jnp.ndarray + + +class FluxSingleTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + dim: int + num_attention_heads: int + attention_head_dim: int + mlp_ratio: int = 4.0 + attention_kernel: str = "dot_product" + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + def setup(self): + self.mlp_hidden_dim = int(self.dim * self.mlp_ratio) + + self.norm = AdaLayerNormZeroSingle( + self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision + ) + + self.linear1 = nn.Dense( + self.dim * 3 + self.mlp_hidden_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + + self.mlp_act = nn.gelu + self.linear2 = nn.Dense( + self.dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + self.attn = FlaxFluxAttention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + attention_kernel=self.attention_kernel, + mesh=self.mesh, + flash_block_sizes=self.flash_block_sizes + ) + + def __call__(self, hidden_states, temb, image_rotary_emb=None): + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + qkv, mlp = jnp.split(self.linear1(norm_hidden_states), [3 * self.dim], axis=-1) + mlp = nn.with_logical_constraint(mlp, ("activation_batch", "activation_length", "activation_embed")) + qkv = nn.with_logical_constraint(qkv, ("activation_batch", "activation_length", "activation_embed")) + + B, L = hidden_states.shape[:2] + H, D, K = self.num_attention_heads, qkv.shape[-1] // (self.num_attention_heads * 3), 3 + qkv_proj = qkv.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) + q, k, v = qkv_proj + + q = self.attn.query_norm(q) + k = self.attn.key_norm(k) + + if image_rotary_emb is not None: + # since this function returns image_rotary_emb and passes it between layers, + # we do not want to modify it + image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) + q, k = self.attn.apply_rope(q, k, image_rotary_emb_reordered) + + q = q.transpose(0, 2, 1, 3).reshape(q.shape[0], q.shape[2], -1) + k = k.transpose(0, 2, 1, 3).reshape(k.shape[0], k.shape[2], -1) + v = v.transpose(0, 2, 1, 3).reshape(v.shape[0], v.shape[2], -1) + + attn_output = self.attn.attention_op.apply_attention(q, k, v) + + attn_mlp = jnp.concatenate([attn_output, self.mlp_act(mlp)], axis=2) + attn_mlp = nn.with_logical_constraint(attn_mlp, ("activation_batch", "activation_length", "activation_embed")) + hidden_states = self.linear2(attn_mlp) + hidden_states = gate * hidden_states + hidden_states = residual + hidden_states + if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16: + hidden_states = jnp.clip(hidden_states, -65504, 65504) + + return hidden_states, temb, image_rotary_emb + + +class FluxTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + dim: int + num_attention_heads: int + attention_head_dim: int + qk_norm: str = "rms_norm" + eps: int = 1e-6 + flash_min_seq_length: int = 4096 + flash_block_sizes: BlockSizes = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + mlp_ratio: float = 4.0 + qkv_bias: bool = False + attention_kernel: str = "dot_product" + + def setup(self): + + self.img_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) + self.txt_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) + + self.attn = FlaxFluxAttention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + qkv_bias=self.qkv_bias, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + attention_kernel=self.attention_kernel, + mesh=self.mesh, + flash_block_sizes=self.flash_block_sizes + ) + + self.img_norm2 = nn.LayerNorm( + use_bias=False, + use_scale=False, + epsilon=self.eps, + dtype=self.dtype, + param_dtype=self.weights_dtype, + ) + self.img_mlp = nn.Sequential( + [ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ] + ) + + self.txt_norm2 = nn.LayerNorm( + use_bias=False, + use_scale=False, + epsilon=self.eps, + dtype=self.dtype, + param_dtype=self.weights_dtype, + ) + self.txt_mlp = nn.Sequential( + [ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ] + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 -class Identity(nn.Module): - def __call__(self, x: Array) -> Array: - return x - -def scan_double_block_layers( - inner_dim, - num_heads, - mlp_ratio, - attention_head_dim, - flash_min_seq_length, - flash_block_sizes, - mesh, - dtype, - weights_dtype, - precision, - qkv_bias, - attention_kernel: str, - num_layers: int): - - scan_fn = nn.scan( - DoubleStreamBlock, - variable_broadcast='params', - in_axes=( - nn.broadcast, - nn.broadcast, - nn.broadcast - ), - out_axes=nn.broadcast, - split_rngs={'params' : False}, - length=num_layers - ) - return scan_fn( - hidden_size=inner_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - attention_head_dim=attention_head_dim, - flash_min_seq_length=flash_min_seq_length, - flash_block_sizes=flash_block_sizes, - mesh=mesh, - dtype=dtype, - weights_dtype=weights_dtype, - precision=precision, - qkv_bias=qkv_bias, - attention_kernel=attention_kernel) + def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.img_norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.txt_norm1( + encoder_hidden_states, emb=temb + ) + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.img_norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.img_mlp(norm_hidden_states) + ff_output = gate_mlp * ff_output + + hidden_states = hidden_states + ff_output + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.txt_norm2(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + context_ff_output = self.txt_mlp(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + return hidden_states, encoder_hidden_states, temb, image_rotary_emb + + +@flax_register_to_config class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" - The Tranformer model introduced in Flux. + The Tranformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). - This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods - implemented for all models (such as downloading or saving). + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its + general usage and behavior. + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. - This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its - general usage and behavior. """ + patch_size: int = 1 in_channels: int = 64 num_layers: int = 19 @@ -102,228 +357,206 @@ class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): num_attention_heads: int = 24 joint_attention_dim: int = 4096 pooled_projection_dim: int = 768 - mlp_ratio: int = 4 - qkv_bias: bool = True guidance_embeds: bool = False axes_dims_rope: Tuple[int] = (16, 56, 56) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None - attention_kernel: str = "dot_product" mesh: jax.sharding.Mesh = None dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 precision: jax.lax.Precision = None - - @nn.compact - def __call__( - self, - img: Array, - img_ids: Array, - txt: Array, - txt_ids: Array, - timesteps: Array, - y: Array, - guidance: Array | None = None, - return_dict: bool = True, - train: bool = False): - - out_channels = self.in_channels - inner_dim = self.num_attention_heads * self.attention_head_dim - pe_dim = inner_dim // self.num_attention_heads - - jax.debug.print("pooled_projections value min: {x}", x=np.min(y)) - jax.debug.print("pooled_projections value max: {x}", x=np.max(y)) - - img = nn.Dense( - inner_dim, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="img_in" - )(img) - jax.debug.print("img.min: {x}", x=np.min(img)) - jax.debug.print("img.max: {x}", x=np.max(img)) - timestep = timestep_embedding(timesteps, 256) - jax.debug.print("timestep.min: {x}", x=np.min(timestep)) - jax.debug.print("timestep.max: {x}", x=np.max(timestep)) - vec = MLPEmbedder( - hidden_dim=inner_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - name="time_in" - )(timestep) - jax.debug.print("timestep.vec min: {x}", x=np.min(vec)) - jax.debug.print("timestep.vec max: {x}", x=np.max(vec)) - print(f"guidance_embeds? {self.guidance_embeds}") - if self.guidance_embeds: - if guidance is None: - raise ValueError( - "Didn't get guidance strength for guidance distrilled model." - ) - guidance_in = timestep_embedding(guidance, 256) - - jax.debug.print("guidance_in.min: {x}", x=np.min(guidance_in)) - jax.debug.print("guidance_in.max: {x}", x=np.max(guidance_in)) - guidance_in = MLPEmbedder( - hidden_dim=inner_dim, + mlp_ratio: float = 4.0 + qkv_bias: bool = True + theta: int = 1000 + attention_kernel: str = "dot_product" + eps = 1e-6 + + def setup(self): + self.out_channels = self.in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pe_embedder = FluxPosEmbed(theta=self.theta, axes_dim=self.axes_dims_rope, dtype=self.dtype) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if self.guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, + pooled_projection_dim=self.pooled_projection_dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, - name="guidance_in" - )(guidance_in) - jax.debug.print("guidance.vec min: {x}", x=np.min(guidance_in)) - jax.debug.print("guidance.vec max: {x}", x=np.max(guidance_in)) - vec = vec + guidance_in - jax.debug.print("timestep_guidance.vec min: {x}", x=np.min(vec)) - jax.debug.print("timestep_guidance.vec max: {x}", x=np.max(vec)) - # else: - # guidance_in = Identity()(timestep_embedding(guidance, 256)) - - pooled_projections = PixArtAlphaTextProjection( - hidden_dim=inner_dim, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - name="vector_in" - )(y) - jax.debug.print("pooled_projections.min: {x}", x=np.min(pooled_projections)) - jax.debug.print("pooled_projections.max: {x}", x=np.max(pooled_projections)) - vec = vec + pooled_projections - jax.debug.print("temb.min: {x}", x=np.min(vec)) - jax.debug.print("temb.max: {x}", x=np.max(vec)) - - txt = nn.Dense( - inner_dim, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="txt_in" - )(txt) - jax.debug.print("txt.min: {x}", x=np.min(txt)) - jax.debug.print("txt.max: {x}", x=np.max(txt)) - ids = jnp.concatenate((txt_ids, img_ids), axis=1) - - #pe_embedder - pe = EmbedND( - dim=pe_dim, - theta=10000, - axes_dim=self.axes_dims_rope - )(ids) - jax.debug.print("pe.min: {x}", x=np.min(pe)) - jax.debug.print("pe.max: {x}", x=np.max(pe)) - # img, txt = DoubleStreamBlock( - # hidden_size=inner_dim, - # num_heads=self.num_attention_heads, - # mlp_ratio=self.mlp_ratio, - # attention_head_dim=self.attention_head_dim, - # flash_min_seq_length=self.flash_min_seq_length, - # flash_block_sizes=self.flash_block_sizes, - # mesh=self.mesh, - # dtype=self.dtype, - # weights_dtype=self.weights_dtype, - # precision=self.precision, - # qkv_bias=self.qkv_bias, - # attention_kernel=self.attention_kernel, - # name="double_blocks_0" - # )(img=img, txt=txt, vec=vec, pe=pe) - # # breakpoint() - for i in range(self.num_layers): - img, txt = DoubleStreamBlock( - hidden_size=inner_dim, - num_heads=self.num_attention_heads, - mlp_ratio=self.mlp_ratio, - attention_head_dim=self.attention_head_dim, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, + ) + self.txt_in = nn.Dense( + self.inner_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), dtype=self.dtype, - weights_dtype=self.weights_dtype, + param_dtype=self.weights_dtype, precision=self.precision, - qkv_bias=self.qkv_bias, - attention_kernel=self.attention_kernel, - name=f"double_blocks_{i}" - )(img=img, txt=txt, vec=vec, pe=pe) - # img, txt = nn.Sequential( - # [ - # *[ - # DoubleStreamBlock( - # hidden_size=inner_dim, - # num_heads=self.num_attention_heads, - # mlp_ratio=self.mlp_ratio, - # attention_head_dim=self.attention_head_dim, - # flash_min_seq_length=self.flash_min_seq_length, - # flash_block_sizes=self.flash_block_sizes, - # mesh=self.mesh, - # dtype=self.dtype, - # weights_dtype=self.weights_dtype, - # precision=self.precision, - # qkv_bias=self.qkv_bias, - # attention_kernel=self.attention_kernel, - # )(img=img, txt=txt, vec=vec, pe=pe) - # for _ in range(2) - # ] - # ] - # ) - # breakpoint() - # img, txt = scan_double_block_layers( - # inner_dim=inner_dim, - # num_heads=self.num_attention_heads, - # mlp_ratio=self.mlp_ratio, - # attention_head_dim=self.attention_head_dim, - # flash_min_seq_length=self.flash_min_seq_length, - # flash_block_sizes=self.flash_block_sizes, - # mesh=self.mesh, - # dtype=self.dtype, - # weights_dtype=self.weights_dtype, - # precision=self.precision, - # qkv_bias=self.qkv_bias, - # attention_kernel=self.attention_kernel, - # num_layers=self.num_layers - # )(img, txt, vec, pe) - img = jnp.concatenate((txt, img), axis=1) - for i in range(self.num_single_layers): - img, SingleStreamBlock( - hidden_size=inner_dim, - num_heads=self.num_attention_heads, - mlp_ratio=self.mlp_ratio, - attention_head_dim=self.attention_head_dim, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, + ) + self.img_in = nn.Dense( + self.inner_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), (None, "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + + self.double_blocks = nn.Sequential( + [ + *[ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + ) + for _ in range(self.num_layers) + ] + ] + ) + + self.single_blocks = nn.Sequential( + [ + *[ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + mlp_ratio=self.mlp_ratio, + ) + for _ in range(self.num_single_layers) + ] + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, + elementwise_affine=False, + eps=self.eps, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, - attention_kernel=self.attention_kernel, - name=f"single_blocks_{i}" - )(img, vec, pe) - - img = img[:, txt.shape[1] :, ...] - - img = LastLayer( - hidden_size=inner_dim, - patch_size=1, - out_channels=out_channels, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - name="final_layer" - )(img, vec) - - return img - + ) + + self.proj_out = nn.Dense( + self.patch_size**2 * self.out_channels, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", None)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + use_bias=True, + ) + + def timestep_embedding(self, t: jax.Array, dim: int, max_period=10000, time_factor: float = 1000.0) -> jax.Array: + """ + Generate timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + time_factor: Tensor of positional embeddings. + + Returns: + timestep embeddings. + """ + t = time_factor * t + half = dim // 2 + + freqs = jnp.exp(-math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.bfloat16) / half).astype(dtype=t.dtype) + + args = t[:, None].astype(jnp.bfloat16) * freqs[None] + embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) + + if dim % 2: + embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1) + + if jnp.issubdtype(t.dtype, jnp.floating): + embedding = embedding.astype(t.dtype) + + return embedding + + def __call__( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + img_ids, + txt_ids, + guidance, + return_dict: bool = True, + train: bool = False, + ): + hidden_states = self.img_in(hidden_states) + timestep = self.timestep_embedding(timestep, 256) + if self.guidance_embeds: + guidance = self.timestep_embedding(guidance, 256) + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + ids = jnp.concatenate((txt_ids, img_ids), axis=0) + ids = nn.with_logical_constraint(ids, ("activation_batch", None)) + image_rotary_emb = self.pe_embedder(ids) + image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed")) + + hidden_states, encoder_hidden_states, temb, image_rotary_emb = self.double_blocks( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1) + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) + + hidden_states, temb, image_rotary_emb = self.single_blocks( + hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb + ) + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + def init_weights(self, rngs, max_sequence_length, eval_only=True): scale_factor = 16 resolution = 1024 - num_devices = jax.local_device_count() + num_devices = len(jax.devices()) batch_size = 1 * num_devices batch_image_shape = ( batch_size, @@ -367,22 +600,22 @@ def init_weights(self, rngs, max_sequence_length, eval_only=True): return jax.eval_shape( self.init, rngs, - img=img, + hidden_states=img, img_ids=img_ids, - txt=txt, + encoder_hidden_states=txt, txt_ids=txt_ids, - y=vec, - timesteps=t_vec, + pooled_projections=vec, + timestep=t_vec, guidance=guidance_vec, )["params"] else: return self.init( rngs, - img=img, + hidden_states=img, img_ids=img_ids, - txt=txt, + encoder_hidden_states=txt, txt_ids=txt_ids, - y=vec, - timesteps=t_vec, + pooled_projections=vec, + timestep=t_vec, guidance=guidance_vec, )["params"] \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 59ff4509d..b6b6dc372 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -159,14 +159,33 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool for pt_key, tensor in tensors.items(): renamed_pt_key = rename_key(pt_key) if "double_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("double_blocks.", "double_blocks_") - renamed_pt_key = renamed_pt_key.replace("img_attn.", "img_attn_") - renamed_pt_key = renamed_pt_key.replace("img_mlp.", "img_mlp_") - renamed_pt_key = renamed_pt_key.replace("txt_attn.", "txt_attn_") - renamed_pt_key = renamed_pt_key.replace("txt_mlp.", "txt_mlp_") - - if "single_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("single_blocks.", "single_blocks_") + renamed_pt_key = renamed_pt_key.replace("double_blocks_", "double_blocks.layers_") + renamed_pt_key = renamed_pt_key.replace("img_mlp_", "img_mlp.layers_") + renamed_pt_key = renamed_pt_key.replace("txt_mlp_", "txt_mlp.layers_") + renamed_pt_key = renamed_pt_key.replace("img_mod", "img_norm1") + renamed_pt_key = renamed_pt_key.replace("txt_mod", "txt_norm1") + renamed_pt_key = renamed_pt_key.replace("img_attn.qkv", "attn.i_qkv") + renamed_pt_key = renamed_pt_key.replace("img_attn.proj", "attn.i_proj") + renamed_pt_key = renamed_pt_key.replace("img_attn.norm", "attn") + renamed_pt_key = renamed_pt_key.replace("txt_attn.qkv", "attn.e_qkv") + renamed_pt_key = renamed_pt_key.replace("txt_attn.proj", "attn.e_proj") + renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.key_norm", "attn.encoder_key_norm") + renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.query_norm", "attn.encoder_query_norm") + elif("guidance_in" in renamed_pt_key): + renamed_pt_key = renamed_pt_key.replace("guidance_in", "time_text_embed.FlaxTimestepEmbedding_1") + elif "single_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("single_blocks_", "single_blocks.layers_") + renamed_pt_key = renamed_pt_key.replace("modulation", "norm") + renamed_pt_key = renamed_pt_key.replace("norm.key_norm", "attn.key_norm") + renamed_pt_key = renamed_pt_key.replace("norm.query_norm", "attn.query_norm") + elif "vector_in" in renamed_pt_key or "time_in" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("vector_in", "time_text_embed.PixArtAlphaTextProjection_0") + renamed_pt_key = renamed_pt_key.replace("time_in", "time_text_embed.FlaxTimestepEmbedding_0") + renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") + renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") + elif "final_layer" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("final_layer.linear", "proj_out") + renamed_pt_key = renamed_pt_key.replace("final_layer.adaLN_modulation_1", "norm_out.Dense_0") pt_tuple_key = tuple(renamed_pt_key.split(".")) flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py index ea3b970d8..b91433144 100644 --- a/src/maxdiffusion/models/normalization_flax.py +++ b/src/maxdiffusion/models/normalization_flax.py @@ -146,4 +146,4 @@ def __call__(self, x, emb): ) else: raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") - return x, gate_msa + return x, gate_msa \ No newline at end of file From 1f1475d804f2c81a04bba77d67e06bb12556230f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 3 Feb 2025 22:51:04 +0000 Subject: [PATCH 17/35] removed unused code. --- src/maxdiffusion/configs/base_flux_dev.yml | 1 + .../configs/base_flux_schnell.yml | 1 + .../models/flux/modules/__init__.py | 15 - .../models/flux/modules/layers.py | 584 ------------------ .../transformers/transformer_flux_flax.py | 26 +- 5 files changed, 4 insertions(+), 623 deletions(-) delete mode 100644 src/maxdiffusion/models/flux/modules/__init__.py delete mode 100644 src/maxdiffusion/models/flux/modules/layers.py diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 987101672..cfb428913 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -130,6 +130,7 @@ logical_axis_rules: [ ['activation_batch', ['data','fsdp']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], + ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], ['conv_batch', ['data','fsdp']], diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index fd80c5f28..c30fa82e8 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -129,6 +129,7 @@ logical_axis_rules: [ ['activation_batch', ['data','fsdp']], ['activation_heads', 'tensor'], ['activation_kv', 'tensor'], + ['mlp','tensor'], ['embed','fsdp'], ['heads', 'tensor'], ['conv_batch', ['data','fsdp']], diff --git a/src/maxdiffusion/models/flux/modules/__init__.py b/src/maxdiffusion/models/flux/modules/__init__.py deleted file mode 100644 index 55bca151a..000000000 --- a/src/maxdiffusion/models/flux/modules/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/modules/layers.py b/src/maxdiffusion/models/flux/modules/layers.py deleted file mode 100644 index d4c3d5534..000000000 --- a/src/maxdiffusion/models/flux/modules/layers.py +++ /dev/null @@ -1,584 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import math -from dataclasses import dataclass -from einops import rearrange -import jax -import jax.numpy as jnp -from chex import Array -from jax.typing import DTypeLike -import flax.linen as nn -from ...attention_flax import AttentionOp -from .... import common_types - -BlockSizes = common_types.BlockSizes - -def rope(pos: Array, dim: int, theta: int) -> Array: - assert dim % 2 == 0 - scale = jnp.arange(0, dim, 2, dtype=jnp.float32) / dim - omega = 1.0 / (theta ** scale) - out = jnp.einsum("...n,d->...nd", pos, omega) - out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], axis=-1) - out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) - return out.astype(jnp.float32) - -def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: - xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype( - xk.dtype - ) - -class QKNorm(nn.Module): - dtype: DTypeLike = jnp.bfloat16 - weights_dtype: DTypeLike = jnp.bfloat16 - - @nn.compact - def __call__(self, q: Array, k: Array, v: Array) -> tuple[Array, Array]: - q = nn.RMSNorm( - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="query_norm" - )(q) - k = nn.RMSNorm( - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="key_norm" - )(k) - return q, k - -class EmbedND(nn.Module): - dim: int - theta: int - axes_dim: list[int] - - def __call__(self, ids: Array): - n_axes = ids.shape[-1] - emb = jnp.concatenate( - [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], axis=-3, - ) - - return jnp.expand_dims(emb, axis=1) - -def timestep_embedding( - t: Array, dim: int, max_period=10000, time_factor: float = 1000.0 -) -> Array: - """ - Generate timestep embeddings. - - Args: - t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - dim: the dimension of the output. - max_period: controls the minimum frequency of the embeddings. - time_factor: Tensor of positional embeddings. - - Returns: - timestep embeddings. - """ - t = time_factor * t - half = dim // 2 - - freqs = jnp.exp( - -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half - ).astype(dtype=t.dtype) - - args = t[:, None].astype(jnp.float32) * freqs[None] - embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) - - if dim % 2: - embedding = jnp.concatenate( - [embedding, jnp.zeros_like(embedding[:, :1])], axis=-1 - ) - - if jnp.issubdtype(t.dtype, jnp.floating): - embedding = embedding.astype(t.dtype) - - return embedding -import numpy as np -class PixArtAlphaTextProjection(nn.Module): - hidden_dim: int - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - - @nn.compact - def __call__(self, x: Array) -> Array: - - hidden_states = nn.Dense( - self.hidden_dim, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="in_layer" - )(x) - jax.debug.print("PixArtAlphaTextProjection, in_layer min: {x}", x=np.min(hidden_states)) - jax.debug.print("PixArtAlphaTextProjection, in_layer max: {x}", x=np.max(hidden_states)) - hidden_states = nn.swish(hidden_states) - jax.debug.print("PixArtAlphaTextProjection, act min: {x}", x=np.min(hidden_states)) - jax.debug.print("PixArtAlphaTextProjection, act max: {x}", x=np.max(hidden_states)) - hidden_states = nn.Dense( - self.hidden_dim, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("heads", "embed") - ), - name="out_layer" - )(hidden_states) - jax.debug.print("PixArtAlphaTextProjection, out min: {x}", x=np.min(hidden_states)) - jax.debug.print("PixArtAlphaTextProjection, out max: {x}", x=np.max(hidden_states)) - - return hidden_states - -class MLPEmbedder(nn.Module): - hidden_dim: int - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - - @nn.compact - def __call__(self, x: Array) -> Array: - - x = nn.Dense( - self.hidden_dim, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="in_layer" - )(x) - x = nn.silu(x) - x = nn.Dense( - self.hidden_dim, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("heads", "embed") - ), - name="out_layer" - )(x) - - return x - -@dataclass -class ModulationOut: - shift: Array - scale: Array - gate: Array - -class Modulation(nn.Module): - dim: int - double: bool - dtype: DTypeLike = jnp.bfloat16 - weights_dtype: DTypeLike = jnp.bfloat16 - precision: jax.lax.Precision = None - - @nn.compact - def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]: - multiplier = 6 if self.double else 3 - lin = nn.Dense( - multiplier * self.dim, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="lin" - )(nn.silu(vec)) - - out = jnp.split(lin[:, None, :], multiplier, axis=-1) - return ( - ModulationOut(*out[:3]), - ModulationOut(*out[3:]) if self.double else None - ) - -class SingleStreamBlock(nn.Module): - hidden_size: int - num_heads: int - mlp_ratio: float - qk_scale: float | None = None - attention_head_dim: int = 128 - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - attention_kernel: str = "dot_product" - - @nn.compact - def __call__(self, x: Array, vec: Array, pe: Array) -> Array: - mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) - - mod, _ = Modulation( - self.hidden_size, - double=False, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - name="modulation" - )(vec) - x_mod = (1 + mod.scale) * nn.LayerNorm( - use_scale=False, - use_bias=False, - epsilon=1e-6, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="pre_norm" - )(x) + mod.shift - - x_mod = nn.Dense( - self.hidden_size * 3 + mlp_hidden_dim, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="linear1" - )(x_mod) - - qkv, mlp = jnp.split(x_mod, [3 * self.hidden_size], axis=-1) - q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) - q, k = QKNorm( - dtype=self.dtype, - weights_dtype=self.weights_dtype, - name="norm" - )(q, k, v) - - q, k = apply_rope(q, k, pe) - #compute attention - attn = AttentionOp( - mesh=self.mesh, - attention_kernel=self.attention_kernel, - scale=self.attention_head_dim**-0.5, - heads=self.num_heads, - dim_head=self.attention_head_dim, - flash_min_seq_length=self.flash_min_seq_length, - use_memory_efficient_attention=False, - split_head_dim=True, - flash_block_sizes=self.flash_block_sizes, - dtype=self.dtype - ).apply_attention(q, k, v) - - output = nn.Dense( - self.hidden_size, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="linear2" - )(jnp.concatenate((attn, nn.gelu(mlp)), 2)) - return x + mod.gate * output - -class DoubleStreamBlock(nn.Module): - hidden_size: int - num_heads: int - mlp_ratio: float - attention_head_dim: int = 128 - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - qkv_bias: bool = False - attention_kernel: str = "dot_product" - - @nn.compact - def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array, Array]: - - mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) - - img_mod1, img_mod2 = Modulation( - self.hidden_size, - double=True, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - name="img_mod" - )(vec) - - txt_mod1, txt_mod2 = Modulation( - self.hidden_size, - double=True, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - name="txt_mod" - )(vec) - - # prepare image for attention - img_modulated = nn.LayerNorm( - use_scale=False, - use_bias=False, - epsilon=1e-6, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="img_norm1" - )(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift - img_qkv = nn.Dense( - self.hidden_size * 3, - use_bias=self.qkv_bias, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="img_attn_qkv" - )(img_modulated) - img_q, img_k, img_v = rearrange( - img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads - ) - img_q, img_k = QKNorm( - dtype=self.dtype, - weights_dtype=self.weights_dtype, - name="img_attn_norm" - )(img_q, img_k, img_v) - - # prepare text for attention - txt_modulated = nn.LayerNorm( - use_scale=False, - use_bias=False, - epsilon=1e-6, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="txt_norm1" - )(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift - txt_qkv = nn.Dense( - self.hidden_size * 3, - use_bias=self.qkv_bias, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="txt_attn_qkv" - )(txt_modulated) - txt_q, txt_k, txt_v = rearrange( - txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads - ) - txt_q, txt_k = QKNorm( - dtype=self.dtype, - weights_dtype=self.weights_dtype, - name="txt_attn_norm" - )(txt_q, txt_k, txt_v) - - # run actual attention - q = jnp.concatenate((txt_q, img_q), axis=2) - k = jnp.concatenate((txt_k, img_k), axis=2) - v = jnp.concatenate((txt_v, img_v), axis=2) - q, k = apply_rope(q, k, pe) - - attn = AttentionOp( - mesh=self.mesh, - attention_kernel=self.attention_kernel, - scale=self.attention_head_dim**-0.5, - heads=self.num_heads, - dim_head=self.attention_head_dim, - flash_min_seq_length=self.flash_min_seq_length, - use_memory_efficient_attention=False, - split_head_dim=True, - flash_block_sizes=self.flash_block_sizes, - dtype=self.dtype - ).apply_attention(q, k, v) - - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - - #calculate the img blocks - img = img + img_mod1.gate * nn.Dense( - self.hidden_size, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("heads", "embed") - ), - name="img_attn_proj" - )(img_attn) - img = img + img_mod2.gate * nn.Sequential( - [ - nn.Dense( - mlp_hidden_dim, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="img_mlp_0" - ), - nn.gelu, - nn.Dense( - self.hidden_size, - use_bias=True, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("heads", "embed") - ), - name="img_mlp_2" - ), - ], - )( - (1 + img_mod2.scale) * nn.LayerNorm( - use_scale=False, - use_bias=False, - param_dtype=self.weights_dtype, - name="img_norm2" - )(img) + img_mod2.shift - ) - - # calculate the txt blocks - txt_proj = nn.Dense( - self.hidden_size, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("heads", "embed") - ), - name="txt_attn_proj" - )(txt_attn) - txt = txt + txt_mod1.gate * txt_proj - - txt = txt + txt_mod2.gate * nn.Sequential( - [ - nn.Dense( - mlp_hidden_dim, - use_bias=True, - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="txt_mlp_0" - ), - nn.gelu, - nn.Dense( - self.hidden_size, - use_bias=True, - param_dtype=self.weights_dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("heads", "embed") - ), - name="txt_mlp_2" - ), - ], - )( - (1 + txt_mod2.scale) * nn.LayerNorm( - use_scale=False, - use_bias=False, - param_dtype=self.weights_dtype, - name="txt_norm2" - )(txt) + txt_mod2.shift - ) - - return img, txt - -class LastLayer(nn.Module): - hidden_size: int - patch_size: int - out_channels: int - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - precision: jax.lax.Precision = None - - @nn.compact - def __call__(self, x: Array, vec: Array) -> Array: - shift, scale = jnp.split( - nn.Sequential( - [ - nn.silu, - nn.Dense( - 2 * self.hidden_size, - use_bias=True, - param_dtype=self.weights_dtype, - dtype=self.dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("embed", "heads") - ), - name="adaLN_modulation_1" - ), - ] - )(vec), 2, axis=1 - ) - norm_final = nn.LayerNorm( - epsilon=1e-6, - use_scale=False, - use_bias=False, - param_dtype=self.weights_dtype, - name="norm_final" - )(x) - x = (1 + scale[:, None, :]) * norm_final + shift[:, None, :] - x = nn.Dense( - self.patch_size * self.patch_size * self.out_channels, - use_bias=True, - param_dtype=self.weights_dtype, - dtype=self.dtype, - precision=self.precision, - kernel_init=nn.with_logical_partitioning( - nn.initializers.lecun_normal(), - ("heads", "embed") - ), - name="linear" - )(x) - return x \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index d9f398f6a..8f293b814 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -14,15 +14,12 @@ limitations under the License. """ -"""This script is used an example of how to shard the UNET on TPU.""" - -from typing import Any, Dict, Optional, Tuple, Union +from typing import Tuple import jax import math import jax.numpy as jnp import flax import flax.linen as nn -from jax.random import PRNGKey from einops import repeat, rearrange from ....configuration_utils import ConfigMixin, flax_register_to_config from ...modeling_flax_utils import FlaxModelMixin @@ -41,25 +38,6 @@ HEAD = common_types.HEAD D_KV = common_types.D_KV - -@dataclass -class FluxParams: - in_channels: int - vec_in_dim: int - context_in_dim: int - hidden_size: int - mlp_ratio: float - num_heads: int - depth: int - depth_single_blocks: int - axes_dim: list[int] - theta: int - qkv_bias: bool - guidance_embed: bool - param_dtype: jnp.bfloat16 - rngs: jax.random.PRNGKey - - @flax.struct.dataclass class Transformer2DModelOutput(BaseOutput): """ @@ -580,7 +558,7 @@ def init_weights(self, rngs, max_sequence_length, eval_only=True): 768, # Sequence length of clip, how to get this programmatically? ) img = jnp.zeros(batch_image_shape, dtype=self.dtype) - bs, c, h, w = img.shape + bs, _, h, w = img.shape img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) img_ids = jnp.zeros((h // 2, w // 2, 3), dtype=self.dtype) img_ids = img_ids.at[..., 1].set(jnp.arange(h // 2)[:, None]) From b49695a998907a394a918c8b626ee31dc84f7a1d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 3 Feb 2025 23:32:12 +0000 Subject: [PATCH 18/35] support dev --- src/maxdiffusion/models/flux/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index b6b6dc372..b84e3e6fc 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -173,6 +173,8 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.query_norm", "attn.encoder_query_norm") elif("guidance_in" in renamed_pt_key): renamed_pt_key = renamed_pt_key.replace("guidance_in", "time_text_embed.FlaxTimestepEmbedding_1") + renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") + renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") elif "single_blocks" in renamed_pt_key: renamed_pt_key = renamed_pt_key.replace("single_blocks_", "single_blocks.layers_") renamed_pt_key = renamed_pt_key.replace("modulation", "norm") From 04377df7857ad5919309268cf4b35b78229f8615 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 4 Feb 2025 00:10:24 +0000 Subject: [PATCH 19/35] add sentencepiece requirement --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e5ac624e4..9d2f6f338 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,5 @@ orbax-checkpoint==0.10.2 tokenizers==0.20.0 huggingface_hub==0.24.7 transformers==4.48.1 -einops==0.8.0 \ No newline at end of file +einops==0.8.0 +sentencepiece \ No newline at end of file From f6c25e43ee4bfb4a997d874dbba1e70814ad0a9d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 4 Feb 2025 03:07:30 +0000 Subject: [PATCH 20/35] fix repeated double and single blocks. --- src/maxdiffusion/configs/base_flux_dev.yml | 11 +- .../configs/base_flux_schnell.yml | 11 +- .../transformers/transformer_flux_flax.py | 100 +++++++++--------- src/maxdiffusion/models/flux/util.py | 3 - 4 files changed, 50 insertions(+), 75 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index cfb428913..8430a1a8d 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -55,16 +55,7 @@ precision: "DEFAULT" from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: { - "block_q" : 128, - "block_kv" : 128, - "block_kv_compute" : 128, - "block_q_dkv" : 128, - "block_kv_dkv" : 128, - "block_kv_dkv_compute" : 128, - "block_q_dq" : 128, - "block_kv_dq" : 128 -} +flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index c30fa82e8..2a2f177fc 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -54,16 +54,7 @@ precision: "DEFAULT" from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: { - "block_q" : 128, - "block_kv" : 128, - "block_kv_compute" : 128, - "block_q_dkv" : 128, - "block_kv_dkv" : 128, - "block_kv_dkv_compute" : 128, - "block_q_dq" : 128, - "block_kv_dq" : 128 -} +flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 8f293b814..0dbdf66de 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -383,48 +383,43 @@ def setup(self): precision=self.precision, ) - self.double_blocks = nn.Sequential( - [ - *[ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - mlp_ratio=self.mlp_ratio, - qkv_bias=self.qkv_bias, - ) - for _ in range(self.num_layers) - ] - ] - ) - - self.single_blocks = nn.Sequential( - [ - *[ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - mlp_ratio=self.mlp_ratio, - ) - for _ in range(self.num_single_layers) - ] - ] - ) + double_blocks = [] + for _ in range(self.num_layers): + double_block = FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + ) + double_blocks.append(double_block) + self.double_blocks = double_blocks + + single_blocks = [] + for _ in range(self.num_single_layers): + single_block = FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + mlp_ratio=self.mlp_ratio, + ) + single_blocks.append(single_block) + + self.single_blocks = single_blocks self.norm_out = AdaLayerNormContinuous( self.inner_dim, @@ -509,18 +504,19 @@ def __call__( image_rotary_emb = self.pe_embedder(ids) image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed")) - hidden_states, encoder_hidden_states, temb, image_rotary_emb = self.double_blocks( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + for double_block in self.double_blocks: + hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1) hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) - - hidden_states, temb, image_rotary_emb = self.single_blocks( - hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb - ) + for single_block in self.single_blocks: + hidden_states, temb, image_rotary_emb = single_block( + hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb + ) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.norm_out(hidden_states, temb) diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index b84e3e6fc..7fae4250a 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -159,7 +159,6 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool for pt_key, tensor in tensors.items(): renamed_pt_key = rename_key(pt_key) if "double_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("double_blocks_", "double_blocks.layers_") renamed_pt_key = renamed_pt_key.replace("img_mlp_", "img_mlp.layers_") renamed_pt_key = renamed_pt_key.replace("txt_mlp_", "txt_mlp.layers_") renamed_pt_key = renamed_pt_key.replace("img_mod", "img_norm1") @@ -176,7 +175,6 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") elif "single_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("single_blocks_", "single_blocks.layers_") renamed_pt_key = renamed_pt_key.replace("modulation", "norm") renamed_pt_key = renamed_pt_key.replace("norm.key_norm", "attn.key_norm") renamed_pt_key = renamed_pt_key.replace("norm.query_norm", "attn.query_norm") @@ -188,7 +186,6 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool elif "final_layer" in renamed_pt_key: renamed_pt_key = renamed_pt_key.replace("final_layer.linear", "proj_out") renamed_pt_key = renamed_pt_key.replace("final_layer.adaLN_modulation_1", "norm_out.Dense_0") - pt_tuple_key = tuple(renamed_pt_key.split(".")) flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) From ff24ee156411adf6295d7b8645131fa2b739cac8 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 4 Feb 2025 04:33:15 +0000 Subject: [PATCH 21/35] optimized flash block sizes for trillium. --- src/maxdiffusion/configs/base_flux_dev.yml | 16 ++++++++++++++-- .../configs/base_flux_schnell.yml | 19 +++++++++++++++---- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 8430a1a8d..67b7307f9 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -55,7 +55,19 @@ precision: "DEFAULT" from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash + flash_block_sizes: {} +# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. +# flash_block_sizes: { +# "block_q" : 1536, +# "block_kv_compute" : 1536, +# "block_kv" : 1536, +# "block_q_dkv" : 1536, +# "block_kv_dkv" : 1536, +# "block_kv_dkv_compute" : 1536, +# "block_q_dq" : 1536, +# "block_kv_dq" : 1536 +# } # GroupNorm groups norm_num_groups: 32 @@ -137,8 +149,8 @@ data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 dcn_tensor_parallelism: 1 -ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 # Dataset diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 2a2f177fc..4c22edb73 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -55,6 +55,17 @@ from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash flash_block_sizes: {} +# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. +# flash_block_sizes: { +# "block_q" : 1536, +# "block_kv_compute" : 1536, +# "block_kv" : 1536, +# "block_q_dkv" : 1536, +# "block_kv_dkv" : 1536, +# "block_kv_dkv_compute" : 1536, +# "block_q_dq" : 1536, +# "block_kv_dq" : 1536 +# } # GroupNorm groups norm_num_groups: 32 @@ -133,11 +144,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: -1 +dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 dcn_tensor_parallelism: 1 -ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 # Dataset From 18250c5161c65e1ee332812738d6817f361d4e7c Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 4 Feb 2025 18:17:43 +0000 Subject: [PATCH 22/35] clean up code and lint --- requirements.txt | 2 +- .../base_stable_diffusion_checkpointer.py | 5 +- .../configs/base_flux_schnell.yml | 30 +- src/maxdiffusion/generate_flux.py | 385 ++++++++---------- src/maxdiffusion/models/attention_flax.py | 171 -------- src/maxdiffusion/models/embeddings_flax.py | 45 +- src/maxdiffusion/models/flux/__init__.py | 2 +- src/maxdiffusion/models/flux/port.py | 223 ---------- .../models/flux/transformers/__init__.py | 2 +- .../transformers/transformer_flux_flax.py | 55 +-- src/maxdiffusion/models/flux/util.py | 199 +++++---- src/maxdiffusion/models/modeling_utils.py | 1 + src/maxdiffusion/models/normalization_flax.py | 2 +- src/maxdiffusion/tests/text_encoders_test.py | 24 +- src/maxdiffusion/tests/vae_test.py | 21 +- 15 files changed, 349 insertions(+), 818 deletions(-) delete mode 100644 src/maxdiffusion/models/flux/port.py diff --git a/requirements.txt b/requirements.txt index 9d2f6f338..09babb20f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ ruff>=0.1.5,<=0.2 git+https://github.com/mlperf/logging.git opencv-python-headless==4.10.0.84 orbax-checkpoint==0.10.2 -tokenizers==0.20.0 +tokenizers==0.21.0 huggingface_hub==0.24.7 transformers==4.48.1 einops==0.8.0 diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index 92c7605d5..a7b597e36 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -336,10 +336,7 @@ def load_checkpoint(self, step=None, scheduler_class=None): if self.checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT: te_pretrained_2_config = CLIPTextConfig(**model_configs[0]["text_encoder_2_config"]) text_encoder_2 = FlaxCLIPTextModelWithProjection( - te_pretrained_2_config, - seed=self.config.seed, - dtype=self.config.activations_dtype, - _do_init=False + te_pretrained_2_config, seed=self.config.seed, dtype=self.config.activations_dtype, _do_init=False ) pipeline_kwargs["text_encoder_2"] = text_encoder_2 # both tokenizers in sdxl are the same. diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 4c22edb73..ee9db566a 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -54,17 +54,27 @@ precision: "DEFAULT" from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: {} -# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. +flash_block_sizes: { + "block_q" : 256, + "block_kv_compute" : 256, + "block_kv" : 256, + "block_q_dkv" : 256, + "block_kv_dkv" : 256, + "block_kv_dkv_compute" : 256, + "block_q_dq" : 256, + "block_kv_dq" : 256 +} + +# Use the following flash_block_sizes on v6e (Trillium). # flash_block_sizes: { -# "block_q" : 1536, -# "block_kv_compute" : 1536, -# "block_kv" : 1536, -# "block_q_dkv" : 1536, -# "block_kv_dkv" : 1536, -# "block_kv_dkv_compute" : 1536, -# "block_q_dq" : 1536, -# "block_kv_dq" : 1536 +# "block_q" : 2176, +# "block_kv_compute" : 2176, +# "block_kv" : 2176, +# "block_q_dkv" : 2176, +# "block_kv_dkv" : 2176, +# "block_kv_dkv_compute" : 2176, +# "block_q_dq" : 2176, +# "block_kv_dq" : 2176 # } # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 606b15051..482aac539 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -27,114 +27,107 @@ from chex import Array from einops import rearrange from flax.linen import partitioning as nn_partitioning -from transformers import ( - CLIPTokenizer, - FlaxCLIPTextModel, - T5EncoderModel, - FlaxT5EncoderModel, - AutoTokenizer -) +from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel from max_utils import ( - device_put_replicated, - get_memory_allocations, - create_device_mesh, - get_flash_block_sizes, - get_precision, - setup_initial_state + device_put_replicated, + get_memory_allocations, + create_device_mesh, + get_flash_block_sizes, + get_precision, + setup_initial_state, ) + def unpack(x: Array, height: int, width: int) -> Array: - return rearrange( + return rearrange( x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2, - ) + ) + + from einops import rearrange + + def vae_decode(latents, vae, state, config): img = unpack(x=latents, height=config.resolution, width=config.resolution) img = img / vae.config.scaling_factor + vae.config.shift_factor img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample return img + def loop_body( - step, - args, - transformer, - latent_image_ids, - prompt_embeds, - txt_ids, - vec, - guidance_vec, + step, + args, + transformer, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, ): latents, state, c_ts, p_ts = args latents_dtype = latents.dtype t_curr = c_ts[step] t_prev = p_ts[step] - t_vec = jnp.full((latents.shape[0], ), t_curr, dtype=latents.dtype) + t_vec = jnp.full((latents.shape[0],), t_curr, dtype=latents.dtype) pred = transformer.apply( - {"params" : state.params}, - hidden_states=latents, - img_ids=latent_image_ids, - encoder_hidden_states=prompt_embeds, - txt_ids=txt_ids, - timestep=t_vec, - guidance=guidance_vec, - pooled_projections=vec + {"params": state.params}, + hidden_states=latents, + img_ids=latent_image_ids, + encoder_hidden_states=prompt_embeds, + txt_ids=txt_ids, + timestep=t_vec, + guidance=guidance_vec, + pooled_projections=vec, ).sample latents = latents + (t_prev - t_curr) * pred latents = jnp.array(latents, dtype=latents_dtype) return latents, state, c_ts, p_ts + def prepare_latent_image_ids(height, width): latent_image_ids = jnp.zeros((height, width, 3)) - latent_image_ids = latent_image_ids.at[..., 1].set( - latent_image_ids[..., 1] + jnp.arange(height)[:, None] - ) - latent_image_ids = latent_image_ids.at[..., 2].set( - latent_image_ids[..., 2] + jnp.arange(width)[None, :] - ) + latent_image_ids = latent_image_ids.at[..., 1].set(latent_image_ids[..., 1] + jnp.arange(height)[:, None]) + latent_image_ids = latent_image_ids.at[..., 2].set(latent_image_ids[..., 2] + jnp.arange(width)[None, :]) latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) + latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) return latent_image_ids.astype(jnp.bfloat16) + def time_shift(mu: float, sigma: float, t: Array): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) -def get_lin_function( - x1: float = 256, - y1: float = 0.5, - x2: float = 4096, - y2: float = 1.15 -) -> Callable[[float], float]: + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b + def run_inference( - states, - transformer, - vae, - config, - mesh, - latents, - latent_image_ids, - prompt_embeds, - txt_ids, - vec, - guidance_vec, + states, + transformer, + vae, + config, + mesh, + latents, + latent_image_ids, + prompt_embeds, + txt_ids, + vec, + guidance_vec, ): - + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) # shifting the schedule to favor high timesteps for higher signal images if config.time_shift: @@ -147,18 +140,17 @@ def run_inference( # jax.debug.print("c_ts: {x}", x=c_ts) # jax.debug.print("p_ts: {x}", x=p_ts) - transformer_state = states["transformer"] vae_state = states["vae"] loop_body_p = functools.partial( - loop_body, - transformer=transformer, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=txt_ids, - vec=vec, - guidance_vec=guidance_vec, + loop_body, + transformer=transformer, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=txt_ids, + vec=vec, + guidance_vec=guidance_vec, ) vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config) @@ -170,11 +162,11 @@ def run_inference( def pack_latents( - latents: Array, - batch_size: int, - num_channels_latents: int, - height: int, - width: int, + latents: Array, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, ): latents = jnp.reshape(latents, (batch_size, num_channels_latents, height // 2, 2, width // 2, 2)) latents = jnp.permute_dims(latents, (0, 2, 4, 1, 3, 5)) @@ -182,14 +174,9 @@ def pack_latents( return latents + def prepare_latents( - batch_size: int, - num_channels_latents: int, - height: int, - width: int, - vae_scale_factor: int, - dtype: jnp.dtype, - rng: Array + batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array ): # VAE applies 8x compression on images but we must also account for packing which @@ -208,22 +195,20 @@ def prepare_latents( return latents, latent_image_ids + def get_clip_prompt_embeds( - prompt: Union[str, List[str]], - num_images_per_prompt : int, - tokenizer: CLIPTokenizer, - text_encoder : FlaxCLIPTextModel + prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="np" + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="np", ) text_input_ids = text_inputs.input_ids @@ -234,31 +219,29 @@ def get_clip_prompt_embeds( prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1)) return prompt_embeds + def get_t5_prompt_embeds( - prompt: Union[str, List[str]], - num_images_per_prompt: int, - tokenizer: AutoTokenizer, - text_encoder: T5EncoderModel, - max_sequence_length: int = 512 + prompt: Union[str, List[str]], + num_images_per_prompt: int, + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel, + max_sequence_length: int = 512, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( - prompt, - truncation=True, - max_length=max_sequence_length, - return_length=False, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="np" + prompt, + truncation=True, + max_length=max_sequence_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="np", ) text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder( - text_input_ids, - attention_mask=None, - output_hidden_states=False)["last_hidden_state"] + prompt_embeds = text_encoder(text_input_ids, attention_mask=None, output_hidden_states=False)["last_hidden_state"] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.astype(dtype) _, seq_len, _ = prompt_embeds.shape @@ -270,38 +253,36 @@ def get_t5_prompt_embeds( def encode_prompt( - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - clip_tokenizer: CLIPTokenizer, - clip_text_encoder: FlaxCLIPTextModel, - t5_tokenizer: AutoTokenizer, - t5_text_encoder: T5EncoderModel, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512 + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + clip_tokenizer: CLIPTokenizer, + clip_text_encoder: FlaxCLIPTextModel, + t5_tokenizer: AutoTokenizer, + t5_text_encoder: T5EncoderModel, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, ): - + prompt = [prompt] if isinstance(prompt, str) else prompt prompt_2 = prompt or prompt_2 prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 pooled_prompt_embeds = get_clip_prompt_embeds( - prompt=prompt, - num_images_per_prompt=num_images_per_prompt, - tokenizer=clip_tokenizer, - text_encoder=clip_text_encoder + prompt=prompt, num_images_per_prompt=num_images_per_prompt, tokenizer=clip_tokenizer, text_encoder=clip_text_encoder ) prompt_embeds = get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - tokenizer=t5_tokenizer, - text_encoder=t5_text_encoder, - max_sequence_length=max_sequence_length + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + tokenizer=t5_tokenizer, + text_encoder=t5_text_encoder, + max_sequence_length=max_sequence_length, ) text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) return prompt_embeds, pooled_prompt_embeds, text_ids + def run(config): from maxdiffusion.models.flux.util import load_flow_model @@ -314,22 +295,18 @@ def run(config): # LOAD VAE vae, vae_params = FlaxAutoencoderKL.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="vae", - from_pt=True, - use_safetensors=True, - dtype="bfloat16" + config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" ) weights_init_fn = functools.partial(vae.init_weights, rng=rng) vae_state, vae_state_shardings = setup_initial_state( - model=vae, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=vae_params, - training=False, + model=vae, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=vae_params, + training=False, ) vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) @@ -337,46 +314,40 @@ def run(config): # LOAD TRANSFORMER flash_block_sizes = get_flash_block_sizes(config) transformer = FluxTransformer2DModel.from_config( - config.pretrained_model_name_or_path, - subfolder="transformer", - mesh=mesh, - split_head_dim=config.split_head_dim, - attention_kernel=config.attention, - flash_block_sizes=flash_block_sizes, - dtype=config.activations_dtype, - weights_dtype=config.weights_dtype, - precision=get_precision(config) + config.pretrained_model_name_or_path, + subfolder="transformer", + mesh=mesh, + split_head_dim=config.split_head_dim, + attention_kernel=config.attention, + flash_block_sizes=flash_block_sizes, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, + precision=get_precision(config), ) - + num_channels_latents = transformer.in_channels // 4 latents, latent_image_ids = prepare_latents( - batch_size=global_batch_size, - num_channels_latents=num_channels_latents, - height=config.resolution, - width=config.resolution, - dtype=jnp.bfloat16, - vae_scale_factor=vae_scale_factor, - rng=rng + batch_size=global_batch_size, + num_channels_latents=num_channels_latents, + height=config.resolution, + width=config.resolution, + dtype=jnp.bfloat16, + vae_scale_factor=vae_scale_factor, + rng=rng, ) # LOAD TEXT ENCODERS clip_text_encoder = FlaxCLIPTextModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder", - from_pt=True, - dtype=config.weights_dtype + config.pretrained_model_name_or_path, subfolder="text_encoder", from_pt=True, dtype=config.weights_dtype ) clip_tokenizer = CLIPTokenizer.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="tokenizer", - dtype=config.weights_dtype + config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype ) - t5_encoder = FlaxT5EncoderModel.from_pretrained( - config.t5xxl_model_name_or_path, - dtype=config.weights_dtype + t5_encoder = FlaxT5EncoderModel.from_pretrained(config.t5xxl_model_name_or_path, dtype=config.weights_dtype) + t5_tokenizer = AutoTokenizer.from_pretrained( + config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True ) - t5_tokenizer = AutoTokenizer.from_pretrained(config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True) encoders_sharding = PositionalSharding(devices_array).replicate() partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) @@ -386,14 +357,14 @@ def run(config): t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - prompt=config.prompt, - prompt_2=config.prompt_2, - clip_tokenizer=clip_tokenizer, - clip_text_encoder=clip_text_encoder, - t5_tokenizer=t5_tokenizer, - t5_text_encoder=t5_encoder, - num_images_per_prompt=global_batch_size, - max_sequence_length=config.max_sequence_length + prompt=config.prompt, + prompt_2=config.prompt_2, + clip_tokenizer=clip_tokenizer, + clip_text_encoder=clip_text_encoder, + t5_tokenizer=t5_tokenizer, + t5_text_encoder=t5_encoder, + num_images_per_prompt=global_batch_size, + max_sequence_length=config.max_sequence_length, ) def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds): @@ -404,19 +375,11 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep print("timesteps.shape: ", timesteps.shape, timesteps.dtype) print("guidance.shape: ", guidance.shape, guidance.dtype) print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) - + timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16) guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) - - validate_inputs( - latents, - latent_image_ids, - prompt_embeds, - text_ids, - timesteps, - guidance, - pooled_prompt_embeds - ) + + validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) # move inputs to device and shard data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) @@ -434,20 +397,24 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep get_memory_allocations() # evaluate shapes - transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True) - + transformer_eval_params = transformer.init_weights( + rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=True + ) + # loads pretrained weights transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") # create transformer state - weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False) + weights_init_fn = functools.partial( + transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False + ) transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, ) transformer_state = transformer_state.replace(params=transformer_params) transformer_state = jax.device_put(transformer_state, transformer_state_shardings) @@ -463,21 +430,21 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep states["vae"] = vae_state p_run_inference = jax.jit( - functools.partial( - run_inference, - transformer=transformer, - vae=vae, - config=config, - mesh=mesh, - latents=latents, - latent_image_ids=latent_image_ids, - prompt_embeds=prompt_embeds, - txt_ids=text_ids, - vec=pooled_prompt_embeds, - guidance_vec=guidance, - ), - in_shardings=(state_shardings,), - out_shardings=None, + functools.partial( + run_inference, + transformer=transformer, + vae=vae, + config=config, + mesh=mesh, + latents=latents, + latent_image_ids=latent_image_ids, + prompt_embeds=prompt_embeds, + txt_ids=text_ids, + vec=pooled_prompt_embeds, + guidance_vec=guidance, + ), + in_shardings=(state_shardings,), + out_shardings=None, ) t0 = time.perf_counter() p_run_inference(states).block_until_ready() @@ -508,4 +475,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) \ No newline at end of file + app.run(main) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 2ac2182fc..884b6d688 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -322,177 +322,6 @@ def chunk_scanner(chunk_idx, _): return jnp.concatenate(res, axis=-3) # fuse the chunked result back -class FlaxFluxAttention(nn.Module): - query_dim: int - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - use_memory_efficient_attention: bool = False - split_head_dim: bool = False - attention_kernel: str = "dot_product" - flash_min_seq_length: int = 4096 - flash_block_sizes: BlockSizes = None - mesh: jax.sharding.Mesh = None - dtype: jnp.dtype = jnp.float32 - weights_dtype: jnp.dtype = jnp.float32 - query_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - key_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) - out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) - precision: jax.lax.Precision = None - qkv_bias: bool = False - - def setup(self): - if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: - raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") - inner_dim = self.dim_head * self.heads - scale = self.dim_head**-0.5 - - self.attention_op = AttentionOp( - mesh=self.mesh, - attention_kernel=self.attention_kernel, - scale=scale, - heads=self.heads, - dim_head=self.dim_head, - flash_min_seq_length=self.flash_min_seq_length, - use_memory_efficient_attention=self.use_memory_efficient_attention, - split_head_dim=self.split_head_dim, - flash_block_sizes=self.flash_block_sizes, - dtype=self.dtype, - float32_qk_product=False, - ) - - kernel_axes = ("embed", "heads") - qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes) - - self.qkv = nn.Dense( - inner_dim * 3, - kernel_init=qkv_init_kernel, - use_bias=self.qkv_bias, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="i_qkv", - precision=self.precision, - ) - - self.encoder_qkv = nn.Dense( - inner_dim * 3, - kernel_init=qkv_init_kernel, - use_bias=self.qkv_bias, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="e_qkv", - precision=self.precision, - ) - - self.proj_attn = nn.Dense( - self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), - use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="i_proj", - precision=self.precision, - ) - - self.encoder_proj_attn = nn.Dense( - self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), - use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="e_proj", - precision=self.precision, - ) - - self.query_norm = nn.RMSNorm( - dtype=self.dtype, - scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), - param_dtype=self.weights_dtype, - ) - self.key_norm = nn.RMSNorm( - dtype=self.dtype, - scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), - param_dtype=self.weights_dtype, - ) - - self.encoder_query_norm = nn.RMSNorm( - dtype=self.dtype, - scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), - param_dtype=self.weights_dtype, - ) - self.encoder_key_norm = nn.RMSNorm( - dtype=self.dtype, - scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)), - param_dtype=self.weights_dtype, - ) - - def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: - xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) - - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - - return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) - - def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): - - qkv_proj = self.qkv(hidden_states) - B, L = hidden_states.shape[:2] - H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 - qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) - query_proj, key_proj, value_proj = qkv_proj - - query_proj = self.query_norm(query_proj) - - key_proj = self.key_norm(key_proj) - - if encoder_hidden_states is not None: - - encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states) - B, L = encoder_hidden_states.shape[:2] - H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3 - encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) - encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj - - encoder_query_proj = self.encoder_query_norm(encoder_query_proj) - - encoder_key_proj = self.encoder_key_norm(encoder_key_proj) - - query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2) - key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2) - value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2) - - query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) - key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) - value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) - - image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) - query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb) - - query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1) - key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1) - value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1) - - attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - context_attn_output = None - - if encoder_hidden_states is not None: - context_attn_output, attn_output = ( - attn_output[:, : encoder_hidden_states.shape[1]], - attn_output[:, encoder_hidden_states.shape[1] :], - ) - - attn_output = self.proj_attn(attn_output) - - context_attn_output = self.encoder_proj_attn(context_attn_output) - - return attn_output, context_attn_output class FlaxFluxAttention(nn.Module): query_dim: int diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 52a70d97d..42ca4b950 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -56,6 +56,7 @@ def get_sinusoidal_embeddings( signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal + class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. @@ -73,14 +74,9 @@ class FlaxTimestepEmbedding(nn.Module): @nn.compact def __call__(self, temb): - temb = nn.Dense(self.time_embed_dim, - dtype=self.dtype, - param_dtype=self.weights_dtype, - name="linear_1")(temb) + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_1")(temb) temb = nn.silu(temb) - temb = nn.Dense(self.time_embed_dim, - dtype=self.dtype, - param_dtype=self.weights_dtype, name="linear_2")(temb) + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, param_dtype=self.weights_dtype, name="linear_2")(temb) return temb @@ -103,6 +99,7 @@ def __call__(self, timesteps): timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift ) + def get_1d_rotary_pos_embed( dim: int, pos: Union[jnp.array, int], theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0, freqs_dtype=jnp.float32 ): @@ -123,6 +120,7 @@ def get_1d_rotary_pos_embed( return out + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -239,36 +237,3 @@ def __call__(self, timestep, guidance, pooled_projection): conditioning = time_guidance_emb + pooled_projections return conditioning - - -# class HFEmbedder(nnx.Module): - -# def __init__(self, version: str, max_length: int, **hf_kwargs): -# super().__init__() -# self.is_clip = version.split("/")[1].startswith("clip") -# self.max_length = max_length -# self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" - -# if self.is_clip: -# self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version, max_length=max_length, use_fast=True) -# self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(version, **hf_kwargs) -# else: -# self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version, max_length=max_length, use_fast=True) -# self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(version, **hf_kwargs) - -# def __call__(self, text: list[str]): -# batch_encoding = self.tokenizer( -# text, -# truncation=True, -# max_length=self.max_length, -# return_length=False, -# return_overflowing_tokens=False, -# padding="max_length", -# return_tensors="np", -# ) -# outputs = self.hf_module( -# input_ids=batch_encoding["input_ids"], -# attention_mask=None, -# output_hidden_states=False, -# ) -# return outputs[self.output_key] \ No newline at end of file diff --git a/src/maxdiffusion/models/flux/__init__.py b/src/maxdiffusion/models/flux/__init__.py index 6d7590d6d..84dd0f150 100644 --- a/src/maxdiffusion/models/flux/__init__.py +++ b/src/maxdiffusion/models/flux/__init__.py @@ -14,4 +14,4 @@ limitations under the License. """ -from .transformers.transformer_flux_flax import FluxTransformer2DModel \ No newline at end of file +from .transformers.transformer_flux_flax import FluxTransformer2DModel diff --git a/src/maxdiffusion/models/flux/port.py b/src/maxdiffusion/models/flux/port.py deleted file mode 100644 index 1e9744ed3..000000000 --- a/src/maxdiffusion/models/flux/port.py +++ /dev/null @@ -1,223 +0,0 @@ -from einops import rearrange - -############################################################################################## -# FLUX MODEL PORTING -############################################################################################## - - -def port_linear(linear, tensors, prefix): - linear.kernel.value = rearrange(tensors[f"{prefix}.weight"], "i o -> o i") - linear.bias.value = tensors[f"{prefix}.bias"] - return linear - - -def port_modulation(modulation, tensors, prefix): - modulation.lin = port_linear( - linear=modulation.lin, tensors=tensors, prefix=f"{prefix}.lin" - ) - return modulation - - -def port_rms_norm(rms_norm, tensors, prefix): - rms_norm.scale.value = tensors[f"{prefix}.scale"] - return rms_norm - - -def port_qk_norm(qk_norm, tensors, prefix): - qk_norm.query_norm = port_rms_norm( - rms_norm=qk_norm.query_norm, - tensors=tensors, - prefix=f"{prefix}.query_norm", - ) - qk_norm.key_norm = port_rms_norm( - rms_norm=qk_norm.key_norm, - tensors=tensors, - prefix=f"{prefix}.key_norm", - ) - return qk_norm - - -def port_self_attention(self_attention, tensors, prefix): - self_attention.qkv = port_linear( - linear=self_attention.qkv, - tensors=tensors, - prefix=f"{prefix}.qkv", - ) - - self_attention.norm = port_qk_norm( - qk_norm=self_attention.norm, - tensors=tensors, - prefix=f"{prefix}.norm", - ) - - self_attention.proj = port_linear( - linear=self_attention.proj, - tensors=tensors, - prefix=f"{prefix}.proj", - ) - - return self_attention - - -def port_double_stream_block(double_stream_block, tensors, prefix): - double_stream_block.img_mod = port_modulation( - modulation=double_stream_block.img_mod, - tensors=tensors, - prefix=f"{prefix}.img_mod", - ) - - # double_stream_block.img_norm1 has no params - - double_stream_block.img_attn = port_self_attention( - self_attention=double_stream_block.img_attn, - tensors=tensors, - prefix=f"{prefix}.img_attn", - ) - - # double_stream_block.img_norm2 has no params - - double_stream_block.img_mlp.layers[0] = port_linear( - linear=double_stream_block.img_mlp.layers[0], - tensors=tensors, - prefix=f"{prefix}.img_mlp.0", - ) - double_stream_block.img_mlp.layers[2] = port_linear( - linear=double_stream_block.img_mlp.layers[2], - tensors=tensors, - prefix=f"{prefix}.img_mlp.2", - ) - - double_stream_block.txt_mod = port_modulation( - modulation=double_stream_block.txt_mod, - tensors=tensors, - prefix=f"{prefix}.txt_mod", - ) - - # double_stream_block.txt_norm1 has no params - - double_stream_block.txt_attn = port_self_attention( - self_attention=double_stream_block.txt_attn, - tensors=tensors, - prefix=f"{prefix}.txt_attn", - ) - - # double_stream_block.txt_norm2 has no params - - double_stream_block.txt_mlp.layers[0] = port_linear( - linear=double_stream_block.txt_mlp.layers[0], - tensors=tensors, - prefix=f"{prefix}.txt_mlp.0", - ) - double_stream_block.txt_mlp.layers[2] = port_linear( - linear=double_stream_block.txt_mlp.layers[2], - tensors=tensors, - prefix=f"{prefix}.txt_mlp.2", - ) - - return double_stream_block - - -def port_single_stream_block(single_stream_block, tensors, prefix): - single_stream_block.linear1 = port_linear( - linear=single_stream_block.linear1, tensors=tensors, prefix=f"{prefix}.linear1" - ) - single_stream_block.linear2 = port_linear( - linear=single_stream_block.linear2, tensors=tensors, prefix=f"{prefix}.linear2" - ) - - single_stream_block.norm = port_qk_norm( - qk_norm=single_stream_block.norm, tensors=tensors, prefix=f"{prefix}.norm" - ) - - # single_stream_block.pre_norm has no params - - single_stream_block.modulation = port_modulation( - modulation=single_stream_block.modulation, - tensors=tensors, - prefix=f"{prefix}.modulation", - ) - - return single_stream_block - - -def port_mlp_embedder(mlp_embedder, tensors, prefix): - mlp_embedder.in_layer = port_linear( - linear=mlp_embedder.in_layer, tensors=tensors, prefix=f"{prefix}.in_layer" - ) - - mlp_embedder.out_layer = port_linear( - linear=mlp_embedder.out_layer, tensors=tensors, prefix=f"{prefix}.out_layer" - ) - return mlp_embedder - - -def port_final_layer(final_layer, tensors, prefix): - # last_layer.norm_final has no params - final_layer.linear = port_linear( - linear=final_layer.linear, - tensors=tensors, - prefix=f"{prefix}.linear", - ) - - final_layer.adaLN_modulation.layers[1] = port_linear( - linear=final_layer.adaLN_modulation.layers[1], - tensors=tensors, - prefix=f"{prefix}.adaLN_modulation.1", - ) - - return final_layer - - -def port_flux(flux, tensors): - flux.img_in = port_linear( - linear=flux.img_in, - tensors=tensors, - prefix="img_in", - ) - - flux.time_in = port_mlp_embedder( - mlp_embedder=flux.time_in, - tensors=tensors, - prefix="time_in", - ) - - flux.vector_in = port_mlp_embedder( - mlp_embedder=flux.vector_in, - tensors=tensors, - prefix="vector_in", - ) - - if flux.params.guidance_embed: - flux.guidance_in = port_mlp_embedder( - mlp_embedder=flux.guidance_in, - tensors=tensors, - prefix="guidance_in", - ) - - flux.txt_in = port_linear( - linear=flux.txt_in, - tensors=tensors, - prefix="txt_in", - ) - - for i, layer in enumerate(flux.double_blocks.layers): - layer = port_double_stream_block( - double_stream_block=layer, - tensors=tensors, - prefix=f"double_blocks.{i}", - ) - - for i, layer in enumerate(flux.single_blocks.layers): - layer = port_single_stream_block( - single_stream_block=layer, - tensors=tensors, - prefix=f"single_blocks.{i}", - ) - - flux.final_layer = port_final_layer( - final_layer=flux.final_layer, - tensors=tensors, - prefix="final_layer", - ) - - return flux diff --git a/src/maxdiffusion/models/flux/transformers/__init__.py b/src/maxdiffusion/models/flux/transformers/__init__.py index 55bca151a..7e4185f36 100644 --- a/src/maxdiffusion/models/flux/transformers/__init__.py +++ b/src/maxdiffusion/models/flux/transformers/__init__.py @@ -12,4 +12,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - """ \ No newline at end of file + """ diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 0dbdf66de..8425d3990 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -38,6 +38,7 @@ HEAD = common_types.HEAD D_KV = common_types.D_KV + @flax.struct.dataclass class Transformer2DModelOutput(BaseOutput): """ @@ -110,7 +111,7 @@ def setup(self): weights_dtype=self.weights_dtype, attention_kernel=self.attention_kernel, mesh=self.mesh, - flash_block_sizes=self.flash_block_sizes + flash_block_sizes=self.flash_block_sizes, ) def __call__(self, hidden_states, temb, image_rotary_emb=None): @@ -194,7 +195,7 @@ def setup(self): weights_dtype=self.weights_dtype, attention_kernel=self.attention_kernel, mesh=self.mesh, - flash_block_sizes=self.flash_block_sizes + flash_block_sizes=self.flash_block_sizes, ) self.img_norm2 = nn.LayerNorm( @@ -386,18 +387,18 @@ def setup(self): double_blocks = [] for _ in range(self.num_layers): double_block = FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - mlp_ratio=self.mlp_ratio, - qkv_bias=self.qkv_bias, + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, ) double_blocks.append(double_block) self.double_blocks = double_blocks @@ -405,20 +406,20 @@ def setup(self): single_blocks = [] for _ in range(self.num_single_layers): single_block = FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - mlp_ratio=self.mlp_ratio, + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + attention_kernel=self.attention_kernel, + flash_min_seq_length=self.flash_min_seq_length, + flash_block_sizes=self.flash_block_sizes, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, + mlp_ratio=self.mlp_ratio, ) single_blocks.append(single_block) - + self.single_blocks = single_blocks self.norm_out = AdaLayerNormContinuous( @@ -592,4 +593,4 @@ def init_weights(self, rngs, max_sequence_length, eval_only=True): pooled_projections=vec, timestep=t_vec, guidance=guidance_vec, - )["params"] \ No newline at end of file + )["params"] diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 7fae4250a..ef5fd5b47 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -1,4 +1,3 @@ - # copied from https://github.com/ml-gde/jflux/blob/main/jflux/util.py import os from dataclasses import dataclass @@ -12,49 +11,48 @@ from jax import numpy as jnp from safetensors import safe_open -from maxdiffusion.models.modeling_flax_pytorch_utils import ( - rename_key, - rename_key_and_reshape_tensor -) +from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor) from maxdiffusion import max_logging + @dataclass class FluxParams: - in_channels: int - vec_in_dim: int - context_in_dim: int - hidden_size: int - mlp_ratio: float - num_heads: int - depth: int - depth_single_blocks: int - axes_dim: list[int] - theta: int - qkv_bias: bool - guidance_embed: bool - rngs: Array - param_dtype: DTypeLike + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + rngs: Array + param_dtype: DTypeLike + def torch2jax(torch_tensor: torch.Tensor) -> Array: - is_bfloat16 = torch_tensor.dtype == torch.bfloat16 - if is_bfloat16: - # upcast the tensor to fp32 - torch_tensor = torch_tensor.float() + is_bfloat16 = torch_tensor.dtype == torch.bfloat16 + if is_bfloat16: + # upcast the tensor to fp32 + torch_tensor = torch_tensor.float() - if torch.device.type != "cpu": - torch_tensor = torch_tensor.to("cpu") + if torch.device.type != "cpu": + torch_tensor = torch_tensor.to("cpu") - numpy_value = torch_tensor.numpy() - jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) - return jax_array + numpy_value = torch_tensor.numpy() + jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None) + return jax_array @dataclass class ModelSpec: - params: FluxParams - ckpt_path: str | None - repo_id: str | None - repo_flow: str | None + params: FluxParams + ckpt_path: str | None + repo_id: str | None + repo_flow: str | None configs = { @@ -104,14 +102,15 @@ class ModelSpec: def print_load_warning(missing: list[str], unexpected: list[str]) -> None: - if len(missing) > 0 and len(unexpected) > 0: - max_logging.log(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) - max_logging.log("\n" + "-" * 79 + "\n") - max_logging.log(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) - elif len(missing) > 0: - max_logging.log(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) - elif len(unexpected) > 0: - max_logging.log(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + if len(missing) > 0 and len(unexpected) > 0: + max_logging.log(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + max_logging.log("\n" + "-" * 79 + "\n") + max_logging.log(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + max_logging.log(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + max_logging.log(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): """ @@ -129,68 +128,66 @@ def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): try: expected_pytree_shape = expected_pytree[key].shape except: - expected_pytree_shape= expected_pytree[key].value.shape + expected_pytree_shape = expected_pytree[key].value.shape if expected_pytree_shape != new_pytree[key].shape: - max_logging.log(f"shape mismatch, expected shape of {expected_pytree[key].shape}, but got shape of {new_pytree[key].shape}") + max_logging.log( + f"shape mismatch, expected shape of {expected_pytree[key].shape}, but got shape of {new_pytree[key].shape}" + ) else: max_logging.log(f"key: {key} not found...") -def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool = True): # -> Flux: - device = jax.devices(device)[0] - with jax.default_device(device): - ckpt_path = configs[name].ckpt_path - if ( - ckpt_path is None - and configs[name].repo_id is not None - and configs[name].repo_flow is not None - and hf_download - ): - ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) - - max_logging.log(f"Load and port flux on {device}") - - if ckpt_path is not None: - tensors = {} - with safe_open(ckpt_path, framework="pt") as f: - for k in f.keys(): - tensors[k] = torch2jax(f.get_tensor(k)) - flax_state_dict = {} - cpu = jax.local_devices(backend="cpu")[0] - for pt_key, tensor in tensors.items(): - renamed_pt_key = rename_key(pt_key) - if "double_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("img_mlp_", "img_mlp.layers_") - renamed_pt_key = renamed_pt_key.replace("txt_mlp_", "txt_mlp.layers_") - renamed_pt_key = renamed_pt_key.replace("img_mod", "img_norm1") - renamed_pt_key = renamed_pt_key.replace("txt_mod", "txt_norm1") - renamed_pt_key = renamed_pt_key.replace("img_attn.qkv", "attn.i_qkv") - renamed_pt_key = renamed_pt_key.replace("img_attn.proj", "attn.i_proj") - renamed_pt_key = renamed_pt_key.replace("img_attn.norm", "attn") - renamed_pt_key = renamed_pt_key.replace("txt_attn.qkv", "attn.e_qkv") - renamed_pt_key = renamed_pt_key.replace("txt_attn.proj", "attn.e_proj") - renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.key_norm", "attn.encoder_key_norm") - renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.query_norm", "attn.encoder_query_norm") - elif("guidance_in" in renamed_pt_key): - renamed_pt_key = renamed_pt_key.replace("guidance_in", "time_text_embed.FlaxTimestepEmbedding_1") - renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") - renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") - elif "single_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("modulation", "norm") - renamed_pt_key = renamed_pt_key.replace("norm.key_norm", "attn.key_norm") - renamed_pt_key = renamed_pt_key.replace("norm.query_norm", "attn.query_norm") - elif "vector_in" in renamed_pt_key or "time_in" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("vector_in", "time_text_embed.PixArtAlphaTextProjection_0") - renamed_pt_key = renamed_pt_key.replace("time_in", "time_text_embed.FlaxTimestepEmbedding_0") - renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") - renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") - elif "final_layer" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("final_layer.linear", "proj_out") - renamed_pt_key = renamed_pt_key.replace("final_layer.adaLN_modulation_1", "norm_out.Dense_0") - pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) - flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) - validate_flax_state_dict(eval_shapes, flax_state_dict) - flax_state_dict = unflatten_dict(flax_state_dict) - del tensors - jax.clear_caches() - return flax_state_dict + +def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool = True): # -> Flux: + device = jax.devices(device)[0] + with jax.default_device(device): + ckpt_path = configs[name].ckpt_path + if ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download: + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + max_logging.log(f"Load and port flux on {device}") + + if ckpt_path is not None: + tensors = {} + with safe_open(ckpt_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + if "double_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("img_mlp_", "img_mlp.layers_") + renamed_pt_key = renamed_pt_key.replace("txt_mlp_", "txt_mlp.layers_") + renamed_pt_key = renamed_pt_key.replace("img_mod", "img_norm1") + renamed_pt_key = renamed_pt_key.replace("txt_mod", "txt_norm1") + renamed_pt_key = renamed_pt_key.replace("img_attn.qkv", "attn.i_qkv") + renamed_pt_key = renamed_pt_key.replace("img_attn.proj", "attn.i_proj") + renamed_pt_key = renamed_pt_key.replace("img_attn.norm", "attn") + renamed_pt_key = renamed_pt_key.replace("txt_attn.qkv", "attn.e_qkv") + renamed_pt_key = renamed_pt_key.replace("txt_attn.proj", "attn.e_proj") + renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.key_norm", "attn.encoder_key_norm") + renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.query_norm", "attn.encoder_query_norm") + elif "guidance_in" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("guidance_in", "time_text_embed.FlaxTimestepEmbedding_1") + renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") + renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") + elif "single_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("modulation", "norm") + renamed_pt_key = renamed_pt_key.replace("norm.key_norm", "attn.key_norm") + renamed_pt_key = renamed_pt_key.replace("norm.query_norm", "attn.query_norm") + elif "vector_in" in renamed_pt_key or "time_in" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("vector_in", "time_text_embed.PixArtAlphaTextProjection_0") + renamed_pt_key = renamed_pt_key.replace("time_in", "time_text_embed.FlaxTimestepEmbedding_0") + renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") + renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") + elif "final_layer" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("final_layer.linear", "proj_out") + renamed_pt_key = renamed_pt_key.replace("final_layer.adaLN_modulation_1", "norm_out.Dense_0") + pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict diff --git a/src/maxdiffusion/models/modeling_utils.py b/src/maxdiffusion/models/modeling_utils.py index 8d0ffe5e4..3bf54107f 100644 --- a/src/maxdiffusion/models/modeling_utils.py +++ b/src/maxdiffusion/models/modeling_utils.py @@ -109,6 +109,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ return torch.load(checkpoint_file, map_location="cpu") else: from safetensors import torch as safetensors_torch + return safetensors_torch.load_file(checkpoint_file, device="cpu") except Exception as e: try: diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py index b91433144..ea3b970d8 100644 --- a/src/maxdiffusion/models/normalization_flax.py +++ b/src/maxdiffusion/models/normalization_flax.py @@ -146,4 +146,4 @@ def __call__(self, x, emb): ) else: raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") - return x, gate_msa \ No newline at end of file + return x, gate_msa diff --git a/src/maxdiffusion/tests/text_encoders_test.py b/src/maxdiffusion/tests/text_encoders_test.py index cf4ba0c1b..65888ca09 100644 --- a/src/maxdiffusion/tests/text_encoders_test.py +++ b/src/maxdiffusion/tests/text_encoders_test.py @@ -25,22 +25,23 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + class TextEncoderTest(unittest.TestCase): """Test text encoders""" def setUp(self): TextEncoderTest.dummy_data = {} - + def test_flux_t5_text_encoder(self): text_encoder_2_pt = T5EncoderModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="text_encoder_2", + "black-forest-labs/FLUX.1-dev", + subfolder="text_encoder_2", ) tokenizer_2 = T5TokenizerFast.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="tokenizer_2", + "black-forest-labs/FLUX.1-dev", + subfolder="tokenizer_2", ) embeds = get_t5_prompt_embeds("A dog on a skateboard", 2, tokenizer_2, text_encoder_2_pt) @@ -50,17 +51,8 @@ def test_flux_t5_text_encoder(self): def test_flux_clip_text_encoder(self): text_encoder = FlaxCLIPTextModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="text_encoder", - from_pt=True, - dtype="bfloat16" - ) - tokenizer = CLIPTokenizer.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="tokenizer", - dtype="bfloat16" + "black-forest-labs/FLUX.1-dev", subfolder="text_encoder", from_pt=True, dtype="bfloat16" ) + tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer", dtype="bfloat16") embeds = get_clip_prompt_embeds("A cat riding a skateboard", 2, tokenizer, text_encoder) assert embeds.shape == (2, 768) - - diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index 8370ff5a3..7673642a8 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -28,14 +28,15 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + class VaeTest(unittest.TestCase): """Test Vae""" def setUp(self): VaeTest.dummy_data = {} - + def test_flux_vae(self): - + img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) img_min = np.min(base_image) @@ -43,17 +44,13 @@ def test_flux_vae(self): image = (base_image - img_min) / (img_max - img_min) image = 2.0 * image - 1.0 image = np.expand_dims(image, 0) - image = np.transpose(image, (0, 3, 1, 2)) # (1, 3, 1024, 1024), BCWH - + image = np.transpose(image, (0, 3, 1, 2)) # (1, 3, 1024, 1024), BCWH + vae, vae_params = FlaxAutoencoderKL.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="vae", - from_pt=True, - use_safetensors=True, - dtype="bfloat16" + "black-forest-labs/FLUX.1-dev", subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" ) - encoded_image = vae.apply({"params" : vae_params}, image, deterministic=True, method=vae.encode) + encoded_image = vae.apply({"params": vae_params}, image, deterministic=True, method=vae.encode) latents = encoded_image[0].sample(jax.random.key(0)) latents = jnp.transpose(latents, (0, 3, 1, 2)) @@ -63,12 +60,10 @@ def test_flux_vae(self): # decode back latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor - image = vae.apply({"params" : vae_params}, latents, deterministic=True, method=vae.decode).sample[0] + image = vae.apply({"params": vae_params}, latents, deterministic=True, method=vae.decode).sample[0] image = np.array(image) image = (image * 0.5 + 0.5).clip(0, 1) image = np.transpose(image, (1, 2, 0)) image = np.uint8(image * 255) ssim_compare = ssim(base_image, image, multichannel=True, channel_axis=-1, data_range=255) assert ssim_compare >= 0.90 - - From 303e82ab18c517e36bce7999c3a644403ab28b56 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 5 Feb 2025 02:10:32 +0000 Subject: [PATCH 23/35] fix sdxl generate smoke tests. --- src/maxdiffusion/generate_flux.py | 5 +---- src/maxdiffusion/max_utils.py | 5 ++++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 482aac539..6361502d7 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -31,7 +31,7 @@ from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel -from max_utils import ( +from maxdiffusion.max_utils import ( device_put_replicated, get_memory_allocations, create_device_mesh, @@ -52,9 +52,6 @@ def unpack(x: Array, height: int, width: int) -> Array: ) -from einops import rearrange - - def vae_decode(latents, vae, state, config): img = unpack(x=latents, height=config.resolution, width=config.resolution) img = img / vae.config.scaling_factor + vae.config.shift_factor diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 93acb03e7..a75876d81 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -46,7 +46,10 @@ from flax.linen import partitioning as nn_partitioning from flax.training import train_state from jax.experimental import mesh_utils -from transformers import (FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel) +from transformers import ( + FlaxCLIPTextModel, + FlaxCLIPTextPreTrainedModel +) from flax import struct from typing import ( Callable, From 5df1f3ca891021a26e7e3c7d1b2f0bd061b9a488 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 5 Feb 2025 17:08:26 +0000 Subject: [PATCH 24/35] fix rest of unit tests. --- src/maxdiffusion/max_utils.py | 5 +---- src/maxdiffusion/tests/text_encoders_test.py | 23 +++++++++++--------- src/maxdiffusion/tests/vae_test.py | 7 ++++++ 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index a75876d81..93acb03e7 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -46,10 +46,7 @@ from flax.linen import partitioning as nn_partitioning from flax.training import train_state from jax.experimental import mesh_utils -from transformers import ( - FlaxCLIPTextModel, - FlaxCLIPTextPreTrainedModel -) +from transformers import (FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel) from flax import struct from typing import ( Callable, diff --git a/src/maxdiffusion/tests/text_encoders_test.py b/src/maxdiffusion/tests/text_encoders_test.py index 65888ca09..c6a463010 100644 --- a/src/maxdiffusion/tests/text_encoders_test.py +++ b/src/maxdiffusion/tests/text_encoders_test.py @@ -16,13 +16,16 @@ import os import unittest +import pytest +import jax.numpy as jnp from absl.testing import absltest from transformers import CLIPTokenizer, FlaxCLIPTextModel -from transformers import T5TokenizerFast, T5EncoderModel +from transformers import T5TokenizerFast, FlaxT5EncoderModel from ..generate_flux import get_clip_prompt_embeds, get_t5_prompt_embeds +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -32,22 +35,18 @@ class TextEncoderTest(unittest.TestCase): def setUp(self): TextEncoderTest.dummy_data = {} + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_t5_text_encoder(self): - text_encoder_2_pt = T5EncoderModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="text_encoder_2", - ) + text_encoder = FlaxT5EncoderModel.from_pretrained("ariG23498/t5-v1-1-xxl-flax") - tokenizer_2 = T5TokenizerFast.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="tokenizer_2", - ) + tokenizer_2 = T5TokenizerFast.from_pretrained("ariG23498/t5-v1-1-xxl-flax") - embeds = get_t5_prompt_embeds("A dog on a skateboard", 2, tokenizer_2, text_encoder_2_pt) + embeds = get_t5_prompt_embeds("A dog on a skateboard", 2, tokenizer_2, text_encoder) assert embeds.shape == (2, 512, 4096) + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_clip_text_encoder(self): text_encoder = FlaxCLIPTextModel.from_pretrained( @@ -56,3 +55,7 @@ def test_flux_clip_text_encoder(self): tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer", dtype="bfloat16") embeds = get_clip_prompt_embeds("A cat riding a skateboard", 2, tokenizer, text_encoder) assert embeds.shape == (2, 768) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index 7673642a8..0200f78e2 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -16,6 +16,7 @@ import os import unittest +import pytest from absl.testing import absltest import numpy as np @@ -27,6 +28,7 @@ from skimage.metrics import structural_similarity as ssim THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" class VaeTest(unittest.TestCase): @@ -35,6 +37,7 @@ class VaeTest(unittest.TestCase): def setUp(self): VaeTest.dummy_data = {} + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_vae(self): img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png") @@ -67,3 +70,7 @@ def test_flux_vae(self): image = np.uint8(image * 255) ssim_compare = ssim(base_image, image, multichannel=True, channel_axis=-1, data_range=255) assert ssim_compare >= 0.90 + + +if __name__ == "__main__": + absltest.main() From 1ec459df40330036693a9b7f67825c366e060ced Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 5 Feb 2025 19:36:27 +0000 Subject: [PATCH 25/35] update readme and some dependencies. --- README.md | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index f9be0f3fc..cdf451050 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? -- **`2025/02/12`**: Flux LoRA for inference. -- **`2025/02/08`**: Flux schnell & dev inference. +- **`2025/02/08**: Flux schnell & dev inference. - **`2024/12/12`**: Load multiple LoRAs for inference. - **`2024/10/22`**: LoRA support for Hyper SDXL. - **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. @@ -48,8 +47,7 @@ MaxDiffusion supports * [Training](#training) * [Dreambooth](#dreambooth) * [Inference](#inference) - * [Flux](#flux) - * [Flux LoRA](#flux-lora) + * [Flux](#flux) * [Hyper-SD XL LoRA](#hyper-sdxl-lora) * [Load Multiple LoRA](#load-multiple-lora) * [SDXL Lightning](#sdxl-lightning) @@ -171,24 +169,6 @@ To generate images, run the following command: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` - ## Flux LoRA - - Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know. - - Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection. - - First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows: - - ```bash - python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}' - ``` - - Loading multiple LoRAs is supported as follows: - - ```bash - python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}' - ``` - ## Hyper SDXL LoRA Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD) From 1f28cb584fef81d6bc8f0d203060250d15fb3351 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 5 Feb 2025 19:40:45 +0000 Subject: [PATCH 26/35] remove unused dependencies. --- src/maxdiffusion/generate_flux.py | 2 +- .../models/flux/transformers/transformer_flux_flax.py | 2 -- src/maxdiffusion/models/flux/util.py | 2 +- src/maxdiffusion/tests/text_encoders_test.py | 1 - src/maxdiffusion/tests/vae_test.py | 1 - 5 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 6361502d7..fb21384a6 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -14,7 +14,7 @@ limitations under the License. """ -from typing import Any, Callable, Dict, List, Optional, Union, Sequence +from typing import Callable, List, Union, Sequence from absl import app import functools import math diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 8425d3990..bff07988a 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -28,9 +28,7 @@ from ...embeddings_flax import (FluxPosEmbed, CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings) from .... import common_types from ....common_types import BlockSizes -from .... import max_logging from ....utils import BaseOutput -from dataclasses import dataclass AxisNames = common_types.AxisNames BATCH = common_types.BATCH diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index ef5fd5b47..362a39171 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -127,7 +127,7 @@ def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict): if key in new_pytree.keys(): try: expected_pytree_shape = expected_pytree[key].shape - except: + except Exception: expected_pytree_shape = expected_pytree[key].value.shape if expected_pytree_shape != new_pytree[key].shape: max_logging.log( diff --git a/src/maxdiffusion/tests/text_encoders_test.py b/src/maxdiffusion/tests/text_encoders_test.py index c6a463010..e7d3d6ddd 100644 --- a/src/maxdiffusion/tests/text_encoders_test.py +++ b/src/maxdiffusion/tests/text_encoders_test.py @@ -17,7 +17,6 @@ import os import unittest import pytest -import jax.numpy as jnp from absl.testing import absltest from transformers import CLIPTokenizer, FlaxCLIPTextModel diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index 0200f78e2..cf7fb399d 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -24,7 +24,6 @@ import jax import jax.numpy as jnp from maxdiffusion import FlaxAutoencoderKL -from maxdiffusion.image_processor import VaeImageProcessor from skimage.metrics import structural_similarity as ssim THIS_DIR = os.path.dirname(os.path.abspath(__file__)) From ff16ba6d43f28646661d44a1eb4be2815c6c9904 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 6 Feb 2025 23:49:53 +0000 Subject: [PATCH 27/35] initial lora implementation for flux --- .../configs/base_flux_schnell.yml | 30 +++------ src/maxdiffusion/generate_flux.py | 66 +++++++++++++++---- src/maxdiffusion/loaders/__init__.py | 2 +- .../loaders/flux_lora_pipeline.py | 27 ++++---- .../models/modeling_flax_pytorch_utils.py | 45 +++++++++++++ 5 files changed, 122 insertions(+), 48 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index ee9db566a..4c22edb73 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -54,27 +54,17 @@ precision: "DEFAULT" from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: { - "block_q" : 256, - "block_kv_compute" : 256, - "block_kv" : 256, - "block_q_dkv" : 256, - "block_kv_dkv" : 256, - "block_kv_dkv_compute" : 256, - "block_q_dq" : 256, - "block_kv_dq" : 256 -} - -# Use the following flash_block_sizes on v6e (Trillium). +flash_block_sizes: {} +# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. # flash_block_sizes: { -# "block_q" : 2176, -# "block_kv_compute" : 2176, -# "block_kv" : 2176, -# "block_q_dkv" : 2176, -# "block_kv_dkv" : 2176, -# "block_kv_dkv_compute" : 2176, -# "block_q_dq" : 2176, -# "block_kv_dq" : 2176 +# "block_q" : 1536, +# "block_kv_compute" : 1536, +# "block_kv" : 1536, +# "block_q_dkv" : 1536, +# "block_kv_dkv" : 1536, +# "block_kv_dkv_compute" : 1536, +# "block_q_dq" : 1536, +# "block_kv_dq" : 1536 # } # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index fb21384a6..fc9184575 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -16,6 +16,7 @@ from typing import Callable, List, Union, Sequence from absl import app +from contextlib import ExitStack import functools import math import time @@ -24,6 +25,7 @@ import jax from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P import jax.numpy as jnp +import flax.linen as nn from chex import Array from einops import rearrange from flax.linen import partitioning as nn_partitioning @@ -39,6 +41,28 @@ get_precision, setup_initial_state, ) +from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin + +def maybe_load_flux_lora(config, lora_loader, params): + def _noop_interceptor(next_fn, args, kwargs, context): + return next_fn(*args, **kwargs) + + lora_config = config.lora_config + interceptors= [_noop_interceptor] + if len(lora_config["lora_model_name_or_path"]) > 0: + interceptors = [] + for i in range(len(lora_config["lora_model_name_or_path"])): + params, rank, network_alphas = lora_loader.load_lora_weights( + config, + lora_config["lora_model_name_or_path"][i], + weight_name=lora_config["weight_name"][i], + params=params, + adapter_name=lora_config["adapter_name"][i], + ) + interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i]) + interceptors.append(interceptor) + + return params, interceptors def unpack(x: Array, height: int, width: int) -> Array: @@ -400,21 +424,29 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep # loads pretrained weights transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") + params = {} + params["transformer"] = transformer_params + # maybe load lora and create interceptor + lora_loader = FluxLoraLoaderMixin() + params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params) + transformer_params = params["transformer"] # create transformer state weights_init_fn = functools.partial( transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False ) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, - ) - transformer_state = transformer_state.replace(params=transformer_params) - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + transformer_state = transformer_state.replace(params=transformer_params) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) get_memory_allocations() states = {} @@ -444,17 +476,23 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep out_shardings=None, ) t0 = time.perf_counter() - p_run_inference(states).block_until_ready() + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + p_run_inference(states).block_until_ready() t1 = time.perf_counter() max_logging.log(f"Compile time: {t1 - t0:.1f}s.") t0 = time.perf_counter() - imgs = p_run_inference(states).block_until_ready() + with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"): + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + imgs = p_run_inference(states).block_until_ready() t1 = time.perf_counter() max_logging.log(f"Inference time: {t1 - t0:.1f}s.") t0 = time.perf_counter() - imgs = p_run_inference(states).block_until_ready() + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + imgs = p_run_inference(states).block_until_ready() imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) t1 = time.perf_counter() max_logging.log(f"Inference time: {t1 - t0:.1f}s.") diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index 2c9e973d1..172738681 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. from .lora_pipeline import StableDiffusionLoraLoaderMixin -from .flux_lora_pipeline import FluxLoraLoaderMixin +from .flux_lora_pipeline import FluxLoraLoaderMixin \ No newline at end of file diff --git a/src/maxdiffusion/loaders/flux_lora_pipeline.py b/src/maxdiffusion/loaders/flux_lora_pipeline.py index 5f449ee9a..36fc64b3c 100644 --- a/src/maxdiffusion/loaders/flux_lora_pipeline.py +++ b/src/maxdiffusion/loaders/flux_lora_pipeline.py @@ -16,30 +16,30 @@ from .lora_base import LoRABaseMixin from ..models.lora import LoRALinearLayer, BaseLoRALayer import jax.numpy as jnp -from flax.traverse_util import flatten_dict +from flax.traverse_util import flatten_dict, unflatten_dict +from flax.core.frozen_dict import unfreeze from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax from huggingface_hub.utils import validate_hf_hub_args - - +from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor) class FluxLoraLoaderMixin(LoRABaseMixin): _lora_lodable_modules = ["transformer", "text_encoder"] - + def load_lora_weights( self, config, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], params, adapter_name=None, - **kwargs, + **kwargs ): state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) params, rank, network_alphas = self.load_lora( - config, - state_dict, - params=params, - adapter_name=adapter_name, + config, + state_dict, + params=params, + adapter_name=adapter_name, ) return params, rank, network_alphas @@ -53,7 +53,7 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name): new_layer_lora = layer_lora[: layer_lora.index(lora_name)] if new_layer_lora not in new_params_keys: new_params_keys.append(new_layer_lora) - network_alpha = network_alphas.get(layer_lora, None) + network_alpha = network_alphas[layer_lora] new_network_alphas[new_layer_lora] = network_alpha return new_params_keys, new_network_alphas @@ -64,7 +64,7 @@ def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name): transformer_keys = flatten_dict(params["transformer"]).keys() lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name) network_alphas_for_interceptor.update(transformer_alphas) - + def _intercept(next_fn, args, kwargs, context): mod = context.module while mod is not None: @@ -107,6 +107,7 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) resume_download = kwargs.pop("resume_download", False) @@ -137,8 +138,8 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): ) return state_dict - + @classmethod def load_lora(cls, config, state_dict, params, adapter_name=None): params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name) - return params, rank, network_alphas + return params, rank, network_alphas \ No newline at end of file diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 9552c69f1..c13c94f74 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -222,6 +222,51 @@ def create_flax_params_from_pytorch_state( renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas +def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name): + pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} + transformer_params = flatten_dict(unfreeze(params["transformer"])) + network_alphas = {} + rank = None + for pt_key, tensor in pt_state_dict.items(): + renamed_pt_key = rename_key(pt_key) + print("renamed_pt_key:", renamed_pt_key) + renamed_pt_key = renamed_pt_key.replace("lora_unet_", "") + renamed_pt_key = renamed_pt_key.replace("lora_down", f"lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("lora_up", f"lora-{adapter_name}.up") + + if "double_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj") + renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv") + renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0") + renamed_pt_key = renamed_pt_key.replace("_img_mlp_2", ".img_mlp.layers_2") + renamed_pt_key = renamed_pt_key.replace("_img_mod_lin", ".img_norm1.lin") + renamed_pt_key = renamed_pt_key.replace("_txt_attn_proj", ".attn.e_proj") + renamed_pt_key = renamed_pt_key.replace("_txt_attn_qkv", ".attn.e_qkv") + renamed_pt_key = renamed_pt_key.replace("_txt_mlp_0", ".txt_mlp.layers_0") + renamed_pt_key = renamed_pt_key.replace("_txt_mlp_2", ".txt_mlp.layers_2") + renamed_pt_key = renamed_pt_key.replace("_txt_mod_lin", ".txt_norm1.lin") + elif "single_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1") + renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2") + renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin") + + renamed_pt_key = renamed_pt_key.replace("weight", "kernel") + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + if "alpha" in pt_tuple_key: + pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'down', 'kernel') + network_alphas[tuple([*pt_tuple_key])] = tensor.item() + pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'up', 'kernel') + network_alphas[tuple([*pt_tuple_key])] = tensor.item() + else: + if pt_tuple_key[-2] == "up": + rank = tensor.shape[1] + transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype) + + params["transformer"] = unflatten_dict(transformer_params) + + return params, rank, network_alphas + def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name): pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} From 719e6dbd1cf80b1b8836c3a6fe3011cbbeec5a3a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 12 Feb 2025 21:38:47 +0000 Subject: [PATCH 28/35] adding another format lora support. --- src/maxdiffusion/configs/base_flux_dev.yml | 2 +- src/maxdiffusion/generate_flux.py | 52 +++++++++---------- .../loaders/flux_lora_pipeline.py | 2 +- .../transformers/transformer_flux_flax.py | 12 ++--- .../models/modeling_flax_pytorch_utils.py | 11 +++- 5 files changed, 43 insertions(+), 36 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 67b7307f9..3077e3b56 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -30,7 +30,7 @@ t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' # Flux params flux_name: "flux-dev" max_sequence_length: 512 -time_shift: False +time_shift: True base_shift: 0.5 max_shift: 1.15 # offloads t5 encoder after text encoding to save memory. diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index fc9184575..bd358a122 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array: def vae_decode(latents, vae, state, config): - img = unpack(x=latents, height=config.resolution, width=config.resolution) + img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution) img = img / vae.config.scaling_factor + vae.config.shift_factor img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample return img @@ -115,13 +115,12 @@ def loop_body( def prepare_latent_image_ids(height, width): latent_image_ids = jnp.zeros((height, width, 3)) - latent_image_ids = latent_image_ids.at[..., 1].set(latent_image_ids[..., 1] + jnp.arange(height)[:, None]) - latent_image_ids = latent_image_ids.at[..., 2].set(latent_image_ids[..., 2] + jnp.arange(width)[None, :]) + latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None]) + latent_image_ids = latent_image_ids.at[..., 2].set(jnp.arange(width)[None, :]) latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) - return latent_image_ids.astype(jnp.bfloat16) @@ -147,20 +146,10 @@ def run_inference( txt_ids, vec, guidance_vec, + c_ts, + p_ts ): - timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) - # shifting the schedule to favor high timesteps for higher signal images - if config.time_shift: - # estimate mu based on linear estimation between two points - lin_function = get_lin_function(y1=config.base_shift, y2=config.max_shift) - mu = lin_function(latents.shape[1]) - timesteps = time_shift(mu, 1.0, timesteps).tolist() - c_ts = timesteps[:-1] - p_ts = timesteps[1:] - # jax.debug.print("c_ts: {x}", x=c_ts) - # jax.debug.print("p_ts: {x}", x=p_ts) - transformer_state = states["transformer"] vae_state = states["vae"] @@ -173,11 +162,10 @@ def run_inference( vec=vec, guidance_vec=guidance_vec, ) - vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, transformer_state, c_ts, p_ts)) + latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, transformer_state, c_ts, p_ts)) image = vae_decode_p(latents) return image @@ -236,8 +224,7 @@ def get_clip_prompt_embeds( prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False) prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1) - prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1)) + prompt_embeds = jnp.tile(prompt_embeds, (batch_size * num_images_per_prompt, 1)) return prompt_embeds @@ -300,7 +287,7 @@ def encode_prompt( max_sequence_length=max_sequence_length, ) - text_ids = jnp.zeros((prompt_embeds.shape[0], prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) + text_ids = jnp.zeros((prompt_embeds.shape[1], 3)).astype(jnp.bfloat16) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -397,18 +384,14 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep print("guidance.shape: ", guidance.shape, guidance.dtype) print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype) - timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16) guidance = jnp.asarray([config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16) - validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) - # move inputs to device and shard data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) latents = jax.device_put(latents, data_sharding) - latent_image_ids = jax.device_put(latent_image_ids, data_sharding) + latent_image_ids = jax.device_put(latent_image_ids) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - text_ids = jax.device_put(text_ids, data_sharding) - timesteps = jax.device_put(timesteps, data_sharding) + text_ids = jax.device_put(text_ids) guidance = jax.device_put(guidance, data_sharding) pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding) @@ -458,6 +441,19 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep states["transformer"] = transformer_state states["vae"] = vae_state + # Setup timesteps + timesteps = jnp.linspace(1, 0, config.num_inference_steps + 1) + # shifting the schedule to favor high timesteps for higher signal images + if config.time_shift: + # estimate mu based on linear estimation between two points + lin_function = get_lin_function(x1=config.max_sequence_length, y1=config.base_shift, y2=config.max_shift) + mu = lin_function(latents.shape[1]) + timesteps = time_shift(mu, 1.0, timesteps) + c_ts = timesteps[:-1] + p_ts = timesteps[1:] + + validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds) + p_run_inference = jax.jit( functools.partial( run_inference, @@ -471,6 +467,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep txt_ids=text_ids, vec=pooled_prompt_embeds, guidance_vec=guidance, + c_ts=c_ts, + p_ts=p_ts ), in_shardings=(state_shardings,), out_shardings=None, diff --git a/src/maxdiffusion/loaders/flux_lora_pipeline.py b/src/maxdiffusion/loaders/flux_lora_pipeline.py index 36fc64b3c..6de0fa3f0 100644 --- a/src/maxdiffusion/loaders/flux_lora_pipeline.py +++ b/src/maxdiffusion/loaders/flux_lora_pipeline.py @@ -53,7 +53,7 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name): new_layer_lora = layer_lora[: layer_lora.index(lora_name)] if new_layer_lora not in new_params_keys: new_params_keys.append(new_layer_lora) - network_alpha = network_alphas[layer_lora] + network_alpha = network_alphas.get(layer_lora, None) new_network_alphas[new_layer_lora] = network_alpha return new_params_keys, new_network_alphas diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index bff07988a..40c95de48 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -144,10 +144,10 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None): hidden_states = self.linear2(attn_mlp) hidden_states = gate * hidden_states hidden_states = residual + hidden_states - if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16: + if hidden_states.dtype == jnp.float16: hidden_states = jnp.clip(hidden_states, -65504, 65504) - return hidden_states, temb, image_rotary_emb + return hidden_states class FluxTransformerBlock(nn.Module): @@ -294,9 +294,9 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb= context_ff_output = self.txt_mlp(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output - if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16: + if encoder_hidden_states.dtype == jnp.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - return hidden_states, encoder_hidden_states, temb, image_rotary_emb + return hidden_states, encoder_hidden_states @flax_register_to_config @@ -504,7 +504,7 @@ def __call__( image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed")) for double_block in self.double_blocks: - hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block( + hidden_states, encoder_hidden_states = double_block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, @@ -513,7 +513,7 @@ def __call__( hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1) hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) for single_block in self.single_blocks: - hidden_states, temb, image_rotary_emb = single_block( + hidden_states = single_block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb ) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index c13c94f74..7ade84df9 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -229,12 +229,21 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, rank = None for pt_key, tensor in pt_state_dict.items(): renamed_pt_key = rename_key(pt_key) - print("renamed_pt_key:", renamed_pt_key) renamed_pt_key = renamed_pt_key.replace("lora_unet_", "") renamed_pt_key = renamed_pt_key.replace("lora_down", f"lora-{adapter_name}.down") renamed_pt_key = renamed_pt_key.replace("lora_up", f"lora-{adapter_name}.up") if "double_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("double_blocks.", "double_blocks_") + renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.down", f"attn.i_proj.lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.up", f"attn.i_proj.lora-{adapter_name}.up") + renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.down", f"attn.e_proj.lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.up", f"attn.e_proj.lora-{adapter_name}.up") + renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.down", f"attn.i_qkv.lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.up", f"attn.i_qkv.lora-{adapter_name}.up") + renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.down", f"attn.e_qkv.lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.up", f"attn.e_qkv.lora-{adapter_name}.up") + renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj") renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv") renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0") From 91d7f5c38dcaa493a93e56e6b558daa53aaa90c0 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 13 Feb 2025 01:26:10 +0000 Subject: [PATCH 29/35] Support other format loras. update readme. Run code_style. --- README.md | 24 +++++++++++++++++-- src/maxdiffusion/generate_flux.py | 16 +++++++------ src/maxdiffusion/loaders/__init__.py | 2 +- .../loaders/flux_lora_pipeline.py | 20 +++++++++------- .../models/modeling_flax_pytorch_utils.py | 13 +++++----- .../tests/generate_flux_smoke_test.py | 1 + 6 files changed, 51 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index cdf451050..f9be0f3fc 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,8 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? -- **`2025/02/08**: Flux schnell & dev inference. +- **`2025/02/12`**: Flux LoRA for inference. +- **`2025/02/08`**: Flux schnell & dev inference. - **`2024/12/12`**: Load multiple LoRAs for inference. - **`2024/10/22`**: LoRA support for Hyper SDXL. - **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. @@ -47,7 +48,8 @@ MaxDiffusion supports * [Training](#training) * [Dreambooth](#dreambooth) * [Inference](#inference) - * [Flux](#flux) + * [Flux](#flux) + * [Flux LoRA](#flux-lora) * [Hyper-SD XL LoRA](#hyper-sdxl-lora) * [Load Multiple LoRA](#load-multiple-lora) * [SDXL Lightning](#sdxl-lightning) @@ -169,6 +171,24 @@ To generate images, run the following command: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` + ## Flux LoRA + + Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know. + + Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection. + + First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows: + + ```bash + python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}' + ``` + + Loading multiple LoRAs is supported as follows: + + ```bash + python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}' + ``` + ## Hyper SDXL LoRA Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index bd358a122..d2604d862 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -43,25 +43,25 @@ ) from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin + def maybe_load_flux_lora(config, lora_loader, params): def _noop_interceptor(next_fn, args, kwargs, context): return next_fn(*args, **kwargs) lora_config = config.lora_config - interceptors= [_noop_interceptor] + interceptors = [_noop_interceptor] if len(lora_config["lora_model_name_or_path"]) > 0: interceptors = [] for i in range(len(lora_config["lora_model_name_or_path"])): params, rank, network_alphas = lora_loader.load_lora_weights( - config, - lora_config["lora_model_name_or_path"][i], - weight_name=lora_config["weight_name"][i], - params=params, - adapter_name=lora_config["adapter_name"][i], + config, + lora_config["lora_model_name_or_path"][i], + weight_name=lora_config["weight_name"][i], + params=params, + adapter_name=lora_config["adapter_name"][i], ) interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i]) interceptors.append(interceptor) - return params, interceptors @@ -501,6 +501,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep for i, image in enumerate(imgs): Image.fromarray(image).save(f"flux_{i}.png") + return imgs + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index 172738681..2c9e973d1 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. from .lora_pipeline import StableDiffusionLoraLoaderMixin -from .flux_lora_pipeline import FluxLoraLoaderMixin \ No newline at end of file +from .flux_lora_pipeline import FluxLoraLoaderMixin diff --git a/src/maxdiffusion/loaders/flux_lora_pipeline.py b/src/maxdiffusion/loaders/flux_lora_pipeline.py index 6de0fa3f0..14786c9f5 100644 --- a/src/maxdiffusion/loaders/flux_lora_pipeline.py +++ b/src/maxdiffusion/loaders/flux_lora_pipeline.py @@ -21,25 +21,27 @@ from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax from huggingface_hub.utils import validate_hf_hub_args from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor) + + class FluxLoraLoaderMixin(LoRABaseMixin): _lora_lodable_modules = ["transformer", "text_encoder"] - + def load_lora_weights( self, config, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], params, adapter_name=None, - **kwargs + **kwargs, ): state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) params, rank, network_alphas = self.load_lora( - config, - state_dict, - params=params, - adapter_name=adapter_name, + config, + state_dict, + params=params, + adapter_name=adapter_name, ) return params, rank, network_alphas @@ -64,7 +66,7 @@ def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name): transformer_keys = flatten_dict(params["transformer"]).keys() lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name) network_alphas_for_interceptor.update(transformer_alphas) - + def _intercept(next_fn, args, kwargs, context): mod = context.module while mod is not None: @@ -138,8 +140,8 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): ) return state_dict - + @classmethod def load_lora(cls, config, state_dict, params, adapter_name=None): params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name) - return params, rank, network_alphas \ No newline at end of file + return params, rank, network_alphas diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 7ade84df9..c3a73370a 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -222,6 +222,7 @@ def create_flax_params_from_pytorch_state( renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas + def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name): pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} transformer_params = flatten_dict(unfreeze(params["transformer"])) @@ -243,7 +244,7 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.up", f"attn.i_qkv.lora-{adapter_name}.up") renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.down", f"attn.e_qkv.lora-{adapter_name}.down") renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.up", f"attn.e_qkv.lora-{adapter_name}.up") - + renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj") renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv") renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0") @@ -258,20 +259,20 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1") renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2") renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin") - + renamed_pt_key = renamed_pt_key.replace("weight", "kernel") - + pt_tuple_key = tuple(renamed_pt_key.split(".")) if "alpha" in pt_tuple_key: - pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'down', 'kernel') + pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "down", "kernel") network_alphas[tuple([*pt_tuple_key])] = tensor.item() - pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'up', 'kernel') + pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "up", "kernel") network_alphas[tuple([*pt_tuple_key])] = tensor.item() else: if pt_tuple_key[-2] == "up": rank = tensor.shape[1] transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype) - + params["transformer"] = unflatten_dict(transformer_params) return params, rank, network_alphas diff --git a/src/maxdiffusion/tests/generate_flux_smoke_test.py b/src/maxdiffusion/tests/generate_flux_smoke_test.py index b8ee06f9c..d4ab287df 100644 --- a/src/maxdiffusion/tests/generate_flux_smoke_test.py +++ b/src/maxdiffusion/tests/generate_flux_smoke_test.py @@ -22,6 +22,7 @@ def download_blob(gcs_file, local_file): gcs_dir_arr = gcs_file.replace("gs://", "").split("/") storage_client = storage.Client() bucket = storage_client.get_bucket(gcs_dir_arr[0]) + blob_loc = "/".join(gcs_dir_arr[1:]) blob = bucket.blob("/".join(gcs_dir_arr[1:])) blob.download_to_filename(local_file) From 1e01c678a13c7b157a3b9ca72ce9c166d4567329 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 13 Feb 2025 01:32:43 +0000 Subject: [PATCH 30/35] ruff --- src/maxdiffusion/loaders/flux_lora_pipeline.py | 5 +---- src/maxdiffusion/models/modeling_flax_pytorch_utils.py | 6 +++--- src/maxdiffusion/tests/generate_flux_smoke_test.py | 1 - 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/loaders/flux_lora_pipeline.py b/src/maxdiffusion/loaders/flux_lora_pipeline.py index 14786c9f5..5f449ee9a 100644 --- a/src/maxdiffusion/loaders/flux_lora_pipeline.py +++ b/src/maxdiffusion/loaders/flux_lora_pipeline.py @@ -16,11 +16,9 @@ from .lora_base import LoRABaseMixin from ..models.lora import LoRALinearLayer, BaseLoRALayer import jax.numpy as jnp -from flax.traverse_util import flatten_dict, unflatten_dict -from flax.core.frozen_dict import unfreeze +from flax.traverse_util import flatten_dict from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax from huggingface_hub.utils import validate_hf_hub_args -from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor) class FluxLoraLoaderMixin(LoRABaseMixin): @@ -109,7 +107,6 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) resume_download = kwargs.pop("resume_download", False) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index c3a73370a..336f01404 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -265,13 +265,13 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, pt_tuple_key = tuple(renamed_pt_key.split(".")) if "alpha" in pt_tuple_key: pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "down", "kernel") - network_alphas[tuple([*pt_tuple_key])] = tensor.item() + network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409 pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "up", "kernel") - network_alphas[tuple([*pt_tuple_key])] = tensor.item() + network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409 else: if pt_tuple_key[-2] == "up": rank = tensor.shape[1] - transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype) + transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype) # noqa: C409 params["transformer"] = unflatten_dict(transformer_params) diff --git a/src/maxdiffusion/tests/generate_flux_smoke_test.py b/src/maxdiffusion/tests/generate_flux_smoke_test.py index d4ab287df..b8ee06f9c 100644 --- a/src/maxdiffusion/tests/generate_flux_smoke_test.py +++ b/src/maxdiffusion/tests/generate_flux_smoke_test.py @@ -22,7 +22,6 @@ def download_blob(gcs_file, local_file): gcs_dir_arr = gcs_file.replace("gs://", "").split("/") storage_client = storage.Client() bucket = storage_client.get_bucket(gcs_dir_arr[0]) - blob_loc = "/".join(gcs_dir_arr[1:]) blob = bucket.blob("/".join(gcs_dir_arr[1:])) blob.download_to_filename(local_file) From 19e1b8a4c872ab9cf5e2265ff0953c0e45a083e5 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 5 Feb 2025 21:59:56 +0000 Subject: [PATCH 31/35] fix typo in readme. --- README.md | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index f9be0f3fc..ea7e1fe13 100644 --- a/README.md +++ b/README.md @@ -44,18 +44,22 @@ MaxDiffusion supports # Table of Contents -* [Getting Started](#getting-started) - * [Training](#training) - * [Dreambooth](#dreambooth) - * [Inference](#inference) - * [Flux](#flux) - * [Flux LoRA](#flux-lora) - * [Hyper-SD XL LoRA](#hyper-sdxl-lora) - * [Load Multiple LoRA](#load-multiple-lora) - * [SDXL Lightning](#sdxl-lightning) - * [ControlNet](#controlnet) -* [Comparison To Alternatives](#comparison-to-alternatives) -* [Development](#development) +- [What's new?](#whats-new) +- [Overview](#overview) +- [Table of Contents](#table-of-contents) +- [Getting Started](#getting-started) + - [Getting Started:](#getting-started-1) + - [Training](#training) + - [Dreambooth](#dreambooth) + - [Inference](#inference) + - [Flux](#flux) + - [Hyper SDXL LoRA](#hyper-sdxl-lora) + - [Load Multiple LoRA](#load-multiple-lora) + - [SDXL Lightning](#sdxl-lightning) + - [ControlNet](#controlnet) + - [Getting Started: Multihost development](#getting-started-multihost-development) +- [Comparison to Alternatives](#comparison-to-alternatives) +- [Development](#development) # Getting Started From f05a7bec84e9dab9d2909a160c6603e7b7ecd885 Mon Sep 17 00:00:00 2001 From: ksikiric Date: Thu, 13 Feb 2025 13:02:27 +0000 Subject: [PATCH 32/35] Added FA support for GPUs --- README.md | 4 ++ src/maxdiffusion/configs/base_flux_dev.yml | 4 +- .../configs/base_flux_schnell.yml | 16 ++++- src/maxdiffusion/models/attention_flax.py | 70 +++++++++++++++++-- 4 files changed, 84 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ea7e1fe13..01c2faf73 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ MaxDiffusion supports - [Dreambooth](#dreambooth) - [Inference](#inference) - [Flux](#flux) + - [Flash Attention for GPU:](#flash-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) - [Load Multiple LoRA](#load-multiple-lora) - [SDXL Lightning](#sdxl-lightning) @@ -175,6 +176,9 @@ To generate images, run the following command: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` + ### Flash Attention for GPU: + Flash Attention for GPU is supported via TransformerEngine, make sure it is installed and then specify attention=cudnn_flash_te when running the above commands. + ## Flux LoRA Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know. diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 3077e3b56..3f9c28a8a 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -54,7 +54,7 @@ precision: "DEFAULT" # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. @@ -197,7 +197,7 @@ max_train_steps: 200 num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' -per_device_batch_size: 1 +per_device_batch_size: 8 warmup_steps_fraction: 0.0 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 4c22edb73..6bf38e84d 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -53,9 +53,19 @@ precision: "DEFAULT" # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash -flash_block_sizes: {} -# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +flash_block_sizes: { + "block_q" : 256, + "block_kv_compute" : 256, + "block_kv" : 256, + "block_q_dkv" : 256, + "block_kv_dkv" : 256, + "block_kv_dkv_compute" : 256, + "block_q_dq" : 256, + "block_kv_dq" : 256 +} + +# Use the following flash_block_sizes on v6e (Trillium). # flash_block_sizes: { # "block_q" : 1536, # "block_kv_compute" : 1536, diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 884b6d688..76bee3e57 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -52,6 +52,24 @@ class AttentionOp(nn.Module): flash_block_sizes: BlockSizes = None dtype: DType = jnp.float32 + def setup(self): + if self.attention_kernel == "cudnn_flash_te": + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + self.dpa_layer = DotProductAttention( + head_dim=self.dim_head, + num_attention_heads=self.heads, + num_gqa_groups=self.heads, + attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + # attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + # float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=self.scale, + transpose_batch_sequence=False, + ) + def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: """Check attention inputs.""" @@ -64,16 +82,22 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None def apply_attention(self, query: Array, key: Array, value: Array): """Routes to different attention kernels.""" self.check_attention_inputs(query, key, value) - can_use_flash_attention = ( - query.shape[1] >= self.flash_min_seq_length - and key.shape[1] >= self.flash_min_seq_length - and value.shape[1] >= self.flash_min_seq_length - ) + + if self.attention_kernel == "flash": + can_use_flash_attention = ( + query.shape[1] >= self.flash_min_seq_length + and key.shape[1] >= self.flash_min_seq_length + and value.shape[1] >= self.flash_min_seq_length + ) + else: + can_use_flash_attention = True if self.attention_kernel == "dot_product" or self.use_memory_efficient_attention or not can_use_flash_attention: return self.apply_attention_dot(query, key, value) elif self.attention_kernel == "flash": return self.tpu_flash_attention(query, key * self.scale, value) + elif self.attention_kernel == "cudnn_flash_te": + return self.cudnn_flash_attention(query, key, value) else: raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") @@ -132,6 +156,32 @@ def wrap_flash_attention(query, key, value): return x + def cudnn_flash_attention( + self, + query: Array, + key: Array, + value: Array, + ) -> Array: + """CUDNN Flash Attention with Transformer Engine. + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon + """ + # These imports are only meant to work in a GPU build. + # copied from tpu_flash_attention + query = self.reshape_data_for_cudnn_flash(query) + key = self.reshape_data_for_cudnn_flash(key) + value = self.reshape_data_for_cudnn_flash(value) + + cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) + axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) + + query = nn.with_logical_constraint(query, axis_names) + key = nn.with_logical_constraint(key, axis_names) + value = nn.with_logical_constraint(value, axis_names) + + out = self.dpa_layer(query, key, value, mask=None) + return self.reshape_data_from_cudnn_flash(out) + def apply_attention_dot(self, query: Array, key: Array, value: Array): """Apply Attention.""" if self.split_head_dim: @@ -209,6 +259,16 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor + def reshape_data_for_cudnn_flash(self, tensor): + # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) + batch, seq, heads_and_dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads) + return tensor + + def reshape_data_from_cudnn_flash(self, tensor): + # reshapes from [b, s, h, d] back to [b, s, h * d] + return tensor.reshape(tensor.shape[0], tensor.shape[1], -1) + def reshape_data_for_flash(self, tensor): # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format) batch, seq, heads_and_dim_head = tensor.shape From 3141d695dcf4e683aab04f26fb314cb3b927a576 Mon Sep 17 00:00:00 2001 From: ksikiric Date: Tue, 18 Feb 2025 09:52:18 +0100 Subject: [PATCH 33/35] ruff and code_style --- src/maxdiffusion/generate_flux.py | 16 +----- src/maxdiffusion/models/attention_flax.py | 25 +++++---- .../transformers/transformer_flux_flax.py | 4 +- .../models/modeling_flax_pytorch_utils.py | 55 ------------------- 4 files changed, 16 insertions(+), 84 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index d2604d862..59564a271 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -135,19 +135,7 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo def run_inference( - states, - transformer, - vae, - config, - mesh, - latents, - latent_image_ids, - prompt_embeds, - txt_ids, - vec, - guidance_vec, - c_ts, - p_ts + states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts ): transformer_state = states["transformer"] @@ -468,7 +456,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep vec=pooled_prompt_embeds, guidance_vec=guidance, c_ts=c_ts, - p_ts=p_ts + p_ts=p_ts, ), in_shardings=(state_shardings,), out_shardings=None, diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 76bee3e57..db8626984 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -55,19 +55,20 @@ class AttentionOp(nn.Module): def setup(self): if self.attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + self.dpa_layer = DotProductAttention( - head_dim=self.dim_head, - num_attention_heads=self.heads, - num_gqa_groups=self.heads, - attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' - attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' - # attention_dropout=self.dropout_rate, - dropout_rng_name="aqt", - dtype=self.dtype, - # float32_logits=self.float32_logits, - qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=self.scale, - transpose_batch_sequence=False, + head_dim=self.dim_head, + num_attention_heads=self.heads, + num_gqa_groups=self.heads, + attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + # attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + # float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=self.scale, + transpose_batch_sequence=False, ) def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 40c95de48..96207468e 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -513,9 +513,7 @@ def __call__( hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1) hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) for single_block in self.single_blocks: - hidden_states = single_block( - hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb - ) + hidden_states = single_block(hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.norm_out(hidden_states, temb) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 336f01404..9552c69f1 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -278,61 +278,6 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, return params, rank, network_alphas -def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name): - pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} - transformer_params = flatten_dict(unfreeze(params["transformer"])) - network_alphas = {} - rank = None - for pt_key, tensor in pt_state_dict.items(): - renamed_pt_key = rename_key(pt_key) - renamed_pt_key = renamed_pt_key.replace("lora_unet_", "") - renamed_pt_key = renamed_pt_key.replace("lora_down", f"lora-{adapter_name}.down") - renamed_pt_key = renamed_pt_key.replace("lora_up", f"lora-{adapter_name}.up") - - if "double_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("double_blocks.", "double_blocks_") - renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.down", f"attn.i_proj.lora-{adapter_name}.down") - renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.up", f"attn.i_proj.lora-{adapter_name}.up") - renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.down", f"attn.e_proj.lora-{adapter_name}.down") - renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.up", f"attn.e_proj.lora-{adapter_name}.up") - renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.down", f"attn.i_qkv.lora-{adapter_name}.down") - renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.up", f"attn.i_qkv.lora-{adapter_name}.up") - renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.down", f"attn.e_qkv.lora-{adapter_name}.down") - renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.up", f"attn.e_qkv.lora-{adapter_name}.up") - - renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj") - renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv") - renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0") - renamed_pt_key = renamed_pt_key.replace("_img_mlp_2", ".img_mlp.layers_2") - renamed_pt_key = renamed_pt_key.replace("_img_mod_lin", ".img_norm1.lin") - renamed_pt_key = renamed_pt_key.replace("_txt_attn_proj", ".attn.e_proj") - renamed_pt_key = renamed_pt_key.replace("_txt_attn_qkv", ".attn.e_qkv") - renamed_pt_key = renamed_pt_key.replace("_txt_mlp_0", ".txt_mlp.layers_0") - renamed_pt_key = renamed_pt_key.replace("_txt_mlp_2", ".txt_mlp.layers_2") - renamed_pt_key = renamed_pt_key.replace("_txt_mod_lin", ".txt_norm1.lin") - elif "single_blocks" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1") - renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2") - renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin") - - renamed_pt_key = renamed_pt_key.replace("weight", "kernel") - - pt_tuple_key = tuple(renamed_pt_key.split(".")) - if "alpha" in pt_tuple_key: - pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "down", "kernel") - network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409 - pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "up", "kernel") - network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409 - else: - if pt_tuple_key[-2] == "up": - rank = tensor.shape[1] - transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype) # noqa: C409 - - params["transformer"] = unflatten_dict(transformer_params) - - return params, rank, network_alphas - - def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas, adapter_name): # Step 1: Convert pytorch tensor to numpy # sometimes we load weights in bf16 and numpy doesn't support it From 2072f7ebcf055e5b36d5f926142437f52d224273 Mon Sep 17 00:00:00 2001 From: ksikiric Date: Thu, 20 Feb 2025 08:13:06 +0100 Subject: [PATCH 34/35] fixed final comments --- README.md | 19 +++++++++++++++++-- src/maxdiffusion/configs/base_flux_dev.yml | 18 +++++++++--------- .../configs/base_flux_schnell.yml | 16 ++++++++-------- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 01c2faf73..09492f921 100644 --- a/README.md +++ b/README.md @@ -176,8 +176,23 @@ To generate images, run the following command: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` - ### Flash Attention for GPU: - Flash Attention for GPU is supported via TransformerEngine, make sure it is installed and then specify attention=cudnn_flash_te when running the above commands. + ## Flash Attention for GPU: + Flash Attention for GPU is supported via TransformerEngine. Installation instructions: + + ```bash + cd maxdiffusion + pip install -U "jax[cuda12]" + pip install -r requirements.txt + pip install --upgrade torch torchvision + pip install "transformer_engine[jax] + pip install . + ``` + + Now run the command: + + ```bash + NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu + ``` ## Flux LoRA diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 3f9c28a8a..7a917824a 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -59,14 +59,14 @@ attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. # flash_block_sizes: { -# "block_q" : 1536, -# "block_kv_compute" : 1536, -# "block_kv" : 1536, -# "block_q_dkv" : 1536, -# "block_kv_dkv" : 1536, -# "block_kv_dkv_compute" : 1536, -# "block_q_dq" : 1536, -# "block_kv_dq" : 1536 +# "block_q" : 2176, +# "block_kv_compute" : 2176, +# "block_kv" : 2176, +# "block_q_dkv" : 2176, +# "block_kv_dkv" : 2176, +# "block_kv_dkv_compute" : 2176, +# "block_q_dq" : 2176, +# "block_kv_dq" : 2176 # } # GroupNorm groups norm_num_groups: 32 @@ -197,7 +197,7 @@ max_train_steps: 200 num_train_epochs: 1 seed: 0 output_dir: 'sdxl-model-finetuned' -per_device_batch_size: 8 +per_device_batch_size: 1 warmup_steps_fraction: 0.0 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 6bf38e84d..1b2bc28b3 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -67,14 +67,14 @@ flash_block_sizes: { # Use the following flash_block_sizes on v6e (Trillium). # flash_block_sizes: { -# "block_q" : 1536, -# "block_kv_compute" : 1536, -# "block_kv" : 1536, -# "block_q_dkv" : 1536, -# "block_kv_dkv" : 1536, -# "block_kv_dkv_compute" : 1536, -# "block_q_dq" : 1536, -# "block_kv_dq" : 1536 +# "block_q" : 2176, +# "block_kv_compute" : 2176, +# "block_kv" : 2176, +# "block_q_dkv" : 2176, +# "block_kv_dkv" : 2176, +# "block_kv_dkv_compute" : 2176, +# "block_q_dq" : 2176, +# "block_kv_dq" : 2176 # } # GroupNorm groups norm_num_groups: 32 From 56771dd5683c4277f3a5a135e0e31af940a1e8a8 Mon Sep 17 00:00:00 2001 From: ksikiric Date: Thu, 20 Feb 2025 17:55:58 +0100 Subject: [PATCH 35/35] Correcting small misstake due to missunderstanding --- src/maxdiffusion/configs/base_flux_dev.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 7a917824a..53a8cf8e6 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -59,14 +59,14 @@ attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. # flash_block_sizes: { -# "block_q" : 2176, -# "block_kv_compute" : 2176, -# "block_kv" : 2176, -# "block_q_dkv" : 2176, -# "block_kv_dkv" : 2176, -# "block_kv_dkv_compute" : 2176, -# "block_q_dq" : 2176, -# "block_kv_dq" : 2176 +# "block_q" : 1536, +# "block_kv_compute" : 1536, +# "block_kv" : 1536, +# "block_q_dkv" : 1536, +# "block_kv_dkv" : 1536, +# "block_kv_dkv_compute" : 1536, +# "block_q_dq" : 1536, +# "block_kv_dq" : 1536 # } # GroupNorm groups norm_num_groups: 32