Skip to content

Commit 75db696

Browse files
committed
feat(ltx2): pass sharding specs to VAE and embeddings connector
1 parent e4e281c commit 75db696

2 files changed

Lines changed: 17 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Tuple, Union, Optional, Sequence
15+
from typing import Tuple, Union, Optional, Sequence, Any
1616

1717
import jax
1818
import jax.numpy as jnp
@@ -584,6 +584,7 @@ def __init__(
584584
dtype: jnp.dtype = jnp.float32,
585585
weights_dtype: jnp.dtype = jnp.float32,
586586
precision: jax.lax.Precision = None,
587+
sharding_specs: Optional[Any] = None,
587588
):
588589
if timestep_conditioning:
589590
self.time_embedder = nnx.data(
@@ -594,6 +595,7 @@ def __init__(
594595
use_additional_conditions=False,
595596
dtype=dtype,
596597
weights_dtype=weights_dtype,
598+
sharding_specs=sharding_specs,
597599
)
598600
)
599601
else:
@@ -674,6 +676,7 @@ def __init__(
674676
dtype: jnp.dtype = jnp.float32,
675677
weights_dtype: jnp.dtype = jnp.float32,
676678
precision: jax.lax.Precision = None,
679+
sharding_specs: Optional[Any] = None,
677680
):
678681
out_channels = out_channels or in_channels
679682

@@ -687,6 +690,7 @@ def __init__(
687690
use_additional_conditions=False,
688691
dtype=dtype,
689692
weights_dtype=weights_dtype,
693+
sharding_specs=sharding_specs,
690694
)
691695
)
692696

@@ -960,6 +964,7 @@ def __init__(
960964
dtype: jnp.dtype = jnp.float32,
961965
weights_dtype: jnp.dtype = jnp.float32,
962966
precision: jax.lax.Precision = None,
967+
sharding_specs: Optional[Any] = None,
963968
):
964969
self.patch_size = patch_size
965970
self.patch_size_t = patch_size_t
@@ -999,6 +1004,7 @@ def __init__(
9991004
dtype=dtype,
10001005
weights_dtype=weights_dtype,
10011006
precision=precision,
1007+
sharding_specs=sharding_specs,
10021008
)
10031009

10041010
# up blocks
@@ -1026,6 +1032,7 @@ def __init__(
10261032
dtype=dtype,
10271033
weights_dtype=weights_dtype,
10281034
precision=precision,
1035+
sharding_specs=sharding_specs,
10291036
)
10301037
)
10311038

@@ -1059,6 +1066,7 @@ def __init__(
10591066
use_additional_conditions=False,
10601067
dtype=dtype,
10611068
weights_dtype=weights_dtype,
1069+
sharding_specs=sharding_specs,
10621070
)
10631071
)
10641072
else:
@@ -1155,6 +1163,7 @@ def __init__(
11551163
dtype: jnp.dtype = jnp.float32,
11561164
weights_dtype: jnp.dtype = jnp.float32,
11571165
precision: jax.lax.Precision = None,
1166+
sharding_specs: Optional[Any] = None,
11581167
):
11591168
self.encoder = LTX2VideoEncoder3d(
11601169
in_channels=in_channels,
@@ -1196,6 +1205,7 @@ def __init__(
11961205
dtype=dtype,
11971206
weights_dtype=weights_dtype,
11981207
precision=precision,
1208+
sharding_specs=sharding_specs,
11991209
)
12001210

12011211
self.scaling_factor = scaling_factor

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Optional, Tuple
17+
from typing import Optional, Tuple, Any
1818
import jax
1919
import jax.numpy as jnp
2020
from flax import nnx
@@ -37,6 +37,7 @@ def __init__(
3737
attention_kernel: str = "flash",
3838
mesh: jax.sharding.Mesh = None,
3939
rngs: nnx.Rngs = None,
40+
sharding_specs: Optional[Any] = None,
4041
):
4142
self.attn1 = LTX2Attention(
4243
query_dim=dim,
@@ -48,8 +49,9 @@ def __init__(
4849
attention_kernel=attention_kernel,
4950
mesh=mesh,
5051
rngs=rngs,
52+
sharding_specs=sharding_specs,
5153
)
52-
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh")
54+
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh", sharding_specs=sharding_specs)
5355
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
5456
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
5557

@@ -92,6 +94,7 @@ def __init__(
9294
attention_kernel: str = "flash",
9395
mesh: jax.sharding.Mesh = None,
9496
rngs: nnx.Rngs = None,
97+
sharding_specs: Optional[Any] = None,
9598
):
9699
self.dim = input_dim
97100
self.heads = heads
@@ -117,6 +120,7 @@ def create_block(rngs):
117120
attention_kernel=attention_kernel,
118121
mesh=mesh,
119122
rngs=rngs,
123+
sharding_specs=sharding_specs,
120124
)
121125

122126
# Call the vmapped constructor

0 commit comments

Comments
 (0)