|
31 | 31 | from .. import common_types, max_logging |
32 | 32 |
|
33 | 33 | from . import quantizations |
| 34 | +from .modeling_flax_utils import get_activation |
34 | 35 |
|
35 | 36 |
|
36 | 37 | Array = common_types.Array |
@@ -134,6 +135,7 @@ def _reshape_heads_to_head_dim(tensor): |
134 | 135 | # This is used to transform the output of flash attention back into the format of other attention outputs |
135 | 136 | b, h, s, d = tensor.shape |
136 | 137 | tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) |
| 138 | + reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d)) |
137 | 139 | axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) |
138 | 140 | return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) |
139 | 141 |
|
@@ -693,6 +695,52 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: |
693 | 695 | return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) |
694 | 696 |
|
695 | 697 |
|
| 698 | +class NNXSimpleFeedForward(nnx.Module): |
| 699 | + |
| 700 | + def __init__( |
| 701 | + self, |
| 702 | + rngs: nnx.Rngs, |
| 703 | + dim: int, |
| 704 | + dim_out: Optional[int] = None, |
| 705 | + mult: int = 4, |
| 706 | + activation_fn: str = "gelu", |
| 707 | + dtype: jnp.dtype = jnp.float32, |
| 708 | + weights_dtype: jnp.dtype = jnp.float32, |
| 709 | + precision: Optional[jax.lax.Precision] = None, |
| 710 | + ): |
| 711 | + inner_dim = int(dim * mult) |
| 712 | + dim_out = dim_out if dim_out is not None else dim |
| 713 | + self.net_0 = nnx.Linear( |
| 714 | + dim, |
| 715 | + inner_dim, |
| 716 | + rngs=rngs, |
| 717 | + use_bias=True, |
| 718 | + dtype=dtype, |
| 719 | + param_dtype=weights_dtype, |
| 720 | + precision=precision, |
| 721 | + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)), |
| 722 | + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)), |
| 723 | + ) |
| 724 | + self.act = get_activation(activation_fn) |
| 725 | + self.net_2 = nnx.Linear( |
| 726 | + inner_dim, |
| 727 | + dim_out, |
| 728 | + rngs=rngs, |
| 729 | + use_bias=True, |
| 730 | + dtype=dtype, |
| 731 | + param_dtype=weights_dtype, |
| 732 | + precision=precision, |
| 733 | + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")), |
| 734 | + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), |
| 735 | + ) |
| 736 | + |
| 737 | + def __call__(self, hidden_states: Array) -> Array: |
| 738 | + hidden_states = self.net_0(hidden_states) |
| 739 | + hidden_states = self.act(hidden_states) |
| 740 | + hidden_states = self.net_2(hidden_states) |
| 741 | + return hidden_states |
| 742 | + |
| 743 | + |
696 | 744 | class NNXAttentionOp(nnx.Module): |
697 | 745 |
|
698 | 746 | def __init__( |
@@ -849,6 +897,8 @@ def __init__( |
849 | 897 | mask_padding_tokens: bool = True, |
850 | 898 | residual_checkpoint_name: str | None = None, |
851 | 899 | enable_jax_named_scopes: bool = False, |
| 900 | + added_kv_proj_dim: Optional[int] = None, |
| 901 | + image_seq_len: Optional[int] = None, |
852 | 902 | ): |
853 | 903 | if attention_kernel == "cudnn_flash_te": |
854 | 904 | raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") |
@@ -1007,6 +1057,7 @@ def __call__( |
1007 | 1057 | hidden_states: jax.Array, |
1008 | 1058 | encoder_hidden_states: jax.Array = None, |
1009 | 1059 | rotary_emb: Optional[jax.Array] = None, |
| 1060 | + encoder_attention_mask: Optional[jax.Array] = None, |
1010 | 1061 | deterministic: bool = True, |
1011 | 1062 | rngs: nnx.Rngs = None, |
1012 | 1063 | ) -> jax.Array: |
|
0 commit comments