|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer.""" |
| 16 | + |
| 17 | +import jax |
| 18 | +from jax.sharding import Mesh |
| 19 | + |
| 20 | +import jax.numpy as jnp |
| 21 | +from flax import nnx |
| 22 | +from typing import Callable |
| 23 | +from MaxText.common_types import Config, Array |
| 24 | +from MaxText.layers.normalizations import RMSNorm |
| 25 | +from MaxText.layers.initializers import nd_dense_init, default_bias_init, default_scalar_init |
| 26 | +from MaxText.common_types import HyperConnectionType |
| 27 | + |
| 28 | + |
| 29 | +def get_functions(expansion_rate: int): |
| 30 | + """ |
| 31 | + Creates functions to broadcast a single feature stream into multiple |
| 32 | + parallel paths (expand) and aggregate them back (reduce). |
| 33 | + """ |
| 34 | + |
| 35 | + def expand(x: Array): |
| 36 | + # (batch, length, dim) -> (batch, length, streams, dim) |
| 37 | + return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2) |
| 38 | + |
| 39 | + def reduce(x: Array): |
| 40 | + # (batch, length, streams, dim) -> (batch, length, dim) |
| 41 | + return jnp.sum(x, axis=2) |
| 42 | + |
| 43 | + return expand, reduce |
| 44 | + |
| 45 | + |
| 46 | +def sinkhorn(t, iters=20): |
| 47 | + """ |
| 48 | + Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1). |
| 49 | + """ |
| 50 | + # Use float32 precision for numerical stability during normalization |
| 51 | + initial_dtype = t.dtype |
| 52 | + t = t.astype(jnp.float32) |
| 53 | + |
| 54 | + # Column-wise normalization (axis=-2) - positive and sum up to 1 across columns |
| 55 | + # Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2) |
| 56 | + t = jax.nn.softmax(t, axis=-2) |
| 57 | + |
| 58 | + def body_fun(i, val): |
| 59 | + # L1 Normalization: val / sum(val) with clipping of denominator |
| 60 | + # Normalize rows (axis -1) |
| 61 | + val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12) |
| 62 | + # Normalize columns (axis -2) |
| 63 | + val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12) |
| 64 | + return val |
| 65 | + |
| 66 | + # Use lax.fori_loop for an efficient, JIT-friendly loop |
| 67 | + t = jax.lax.fori_loop(0, iters, body_fun, t) |
| 68 | + return t.astype(initial_dtype) |
| 69 | + |
| 70 | + |
| 71 | +class ManifoldConstrainedHyperConnections(nnx.Module): |
| 72 | + """Implements Manifold-Constrained Hyper-Connections (mHC). |
| 73 | +
|
| 74 | + Reference: https://arxiv.org/pdf/2512.24880 |
| 75 | +
|
| 76 | + Args: |
| 77 | + config: Configuration object containing hyperparameters. |
| 78 | + dim: The feature dimensionality. |
| 79 | + mesh: The hardware mesh for sharding. |
| 80 | + rngs: Random number generation in NNX. |
| 81 | + """ |
| 82 | + |
| 83 | + def __init__( |
| 84 | + self, |
| 85 | + config: Config, |
| 86 | + dim: int, |
| 87 | + mesh: Mesh, |
| 88 | + rngs: nnx.Rngs, |
| 89 | + ): |
| 90 | + self.config = config |
| 91 | + self.sinkhorn_iterations = config.sinkhorn_iterations |
| 92 | + self.k = config.mhc_expansion_rate |
| 93 | + self.dim = dim |
| 94 | + self.rngs = rngs |
| 95 | + self.mesh = mesh |
| 96 | + self.weight_dtype = self.config.weight_dtype |
| 97 | + |
| 98 | + # Norm layer |
| 99 | + self.mhc_norm = RMSNorm( |
| 100 | + num_features=self.k * self.dim, |
| 101 | + dtype=self.config.dtype, |
| 102 | + weight_dtype=self.weight_dtype, |
| 103 | + kernel_axes=("norm",), |
| 104 | + epsilon=self.config.normalization_layer_epsilon, |
| 105 | + rngs=self.rngs, |
| 106 | + ) |
| 107 | + |
| 108 | + # Scalars |
| 109 | + self.res_alpha_scale = nnx.Param( |
| 110 | + default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), |
| 111 | + sharding=(None,), |
| 112 | + ) |
| 113 | + self.pre_alpha_scale = nnx.Param( |
| 114 | + default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), |
| 115 | + sharding=(None,), |
| 116 | + ) |
| 117 | + self.post_alpha_scale = nnx.Param( |
| 118 | + default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), |
| 119 | + sharding=(None,), |
| 120 | + ) |
| 121 | + |
| 122 | + # Weight matrices |
| 123 | + scale_init = nd_dense_init(1.0, "fan_in", "normal") |
| 124 | + in_axis = 0 |
| 125 | + out_axis = 1 |
| 126 | + weight_sharding_axis_name = ("activation_embed", None) |
| 127 | + self.res_alpha = nnx.Param( |
| 128 | + scale_init( |
| 129 | + self.rngs.params(), |
| 130 | + (self.k * self.dim, self.k * self.k), |
| 131 | + self.weight_dtype, |
| 132 | + in_axis=in_axis, |
| 133 | + out_axis=out_axis, |
| 134 | + ), |
| 135 | + sharding=weight_sharding_axis_name, |
| 136 | + ) |
| 137 | + self.pre_alpha = nnx.Param( |
| 138 | + scale_init( |
| 139 | + self.rngs.params(), |
| 140 | + (self.k * self.dim, self.k), |
| 141 | + self.weight_dtype, |
| 142 | + in_axis=in_axis, |
| 143 | + out_axis=out_axis, |
| 144 | + ), |
| 145 | + sharding=weight_sharding_axis_name, |
| 146 | + ) |
| 147 | + self.post_alpha = nnx.Param( |
| 148 | + scale_init( |
| 149 | + self.rngs.params(), |
| 150 | + (self.k * self.dim, self.k), |
| 151 | + self.weight_dtype, |
| 152 | + in_axis=in_axis, |
| 153 | + out_axis=out_axis, |
| 154 | + ), |
| 155 | + sharding=weight_sharding_axis_name, |
| 156 | + ) |
| 157 | + |
| 158 | + # Biases |
| 159 | + self.res_beta = nnx.Param( |
| 160 | + default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype), |
| 161 | + sharding=(None, None), |
| 162 | + ) |
| 163 | + self.pre_beta = nnx.Param( |
| 164 | + default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), |
| 165 | + sharding=(None, None), |
| 166 | + ) |
| 167 | + self.post_beta = nnx.Param( |
| 168 | + default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), |
| 169 | + sharding=(None, None), |
| 170 | + ) |
| 171 | + |
| 172 | + def res_mapping(self, x: Array): |
| 173 | + """Helper function for residual mapping.""" |
| 174 | + # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) |
| 175 | + h_res = jnp.einsum("bsm,mn -> bsn", x, self.res_alpha[...], precision=self.config.matmul_precision) |
| 176 | + b, s, _ = h_res.shape |
| 177 | + h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) |
| 178 | + intermediate = self.res_alpha_scale * h_res + self.res_beta[...][None, None, :, :] |
| 179 | + output = sinkhorn(intermediate, self.sinkhorn_iterations) |
| 180 | + return output |
| 181 | + |
| 182 | + def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): |
| 183 | + """Helper function for both pre and post mappings.""" |
| 184 | + # Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k) |
| 185 | + h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.config.matmul_precision) |
| 186 | + intermediate = alpha_scale * h + beta[None, None, :] |
| 187 | + output = scale * jax.nn.sigmoid(intermediate) |
| 188 | + return output |
| 189 | + |
| 190 | + def __call__( |
| 191 | + self, |
| 192 | + branch_fn: Callable, |
| 193 | + x: Array, |
| 194 | + mhc_type: HyperConnectionType, |
| 195 | + **kwargs, |
| 196 | + ) -> Array: |
| 197 | + """Applying manifold-constrained hyper connection based on callable function. |
| 198 | +
|
| 199 | + Args: |
| 200 | + branch_fn: The function to be wrapped by the hyper-connection. |
| 201 | + x: Input tensor of shape `(batch..., dim)`. |
| 202 | + mhc_type: The variant of the connection to apply. |
| 203 | + **kwargs: Additional context passed to the branch function. |
| 204 | +
|
| 205 | + Returns: |
| 206 | + The processed tensor, maintaining the shape of `x`. |
| 207 | + """ |
| 208 | + # x shape: [batch, seq, expansion_rate, emb] |
| 209 | + b, s, k, d = x.shape |
| 210 | + |
| 211 | + # 1. Flatten the tensor, and RMS normalization |
| 212 | + norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d))) |
| 213 | + |
| 214 | + # 2. Pre mapping |
| 215 | + pre_mapping = self.mapping(norm_x, self.pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0) |
| 216 | + layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.config.matmul_precision) |
| 217 | + |
| 218 | + # 3. Attention or MLP |
| 219 | + if mhc_type == HyperConnectionType.ATTENTION: |
| 220 | + layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs) |
| 221 | + elif mhc_type == HyperConnectionType.MLP_DENSE: |
| 222 | + layer_out = branch_fn(inputs=layer_input, **kwargs) |
| 223 | + elif mhc_type == HyperConnectionType.MLP_MOE: |
| 224 | + layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs) |
| 225 | + else: |
| 226 | + raise ValueError(f"Unsupported type: {mhc_type}") |
| 227 | + |
| 228 | + # 4. Post mapping |
| 229 | + post_mapping = self.mapping(norm_x, self.post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0) |
| 230 | + post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.config.matmul_precision) |
| 231 | + |
| 232 | + # 5. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] |
| 233 | + res_mapping = self.res_mapping(norm_x) |
| 234 | + res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.config.matmul_precision) |
| 235 | + return res_out + post_out |
0 commit comments