Skip to content

Commit 7ff54d4

Browse files
committed
feat(ltx2): add attention specs to text connector and pass them
1 parent 77bd6dd commit 7ff54d4

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ class TextConnectorShardingSpecs:
6868
net_0_bias: tuple = ("mlp",)
6969
net_2_kernel: tuple = ("mlp", None)
7070
net_2_bias: tuple = (None,)
71+
# Attention specs (defaulting to trillium/safe defaults)
72+
qkv_kernel: tuple = ("embed", "heads")
73+
out_kernel: tuple = ("heads", "embed")
74+
out_bias: tuple = ("embed",)
75+
qkv_bias: tuple = ("heads",)
76+
norm_scale: tuple = ("norm",)
7177

7278

7379
@dataclass
@@ -95,7 +101,11 @@ class VAEShardingSpecs:
95101
use_batched_text_encoder=True,
96102
text_encoder_kernel=(None, "embed"),
97103
),
98-
"text_connector": TextConnectorShardingSpecs(),
104+
"text_connector": TextConnectorShardingSpecs(
105+
qkv_kernel=(None, "heads"),
106+
out_kernel=("heads", None),
107+
out_bias=(None,),
108+
),
99109
"vae": VAEShardingSpecs(vae_conv_kernel=("batch", None, None, None)),
100110
},
101111
"trillium": {

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
attention_kernel=attention_kernel,
5050
mesh=mesh,
5151
rngs=rngs,
52+
sharding_specs=sharding_specs,
5253
)
5354
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh", sharding_specs=sharding_specs)
5455
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)

0 commit comments

Comments
 (0)