1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Tuple , Union , Optional , Sequence
15+ from typing import Tuple , Union , Optional , Sequence , Any
1616
1717import jax
1818import jax .numpy as jnp
@@ -584,6 +584,7 @@ def __init__(
584584 dtype : jnp .dtype = jnp .float32 ,
585585 weights_dtype : jnp .dtype = jnp .float32 ,
586586 precision : jax .lax .Precision = None ,
587+ sharding_specs : Optional [Any ] = None ,
587588 ):
588589 if timestep_conditioning :
589590 self .time_embedder = nnx .data (
@@ -594,6 +595,7 @@ def __init__(
594595 use_additional_conditions = False ,
595596 dtype = dtype ,
596597 weights_dtype = weights_dtype ,
598+ sharding_specs = sharding_specs ,
597599 )
598600 )
599601 else :
@@ -674,6 +676,7 @@ def __init__(
674676 dtype : jnp .dtype = jnp .float32 ,
675677 weights_dtype : jnp .dtype = jnp .float32 ,
676678 precision : jax .lax .Precision = None ,
679+ sharding_specs : Optional [Any ] = None ,
677680 ):
678681 out_channels = out_channels or in_channels
679682
@@ -687,6 +690,7 @@ def __init__(
687690 use_additional_conditions = False ,
688691 dtype = dtype ,
689692 weights_dtype = weights_dtype ,
693+ sharding_specs = sharding_specs ,
690694 )
691695 )
692696
@@ -960,6 +964,7 @@ def __init__(
960964 dtype : jnp .dtype = jnp .float32 ,
961965 weights_dtype : jnp .dtype = jnp .float32 ,
962966 precision : jax .lax .Precision = None ,
967+ sharding_specs : Optional [Any ] = None ,
963968 ):
964969 self .patch_size = patch_size
965970 self .patch_size_t = patch_size_t
@@ -999,6 +1004,7 @@ def __init__(
9991004 dtype = dtype ,
10001005 weights_dtype = weights_dtype ,
10011006 precision = precision ,
1007+ sharding_specs = sharding_specs ,
10021008 )
10031009
10041010 # up blocks
@@ -1026,6 +1032,7 @@ def __init__(
10261032 dtype = dtype ,
10271033 weights_dtype = weights_dtype ,
10281034 precision = precision ,
1035+ sharding_specs = sharding_specs ,
10291036 )
10301037 )
10311038
@@ -1059,6 +1066,7 @@ def __init__(
10591066 use_additional_conditions = False ,
10601067 dtype = dtype ,
10611068 weights_dtype = weights_dtype ,
1069+ sharding_specs = sharding_specs ,
10621070 )
10631071 )
10641072 else :
@@ -1155,6 +1163,7 @@ def __init__(
11551163 dtype : jnp .dtype = jnp .float32 ,
11561164 weights_dtype : jnp .dtype = jnp .float32 ,
11571165 precision : jax .lax .Precision = None ,
1166+ sharding_specs : Optional [Any ] = None ,
11581167 ):
11591168 self .encoder = LTX2VideoEncoder3d (
11601169 in_channels = in_channels ,
@@ -1196,6 +1205,7 @@ def __init__(
11961205 dtype = dtype ,
11971206 weights_dtype = weights_dtype ,
11981207 precision = precision ,
1208+ sharding_specs = sharding_specs ,
11991209 )
12001210
12011211 self .scaling_factor = scaling_factor
0 commit comments