1919import jax .numpy as jnp
2020from flax import nnx
2121from maxdiffusion import common_types
22- from ..attention_ltx2 import LTX2Attention
22+ from maxdiffusion .models .ltx2 .attention_ltx2 import LTX2Attention
23+ from maxdiffusion .models .attention_flax import NNXSimpleFeedForward
2324
2425Array = common_types .Array
2526DType = common_types .DType
2627
27-
28- class FeedForward (nnx .Module ):
29-
30- def __init__ (self , dim : int , dim_out : Optional [int ] = None , mult : int = 4 , dropout : float = 0.0 , rngs : nnx .Rngs = None ):
31- inner_dim = int (dim * mult )
32- dim_out = dim_out if dim_out is not None else dim
33-
34- self .proj1 = nnx .Linear (dim , inner_dim , rngs = rngs )
35- self .proj2 = nnx .Linear (inner_dim , dim_out , rngs = rngs )
36-
37- def __call__ (self , x : Array ) -> Array :
38- x = self .proj1 (x )
39- x = jax .nn .gelu (x )
40- x = self .proj2 (x )
41- return x
42-
43-
4428class _BasicTransformerBlock1D (nnx .Module ):
4529
4630 def __init__ (
@@ -50,6 +34,7 @@ def __init__(
5034 dim_head : int ,
5135 rope_type : str = "interleaved" ,
5236 attention_kernel : str = "flash" ,
37+ mesh : jax .sharding .Mesh = None ,
5338 rngs : nnx .Rngs = None ,
5439 ):
5540 self .attn1 = LTX2Attention (
@@ -60,9 +45,10 @@ def __init__(
6045 bias = True , # LTX-2 default
6146 out_bias = True ,
6247 attention_kernel = attention_kernel ,
48+ mesh = mesh ,
6349 rngs = rngs ,
6450 )
65- self .ff = FeedForward ( dim , dim_out = dim , rngs = rngs )
51+ self .ff = NNXSimpleFeedForward ( rngs = rngs , dim = dim , dim_out = dim )
6652 self .norm1 = nnx .RMSNorm (dim , epsilon = 1e-6 , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , rngs = rngs )
6753 self .norm2 = nnx .RMSNorm (dim , epsilon = 1e-6 , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , rngs = rngs )
6854
@@ -101,6 +87,7 @@ def __init__(
10187 num_learnable_registers : int = 128 ,
10288 rope_type : str = "interleaved" ,
10389 attention_kernel : str = "flash" ,
90+ mesh : jax .sharding .Mesh = None ,
10491 rngs : nnx .Rngs = None ,
10592 ):
10693 self .dim = input_dim
@@ -115,7 +102,13 @@ def __init__(
115102 @nnx .vmap (in_axes = 0 , out_axes = 0 , axis_size = layers )
116103 def create_block (rngs ):
117104 return _BasicTransformerBlock1D (
118- dim = input_dim , heads = heads , dim_head = head_dim , rope_type = rope_type , attention_kernel = attention_kernel , rngs = rngs
105+ dim = input_dim ,
106+ heads = heads ,
107+ dim_head = head_dim ,
108+ rope_type = rope_type ,
109+ attention_kernel = attention_kernel ,
110+ mesh = mesh ,
111+ rngs = rngs ,
119112 )
120113
121114 # Call the vmapped constructor
0 commit comments