|
| 1 | +"""Copyright 2025 Google LLC |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +""" |
| 15 | + |
| 16 | +from typing import Tuple |
| 17 | + |
| 18 | +from flax import nnx |
| 19 | +import jax |
| 20 | +from jax.ad_checkpoint import checkpoint_name |
| 21 | +import jax.numpy as jnp |
| 22 | +from jax.sharding import PartitionSpec |
| 23 | + |
| 24 | +from .... import common_types |
| 25 | +from ...attention_flax import FlaxWanAttention |
| 26 | +from ...normalization_flax import FP32LayerNorm |
| 27 | +from .transformer_wan import WanFeedForward |
| 28 | + |
| 29 | +BlockSizes = common_types.BlockSizes |
| 30 | + |
| 31 | + |
| 32 | +class WanVACETransformerBlock(nnx.Module): |
| 33 | + """Attention block for VACE. |
| 34 | +
|
| 35 | + Processes the conditioning signals and produces latent codes that can be |
| 36 | + summed to the main branch of WAN. |
| 37 | +
|
| 38 | + Based on |
| 39 | + https://github.com/huggingface/diffusers/blob/be3c2a0667493022f17d756ca3dba631d28dfb40/src/diffusers/models/transformers/transformer_wan_vace.py#L41C7-L41C30 |
| 40 | + """ |
| 41 | + |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + rngs: nnx.Rngs, |
| 45 | + *, |
| 46 | + dim: int, |
| 47 | + ffn_dim: int, |
| 48 | + num_heads: int, |
| 49 | + qk_norm: str = "rms_norm_across_heads", |
| 50 | + cross_attn_norm: bool = False, |
| 51 | + eps: float = 1e-6, |
| 52 | + flash_min_seq_length: int = 4096, |
| 53 | + flash_block_sizes: BlockSizes | None = None, |
| 54 | + mesh: jax.sharding.Mesh | None = None, |
| 55 | + dtype: jnp.dtype = jnp.float32, |
| 56 | + weights_dtype: jnp.dtype = jnp.float32, |
| 57 | + precision: jax.lax.Precision | None = None, |
| 58 | + attention: str = "dot_product", |
| 59 | + dropout: float = 0.0, |
| 60 | + apply_input_projection: bool = False, |
| 61 | + apply_output_projection: bool = False, |
| 62 | + ): |
| 63 | + """Sets up the model. |
| 64 | +
|
| 65 | + Args: |
| 66 | + rngs: Random number generator. |
| 67 | + dim: Internal dimension of the block. |
| 68 | + ffn_dim: Dimension of the feed-forward network. |
| 69 | + num_heads: Number of attention heads. |
| 70 | + qk_norm: Whether to apply RMSNorm to the query and key vectors. |
| 71 | + cross_attn_norm: Whether to apply layer normalization before |
| 72 | + cross-attention (only True supported). |
| 73 | + eps: Epsilon value for normalization. |
| 74 | + flash_min_seq_length: Minimum sequence length for flash attention. |
| 75 | + flash_block_sizes: Block sizes for flash attention. |
| 76 | + mesh: Sharding topology. |
| 77 | + dtype: Data type for the computation. |
| 78 | + weights_dtype: Data type for parameter initializers (see param_dtype in |
| 79 | + nnx.Linear). |
| 80 | + precision: Precision for the computation. |
| 81 | + attention: Type of attention to use. |
| 82 | + dropout: Dropout rate. |
| 83 | + apply_input_projection: Whether to apply a linear projection to the |
| 84 | + inputs. |
| 85 | + apply_output_projection: Whether to apply an output projection before |
| 86 | + outputting the result. |
| 87 | + """ |
| 88 | + |
| 89 | + self.apply_input_projection = apply_input_projection |
| 90 | + self.apply_output_projection = apply_output_projection |
| 91 | + |
| 92 | + # 1. Input projection |
| 93 | + self.proj_in = nnx.data([None]) |
| 94 | + if apply_input_projection: |
| 95 | + self.proj_in = nnx.Linear( |
| 96 | + rngs=rngs, |
| 97 | + in_features=dim, |
| 98 | + out_features=dim, |
| 99 | + dtype=dtype, |
| 100 | + param_dtype=weights_dtype, |
| 101 | + precision=precision, |
| 102 | + kernel_init=nnx.with_partitioning( |
| 103 | + nnx.initializers.xavier_uniform(), ("embed", None) |
| 104 | + ), |
| 105 | + ) |
| 106 | + |
| 107 | + # 2. Self-attention |
| 108 | + self.norm1 = FP32LayerNorm( |
| 109 | + rngs=rngs, dim=dim, eps=eps, elementwise_affine=False |
| 110 | + ) |
| 111 | + self.attn1 = FlaxWanAttention( |
| 112 | + rngs=rngs, |
| 113 | + query_dim=dim, |
| 114 | + heads=num_heads, |
| 115 | + dim_head=dim // num_heads, |
| 116 | + qk_norm=qk_norm, |
| 117 | + eps=eps, |
| 118 | + flash_min_seq_length=flash_min_seq_length, |
| 119 | + flash_block_sizes=flash_block_sizes, |
| 120 | + mesh=mesh, |
| 121 | + dtype=dtype, |
| 122 | + weights_dtype=weights_dtype, |
| 123 | + precision=precision, |
| 124 | + attention_kernel=attention, |
| 125 | + dropout=dropout, |
| 126 | + residual_checkpoint_name="self_attn", |
| 127 | + ) |
| 128 | + |
| 129 | + # 3. Cross-attention |
| 130 | + self.attn2 = FlaxWanAttention( |
| 131 | + rngs=rngs, |
| 132 | + query_dim=dim, |
| 133 | + heads=num_heads, |
| 134 | + dim_head=dim // num_heads, |
| 135 | + qk_norm=qk_norm, |
| 136 | + eps=eps, |
| 137 | + flash_min_seq_length=flash_min_seq_length, |
| 138 | + flash_block_sizes=flash_block_sizes, |
| 139 | + mesh=mesh, |
| 140 | + dtype=dtype, |
| 141 | + weights_dtype=weights_dtype, |
| 142 | + precision=precision, |
| 143 | + attention_kernel=attention, |
| 144 | + dropout=dropout, |
| 145 | + residual_checkpoint_name="cross_attn", |
| 146 | + ) |
| 147 | + assert cross_attn_norm is True, "cross_attn_norm must be True" |
| 148 | + self.norm2 = FP32LayerNorm( |
| 149 | + rngs=rngs, dim=dim, eps=eps, elementwise_affine=True |
| 150 | + ) |
| 151 | + |
| 152 | + # 4. Feed-forward |
| 153 | + self.ffn = WanFeedForward( |
| 154 | + rngs=rngs, |
| 155 | + dim=dim, |
| 156 | + inner_dim=ffn_dim, |
| 157 | + activation_fn="gelu-approximate", |
| 158 | + dtype=dtype, |
| 159 | + weights_dtype=weights_dtype, |
| 160 | + precision=precision, |
| 161 | + dropout=dropout, |
| 162 | + ) |
| 163 | + |
| 164 | + self.norm3 = FP32LayerNorm( |
| 165 | + rngs=rngs, dim=dim, eps=eps, elementwise_affine=False |
| 166 | + ) |
| 167 | + |
| 168 | + # 5. Output projection |
| 169 | + self.proj_out = nnx.data([None]) |
| 170 | + if apply_output_projection: |
| 171 | + self.proj_out = nnx.Linear( |
| 172 | + rngs=rngs, |
| 173 | + in_features=dim, |
| 174 | + out_features=dim, |
| 175 | + dtype=dtype, |
| 176 | + param_dtype=weights_dtype, |
| 177 | + precision=precision, |
| 178 | + kernel_init=nnx.with_partitioning( |
| 179 | + nnx.initializers.xavier_uniform(), ("embed", None) |
| 180 | + ), |
| 181 | + ) |
| 182 | + |
| 183 | + key = rngs.params() |
| 184 | + self.adaln_scale_shift_table = nnx.Param( |
| 185 | + jax.random.normal(key, (1, 6, dim)) / dim**0.5, |
| 186 | + ) |
| 187 | + |
| 188 | + def __call__( |
| 189 | + self, |
| 190 | + *, |
| 191 | + hidden_states: jax.Array, |
| 192 | + encoder_hidden_states: jax.Array, |
| 193 | + control_hidden_states: jax.Array, |
| 194 | + temb: jax.Array, |
| 195 | + rotary_emb: jax.Array, |
| 196 | + deterministic: bool = True, |
| 197 | + rngs: nnx.Rngs | None = None, |
| 198 | + ) -> Tuple[jax.Array, jax.Array]: |
| 199 | + if self.apply_input_projection: |
| 200 | + control_hidden_states = self.proj_in(control_hidden_states) |
| 201 | + control_hidden_states = control_hidden_states + hidden_states |
| 202 | + |
| 203 | + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( |
| 204 | + jnp.split( |
| 205 | + (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 |
| 206 | + ) |
| 207 | + ) |
| 208 | + |
| 209 | + control_hidden_states = jax.lax.with_sharding_constraint( |
| 210 | + control_hidden_states, |
| 211 | + PartitionSpec("data", "fsdp", "tensor"), |
| 212 | + ) |
| 213 | + control_hidden_states = checkpoint_name( |
| 214 | + control_hidden_states, "control_hidden_states" |
| 215 | + ) |
| 216 | + encoder_hidden_states = jax.lax.with_sharding_constraint( |
| 217 | + encoder_hidden_states, |
| 218 | + PartitionSpec("data", "fsdp", None), |
| 219 | + ) |
| 220 | + |
| 221 | + # 1. Self-attention |
| 222 | + with jax.named_scope("attn1"): |
| 223 | + norm_hidden_states = ( |
| 224 | + self.norm1(control_hidden_states.astype(jnp.float32)) |
| 225 | + * (1 + scale_msa) |
| 226 | + + shift_msa |
| 227 | + ).astype(control_hidden_states.dtype) |
| 228 | + attn_output = self.attn1( |
| 229 | + hidden_states=norm_hidden_states, |
| 230 | + encoder_hidden_states=norm_hidden_states, |
| 231 | + rotary_emb=rotary_emb, |
| 232 | + deterministic=deterministic, |
| 233 | + rngs=rngs, |
| 234 | + ) |
| 235 | + control_hidden_states = ( |
| 236 | + control_hidden_states.astype(jnp.float32) + attn_output * gate_msa |
| 237 | + ).astype(control_hidden_states.dtype) |
| 238 | + |
| 239 | + # 2. Cross-attention |
| 240 | + with jax.named_scope("attn2"): |
| 241 | + norm_hidden_states = self.norm2( |
| 242 | + control_hidden_states.astype(jnp.float32) |
| 243 | + ).astype(control_hidden_states.dtype) |
| 244 | + attn_output = self.attn2( |
| 245 | + hidden_states=norm_hidden_states, |
| 246 | + encoder_hidden_states=encoder_hidden_states, |
| 247 | + deterministic=deterministic, |
| 248 | + rngs=rngs, |
| 249 | + ) |
| 250 | + control_hidden_states = control_hidden_states + attn_output |
| 251 | + |
| 252 | + # 3. Feed-forward |
| 253 | + with jax.named_scope("ffn"): |
| 254 | + norm_hidden_states = ( |
| 255 | + self.norm3(control_hidden_states.astype(jnp.float32)) |
| 256 | + * (1 + c_scale_msa) |
| 257 | + + c_shift_msa |
| 258 | + ).astype(control_hidden_states.dtype) |
| 259 | + ff_output = self.ffn( |
| 260 | + norm_hidden_states, deterministic=deterministic, rngs=rngs |
| 261 | + ) |
| 262 | + control_hidden_states = ( |
| 263 | + control_hidden_states.astype(jnp.float32) |
| 264 | + + ff_output.astype(jnp.float32) * c_gate_msa |
| 265 | + ).astype(control_hidden_states.dtype) |
| 266 | + conditioning_states = None |
| 267 | + if self.apply_output_projection: |
| 268 | + conditioning_states = self.proj_out(control_hidden_states) |
| 269 | + |
| 270 | + return conditioning_states, control_hidden_states |
0 commit comments