Skip to content

Commit 9b5ed92

Browse files
committed
embeddings connector changes
1 parent 3f2320c commit 9b5ed92

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
attention_kernel: str = "flash",
3838
mesh: jax.sharding.Mesh = None,
3939
rngs: nnx.Rngs = None,
40+
gated_attn: bool = False,
4041
):
4142
self.attn1 = LTX2Attention(
4243
query_dim=dim,
@@ -48,6 +49,7 @@ def __init__(
4849
attention_kernel=attention_kernel,
4950
mesh=mesh,
5051
rngs=rngs,
52+
gated_attn=gated_attn,
5153
)
5254
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh")
5355
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
@@ -92,6 +94,7 @@ def __init__(
9294
attention_kernel: str = "flash",
9395
mesh: jax.sharding.Mesh = None,
9496
rngs: nnx.Rngs = None,
97+
gated_attn: bool = False,
9598
):
9699
self.dim = input_dim
97100
self.heads = heads
@@ -117,6 +120,7 @@ def create_block(rngs):
117120
attention_kernel=attention_kernel,
118121
mesh=mesh,
119122
rngs=rngs,
123+
gated_attn=gated_attn,
120124
)
121125

122126
# Call the vmapped constructor

0 commit comments

Comments
 (0)