Skip to content

Commit ffd7933

Browse files
committed
fixing attention from merging main
1 parent 10f2f33 commit ffd7933

1 file changed

Lines changed: 51 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .. import common_types, max_logging
3232

3333
from . import quantizations
34+
from .modeling_flax_utils import get_activation
3435

3536

3637
Array = common_types.Array
@@ -134,6 +135,7 @@ def _reshape_heads_to_head_dim(tensor):
134135
# This is used to transform the output of flash attention back into the format of other attention outputs
135136
b, h, s, d = tensor.shape
136137
tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3])
138+
reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d))
137139
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
138140
return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names)
139141

@@ -693,6 +695,52 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
693695
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
694696

695697

698+
class NNXSimpleFeedForward(nnx.Module):
699+
700+
def __init__(
701+
self,
702+
rngs: nnx.Rngs,
703+
dim: int,
704+
dim_out: Optional[int] = None,
705+
mult: int = 4,
706+
activation_fn: str = "gelu",
707+
dtype: jnp.dtype = jnp.float32,
708+
weights_dtype: jnp.dtype = jnp.float32,
709+
precision: Optional[jax.lax.Precision] = None,
710+
):
711+
inner_dim = int(dim * mult)
712+
dim_out = dim_out if dim_out is not None else dim
713+
self.net_0 = nnx.Linear(
714+
dim,
715+
inner_dim,
716+
rngs=rngs,
717+
use_bias=True,
718+
dtype=dtype,
719+
param_dtype=weights_dtype,
720+
precision=precision,
721+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)),
722+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
723+
)
724+
self.act = get_activation(activation_fn)
725+
self.net_2 = nnx.Linear(
726+
inner_dim,
727+
dim_out,
728+
rngs=rngs,
729+
use_bias=True,
730+
dtype=dtype,
731+
param_dtype=weights_dtype,
732+
precision=precision,
733+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
734+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
735+
)
736+
737+
def __call__(self, hidden_states: Array) -> Array:
738+
hidden_states = self.net_0(hidden_states)
739+
hidden_states = self.act(hidden_states)
740+
hidden_states = self.net_2(hidden_states)
741+
return hidden_states
742+
743+
696744
class NNXAttentionOp(nnx.Module):
697745

698746
def __init__(
@@ -849,6 +897,8 @@ def __init__(
849897
mask_padding_tokens: bool = True,
850898
residual_checkpoint_name: str | None = None,
851899
enable_jax_named_scopes: bool = False,
900+
added_kv_proj_dim: Optional[int] = None,
901+
image_seq_len: Optional[int] = None,
852902
):
853903
if attention_kernel == "cudnn_flash_te":
854904
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
@@ -1007,6 +1057,7 @@ def __call__(
10071057
hidden_states: jax.Array,
10081058
encoder_hidden_states: jax.Array = None,
10091059
rotary_emb: Optional[jax.Array] = None,
1060+
encoder_attention_mask: Optional[jax.Array] = None,
10101061
deterministic: bool = True,
10111062
rngs: nnx.Rngs = None,
10121063
) -> jax.Array:

0 commit comments

Comments
 (0)