Skip to content

Commit 9142610

Browse files
committed
use NNXSimpleFeedForward, add mesh support
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 62d994a commit 9142610

4 files changed

Lines changed: 22 additions & 20 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,12 @@
1919
import jax.numpy as jnp
2020
from flax import nnx
2121
from maxdiffusion import common_types
22-
from ..attention_ltx2 import LTX2Attention
22+
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention
23+
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward
2324

2425
Array = common_types.Array
2526
DType = common_types.DType
2627

27-
28-
class FeedForward(nnx.Module):
29-
30-
def __init__(self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, rngs: nnx.Rngs = None):
31-
inner_dim = int(dim * mult)
32-
dim_out = dim_out if dim_out is not None else dim
33-
34-
self.proj1 = nnx.Linear(dim, inner_dim, rngs=rngs)
35-
self.proj2 = nnx.Linear(inner_dim, dim_out, rngs=rngs)
36-
37-
def __call__(self, x: Array) -> Array:
38-
x = self.proj1(x)
39-
x = jax.nn.gelu(x)
40-
x = self.proj2(x)
41-
return x
42-
43-
4428
class _BasicTransformerBlock1D(nnx.Module):
4529

4630
def __init__(
@@ -50,6 +34,7 @@ def __init__(
5034
dim_head: int,
5135
rope_type: str = "interleaved",
5236
attention_kernel: str = "flash",
37+
mesh: jax.sharding.Mesh = None,
5338
rngs: nnx.Rngs = None,
5439
):
5540
self.attn1 = LTX2Attention(
@@ -60,9 +45,10 @@ def __init__(
6045
bias=True, # LTX-2 default
6146
out_bias=True,
6247
attention_kernel=attention_kernel,
48+
mesh=mesh,
6349
rngs=rngs,
6450
)
65-
self.ff = FeedForward(dim, dim_out=dim, rngs=rngs)
51+
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim)
6652
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
6753
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
6854

@@ -101,6 +87,7 @@ def __init__(
10187
num_learnable_registers: int = 128,
10288
rope_type: str = "interleaved",
10389
attention_kernel: str = "flash",
90+
mesh: jax.sharding.Mesh = None,
10491
rngs: nnx.Rngs = None,
10592
):
10693
self.dim = input_dim
@@ -115,7 +102,13 @@ def __init__(
115102
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers)
116103
def create_block(rngs):
117104
return _BasicTransformerBlock1D(
118-
dim=input_dim, heads=heads, dim_head=head_dim, rope_type=rope_type, attention_kernel=attention_kernel, rngs=rngs
105+
dim=input_dim,
106+
heads=heads,
107+
dim_head=head_dim,
108+
rope_type=rope_type,
109+
attention_kernel=attention_kernel,
110+
mesh=mesh,
111+
rngs=rngs,
119112
)
120113

121114
# Call the vmapped constructor

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
num_thinking_tokens: int = 128,
4747
dtype: DType = jnp.float32,
4848
attention_kernel: str = "flash",
49+
mesh: jax.sharding.Mesh = None,
4950
rngs: nnx.Rngs = None,
5051
):
5152
input_dim = gemma_dim * gemma_layers
@@ -65,6 +66,7 @@ def __init__(
6566
num_learnable_registers=num_thinking_tokens,
6667
rope_type="interleaved",
6768
attention_kernel=attention_kernel,
69+
mesh=mesh,
6870
rngs=rngs,
6971
)
7072

@@ -106,6 +108,7 @@ def __init__(
106108
num_thinking_tokens: int = 128,
107109
dtype: DType = jnp.float32,
108110
attention_kernel: str = "flash",
111+
mesh: jax.sharding.Mesh = None,
109112
rngs: nnx.Rngs = None,
110113
):
111114
input_dim = gemma_dim * gemma_layers
@@ -126,6 +129,7 @@ def __init__(
126129
num_learnable_registers=num_thinking_tokens,
127130
rope_type="interleaved",
128131
attention_kernel=attention_kernel,
132+
mesh=mesh,
129133
rngs=rngs,
130134
)
131135

@@ -137,6 +141,7 @@ def __init__(
137141
num_learnable_registers=num_thinking_tokens,
138142
rope_type="interleaved",
139143
attention_kernel=attention_kernel,
144+
mesh=mesh,
140145
rngs=rngs,
141146
)
142147

src/maxdiffusion/tests/test_embeddings_connector_ltx2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_thinking_tokens_replacement(self):
4444
head_dim=self.head_dim,
4545
layers=1,
4646
num_learnable_registers=self.num_learnable_registers,
47+
mesh=None,
4748
rngs=self.rng,
4849
)
4950

@@ -96,6 +97,7 @@ def test_forward_shape_and_run(self):
9697
layers=2,
9798
num_learnable_registers=self.num_learnable_registers,
9899
attention_kernel="dot_product", # Use dot_product for testing on CPU
100+
mesh=None,
99101
rngs=self.rng,
100102
)
101103

src/maxdiffusion/tests/test_text_encoders_ltx2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_video_encoder_forward(self):
4747
connector_layers=1,
4848
num_thinking_tokens=8,
4949
attention_kernel="dot_product",
50+
mesh=None,
5051
rngs=self.rng,
5152
)
5253

@@ -66,6 +67,7 @@ def test_av_encoder_forward(self):
6667
connector_layers=1,
6768
num_thinking_tokens=8,
6869
attention_kernel="dot_product",
70+
mesh=None,
6971
rngs=self.rng,
7072
)
7173

0 commit comments

Comments
 (0)