We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 77bd6dd commit 7ff54d4Copy full SHA for 7ff54d4
2 files changed
src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py
@@ -68,6 +68,12 @@ class TextConnectorShardingSpecs:
68
net_0_bias: tuple = ("mlp",)
69
net_2_kernel: tuple = ("mlp", None)
70
net_2_bias: tuple = (None,)
71
+ # Attention specs (defaulting to trillium/safe defaults)
72
+ qkv_kernel: tuple = ("embed", "heads")
73
+ out_kernel: tuple = ("heads", "embed")
74
+ out_bias: tuple = ("embed",)
75
+ qkv_bias: tuple = ("heads",)
76
+ norm_scale: tuple = ("norm",)
77
78
79
@dataclass
@@ -95,7 +101,11 @@ class VAEShardingSpecs:
95
101
use_batched_text_encoder=True,
96
102
text_encoder_kernel=(None, "embed"),
97
103
),
98
- "text_connector": TextConnectorShardingSpecs(),
104
+ "text_connector": TextConnectorShardingSpecs(
105
+ qkv_kernel=(None, "heads"),
106
+ out_kernel=("heads", None),
107
+ out_bias=(None,),
108
+ ),
99
109
"vae": VAEShardingSpecs(vae_conv_kernel=("batch", None, None, None)),
100
110
},
111
"trillium": {
src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py
@@ -49,6 +49,7 @@ def __init__(
49
attention_kernel=attention_kernel,
50
mesh=mesh,
51
rngs=rngs,
52
+ sharding_specs=sharding_specs,
53
)
54
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh", sharding_specs=sharding_specs)
55
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
0 commit comments