@@ -655,14 +655,16 @@ def __init__(
655655 ):
656656 out_channels = out_channels or in_channels
657657
658- self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
659- rngs = rngs ,
660- embedding_dim = in_channels * 4 ,
661- size_emb_dim = 0 ,
662- use_additional_conditions = False ,
663- dtype = dtype ,
664- weights_dtype = weights_dtype
665- ))
658+ self .time_embedder = None
659+ if timestep_conditioning :
660+ self .time_embedder = nnx .data (NNXPixArtAlphaCombinedTimestepSizeEmbeddings (
661+ rngs = rngs ,
662+ embedding_dim = in_channels * 4 ,
663+ size_emb_dim = 0 ,
664+ use_additional_conditions = False ,
665+ dtype = dtype ,
666+ weights_dtype = weights_dtype
667+ ))
666668
667669 if in_channels != out_channels :
668670 self .conv_in = nnx .data (LTX2VideoResnetBlock3d (
@@ -1068,6 +1070,60 @@ def __call__(
10681070 return hidden_states
10691071
10701072
1073+
1074+ class LTX2DiagonalGaussianDistribution (nnx .Module ):
1075+ def __init__ (self , parameters : jax .Array , deterministic : bool = False ):
1076+ self .parameters = parameters
1077+ # Split into mean and logvar
1078+ # LTX-2 specific: 128 channels for mean, 1 channel for logvar
1079+ self .mean , self .logvar = jnp .split (parameters , [128 ], axis = - 1 )
1080+ self .logvar = jnp .clip (self .logvar , - 30.0 , 20.0 )
1081+ self .deterministic = deterministic
1082+ self .std = jnp .exp (0.5 * self .logvar )
1083+ self .var = jnp .exp (self .logvar )
1084+ if self .deterministic :
1085+ self .var = self .std = jnp .zeros_like (
1086+ self .mean , dtype = self .parameters .dtype
1087+ )
1088+
1089+ def sample (self , key : jax .Array ) -> jax .Array :
1090+ # make sure sample is on the same device as the parameters and has same dtype
1091+ sample = jax .random .normal (key , self .mean .shape , dtype = self .parameters .dtype )
1092+ x = self .mean + self .std * sample
1093+ return x
1094+
1095+ def kl (self , other : "LTX2DiagonalGaussianDistribution" = None ) -> jax .Array :
1096+ if self .deterministic :
1097+ return jnp .array ([0.0 ])
1098+ else :
1099+ if other is None :
1100+ return 0.5 * jnp .sum (
1101+ jnp .power (self .mean , 2 ) + self .var - 1.0 - self .logvar ,
1102+ axis = [1 , 2 , 3 ],
1103+ )
1104+ else :
1105+ return 0.5 * jnp .sum (
1106+ jnp .power (self .mean - other .mean , 2 ) / other .var
1107+ + self .var / other .var
1108+ - 1.0
1109+ - self .logvar
1110+ + other .logvar ,
1111+ axis = [1 , 2 , 3 ],
1112+ )
1113+
1114+ def nll (self , sample : jax .Array , dims : Tuple [int , ...] = (1 , 2 , 3 )) -> jax .Array :
1115+ if self .deterministic :
1116+ return jnp .array ([0.0 ])
1117+ logtwopi = jnp .log (2.0 * jnp .pi )
1118+ return 0.5 * jnp .sum (
1119+ logtwopi + self .logvar + jnp .power (sample - self .mean , 2 ) / self .var ,
1120+ axis = dims ,
1121+ )
1122+
1123+ def mode (self ) -> jax .Array :
1124+ return self .mean
1125+
1126+
10711127class LTX2VideoAutoencoderKL (nnx .Module , ConfigMixin ):
10721128 _supports_gradient_checkpointing = True
10731129 config_name = "config.json"
@@ -1510,7 +1566,8 @@ def encode(
15101566 else :
15111567 moments = self ._encode (sample , key = key , causal = causal )
15121568
1513- posterior = FlaxDiagonalGaussianDistribution (moments )
1569+
1570+ posterior = LTX2DiagonalGaussianDistribution (moments )
15141571
15151572 if not return_dict :
15161573 return (posterior ,)
0 commit comments