|
| 1 | +from typing import Union, Iterable, Tuple, Optional, Callable |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import jax |
| 5 | +import jax.numpy as jnp |
| 6 | +from flax import linen as nn |
| 7 | +from flax.linen.initializers import lecun_normal |
| 8 | + |
| 9 | + |
| 10 | +Shape = Tuple[int, ...] |
| 11 | +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] |
| 12 | +InitializerAxis = Union[int, Shape] |
| 13 | + |
| 14 | + |
| 15 | +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: |
| 16 | + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. |
| 17 | + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) |
| 18 | + |
| 19 | + |
| 20 | +def _canonicalize_tuple(x): |
| 21 | + if isinstance(x, Iterable): |
| 22 | + return tuple(x) |
| 23 | + else: |
| 24 | + return (x,) |
| 25 | + |
| 26 | + |
| 27 | +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] |
| 28 | +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] |
| 29 | + |
| 30 | + |
| 31 | +class DenseGeneral(nn.Module): |
| 32 | + """A linear transformation with flexible axes. |
| 33 | +
|
| 34 | + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 |
| 35 | +
|
| 36 | + Attributes: |
| 37 | + features: tuple with numbers of output features. |
| 38 | + axis: tuple with axes to apply the transformation on. |
| 39 | + weight_dtype: the dtype of the weights (default: float32). |
| 40 | + dtype: the dtype of the computation (default: float32). |
| 41 | + kernel_init: initializer function for the weight matrix. |
| 42 | + use_bias: whether to add bias in linear transformation. |
| 43 | + bias_norm: whether to add normalization before adding bias. |
| 44 | + quant: quantization config, defaults to None implying no quantization. |
| 45 | + """ |
| 46 | + |
| 47 | + features: Union[Iterable[int], int] |
| 48 | + axis: Union[Iterable[int], int] = -1 |
| 49 | + weight_dtype: jnp.dtype = jnp.float32 |
| 50 | + dtype: np.dtype = jnp.float32 |
| 51 | + kernel_init: KernelInitializer = lecun_normal() |
| 52 | + kernel_axes: Tuple[Optional[str], ...] = () |
| 53 | + use_bias: bool = False |
| 54 | + matmul_precision: str = "default" |
| 55 | + |
| 56 | + bias_init: Initializer = jax.nn.initializers.constant(0.0) |
| 57 | + |
| 58 | + @nn.compact |
| 59 | + def __call__(self, inputs: jax.Array) -> jax.Array: |
| 60 | + """Applies a linear transformation to the inputs along multiple dimensions. |
| 61 | +
|
| 62 | + Args: |
| 63 | + inputs: The nd-array to be transformed. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + The transformed input. |
| 67 | + """ |
| 68 | + |
| 69 | + def compute_dot_general(inputs, kernel, axis, contract_ind): |
| 70 | + """Computes a dot_general operation that may be quantized.""" |
| 71 | + dot_general = jax.lax.dot_general |
| 72 | + matmul_precision = jax.lax.Precision(self.matmul_precision) |
| 73 | + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) |
| 74 | + |
| 75 | + features = _canonicalize_tuple(self.features) |
| 76 | + axis = _canonicalize_tuple(self.axis) |
| 77 | + |
| 78 | + inputs = jnp.asarray(inputs, self.dtype) |
| 79 | + axis = _normalize_axes(axis, inputs.ndim) |
| 80 | + |
| 81 | + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features |
| 82 | + kernel_in_axis = np.arange(len(axis)) |
| 83 | + kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) |
| 84 | + kernel = self.param( |
| 85 | + "kernel", |
| 86 | + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), |
| 87 | + kernel_shape, |
| 88 | + self.weight_dtype, |
| 89 | + ) |
| 90 | + kernel = jnp.asarray(kernel, self.dtype) |
| 91 | + |
| 92 | + contract_ind = tuple(range(0, len(axis))) |
| 93 | + output = compute_dot_general(inputs, kernel, axis, contract_ind) |
| 94 | + |
| 95 | + if self.use_bias: |
| 96 | + bias_axes, bias_shape = ( |
| 97 | + self.kernel_axes[-len(features) :], |
| 98 | + kernel_shape[-len(features) :], |
| 99 | + ) |
| 100 | + bias = self.param( |
| 101 | + "bias", |
| 102 | + nn.with_logical_partitioning(self.bias_init, bias_axes), |
| 103 | + bias_shape, |
| 104 | + self.weight_dtype, |
| 105 | + ) |
| 106 | + bias = jnp.asarray(bias, self.dtype) |
| 107 | + |
| 108 | + output += bias |
| 109 | + return output |
0 commit comments