Skip to content

Commit 15d242e

Browse files
wip - wan transformer
1 parent 296e956 commit 15d242e

4 files changed

Lines changed: 456 additions & 12 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,139 @@ def chunk_scanner(chunk_idx, _):
383383

384384
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
385385

386+
def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
387+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
388+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
389+
390+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
391+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
392+
393+
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
394+
395+
class FlaxWanAttention(nn.module):
396+
query_dim: int
397+
heads: int = 8
398+
dim_head: int = 64
399+
dropout: float = 0.0
400+
use_memory_efficient_attention: bool = False
401+
split_head_dim: bool = False
402+
attention_kernel: str = "dot_product"
403+
flash_min_seq_length: int = 4096
404+
flash_block_sizes: BlockSizes = None
405+
mesh: jax.sharding.Mesh = None
406+
dtype: jnp.dtype = jnp.float32
407+
weights_dtype: jnp.dtype = jnp.float32
408+
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
409+
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
410+
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
411+
out_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
412+
precision: jax.lax.Precision = None
413+
qkv_bias: bool = False
414+
415+
def setup(self):
416+
if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None:
417+
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
418+
inner_dim = self.dim_head * self.heads
419+
scale = self.dim_head**-0.5
420+
421+
self.attention_op = AttentionOp(
422+
mesh=self.mesh,
423+
attention_kernel=self.attention_kernel,
424+
scale=scale,
425+
heads=self.heads,
426+
dim_head=self.dim_head,
427+
flash_min_seq_length=self.flash_min_seq_length,
428+
use_memory_efficient_attention=self.use_memory_efficient_attention,
429+
split_head_dim=self.split_head_dim,
430+
flash_block_sizes=self.flash_block_sizes,
431+
dtype=self.dtype,
432+
float32_qk_product=False,
433+
)
434+
435+
kernel_axes = ("embed", "heads")
436+
qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes)
437+
438+
qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "heads"))
439+
440+
self.query = nn.Dense(
441+
inner_dim,
442+
kernel_init=qkv_init_kernel,
443+
use_bias=False,
444+
dtype=self.dtype,
445+
param_dtype=self.weights_dtype,
446+
name="to_q",
447+
precision=self.precision,
448+
)
449+
450+
self.key = nn.Dense(
451+
inner_dim,
452+
kernel_init=qkv_init_kernel,
453+
use_bias=False,
454+
dtype=self.dtype,
455+
param_dtype=self.weights_dtype,
456+
name="to_k",
457+
precision=self.precision,
458+
)
459+
460+
self.value = nn.Dense(
461+
inner_dim,
462+
kernel_init=qkv_init_kernel,
463+
use_bias=False,
464+
dtype=self.dtype,
465+
param_dtype=self.weights_dtype,
466+
name="to_v",
467+
precision=self.precision,
468+
)
469+
470+
self.query_norm = nn.RMSNorm(
471+
dtype=self.dtype,
472+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
473+
param_dtype=self.weights_dtype,
474+
)
475+
self.key_norm = nn.RMSNorm(
476+
dtype=self.dtype,
477+
scale_init=nn.with_logical_partitioning(nn.initializers.ones, ("heads",)),
478+
param_dtype=self.weights_dtype,
479+
)
480+
481+
self.proj_attn = nn.Dense(
482+
self.query_dim,
483+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("heads", "embed")),
484+
dtype=self.dtype,
485+
param_dtype=self.weights_dtype,
486+
name="to_out_0",
487+
precision=self.precision,
488+
)
489+
self.dropout_layer = nn.Dropout(rate=self.dropout)
490+
491+
def call(
492+
self,
493+
hidden_states: Array,
494+
encoder_hidden_states: Optional[Array],
495+
rotary_emb: Optional[Array],
496+
deterministic: bool = True
497+
):
498+
encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
499+
500+
query_proj = self.query(hidden_states)
501+
key_proj = self.key(encoder_hidden_states)
502+
value_proj = self.value(encoder_hidden_states)
503+
504+
query_proj = self.query_norm(query_proj)
505+
key_proj = self.key_norm(key_proj)
506+
507+
if rotary_emb:
508+
query_proj, key_proj = self.apply_rope(query_proj, key_proj, rotary_emb)
509+
510+
query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
511+
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
512+
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)
513+
514+
hidden_states = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
515+
516+
hidden_states = self.proj_attn(hidden_states)
517+
hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD))
518+
return self.dropout_layer(hidden_states, deterministic=deterministic)
386519

387520
class FlaxFluxAttention(nn.Module):
388521
query_dim: int
@@ -493,15 +626,6 @@ def setup(self):
493626
param_dtype=self.weights_dtype,
494627
)
495628

496-
def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
497-
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
498-
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
499-
500-
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
501-
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
502-
503-
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
504-
505629
def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):
506630

507631
qkv_proj = self.qkv(hidden_states)
@@ -535,7 +659,7 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non
535659
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)
536660

537661
image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2)
538-
query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb)
662+
query_proj, key_proj = apply_rope(query_proj, key_proj, image_rotary_emb)
539663

540664
query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1)
541665
key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1)

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def get_sinusoidal_embeddings(
5656
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
5757
return signal
5858

59-
6059
class FlaxTimestepEmbedding(nn.Module):
6160
r"""
6261
Time step Embedding Module. Learns embeddings for input time steps.
@@ -91,7 +90,8 @@ class FlaxTimesteps(nn.Module):
9190

9291
dim: int = 32
9392
flip_sin_to_cos: bool = False
94-
freq_shift: float = 1
93+
freq_shift: float = 1.0
94+
scale: int = 1
9595

9696
@nn.compact
9797
def __call__(self, timesteps):
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""

0 commit comments

Comments
 (0)