Skip to content

Commit e44ddbc

Browse files
committed
adaln.py for nnx, classes NNXPixArtAlphaCombinedTimestepSizeEmbeddings and NNXTimesteps added
1 parent f23746b commit e44ddbc

2 files changed

Lines changed: 70 additions & 0 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,38 @@ def __call__(self, timestep, guidance, pooled_projection):
501501
conditioning = time_guidance_emb + pooled_projections
502502

503503
return conditioning
504+
505+
506+
class NNXTimesteps(nnx.Module):
507+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
508+
self.num_channels = num_channels
509+
self.flip_sin_to_cos = flip_sin_to_cos
510+
self.downscale_freq_shift = downscale_freq_shift
511+
self.scale = scale
512+
513+
def __call__(self, timesteps: jax.Array) -> jax.Array:
514+
return get_sinusoidal_embeddings(
515+
timesteps=timesteps,
516+
embedding_dim=self.num_channels,
517+
freq_shift=self.downscale_freq_shift,
518+
flip_sin_to_cos=self.flip_sin_to_cos,
519+
scale=self.scale
520+
)
521+
522+
523+
class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module):
524+
def __init__(self, rngs: nnx.Rngs, embedding_dim: int, size_emb_dim: int, dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32):
525+
self.outdim = size_emb_dim
526+
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
527+
self.timestep_embedder = NNXTimestepEmbedding(
528+
rngs=rngs,
529+
in_channels=256,
530+
time_embed_dim=embedding_dim,
531+
dtype=dtype,
532+
weights_dtype=weights_dtype
533+
)
534+
535+
def __call__(self, timestep: jax.Array, hidden_dtype: jnp.dtype = jnp.float32) -> jax.Array:
536+
timesteps_proj = self.time_proj(timestep)
537+
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
538+
return timesteps_emb
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
import jax
3+
import jax.numpy as jnp
4+
from flax import nnx
5+
from typing import Optional, Tuple
6+
from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings
7+
8+
class AdaLayerNormSingle(nnx.Module):
9+
"""
10+
Norm layer adaptive layer norm single (adaLN-single).
11+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
12+
"""
13+
def __init__(self, rngs: nnx.Rngs, embedding_dim: int, embedding_coefficient: int = 6, dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32):
14+
self.emb = NNXPixArtAlphaCombinedTimestepSizeEmbeddings(
15+
rngs=rngs,
16+
embedding_dim=embedding_dim,
17+
size_emb_dim=embedding_dim // 3,
18+
dtype=dtype,
19+
weights_dtype=weights_dtype
20+
)
21+
self.silu = nnx.silu
22+
self.linear = nnx.Linear(
23+
rngs=rngs,
24+
in_features=embedding_dim,
25+
out_features=embedding_coefficient * embedding_dim,
26+
use_bias=True,
27+
dtype=dtype,
28+
param_dtype=weights_dtype,
29+
kernel_init=nnx.initializers.zeros,
30+
bias_init=nnx.initializers.zeros
31+
)
32+
33+
def __call__(self, timestep: jax.Array, hidden_dtype: Optional[jnp.dtype] = None) -> Tuple[jax.Array, jax.Array]:
34+
embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
35+
return self.linear(self.silu(embedded_timestep)), embedded_timestep

0 commit comments

Comments
 (0)