@@ -51,8 +51,7 @@ def __call__(self, x: jax.Array, channel_dim: Optional[int] = None) -> jax.Array
5151 channel_dim = channel_dim if channel_dim is not None else self .channel_dim
5252 # Compute mean of squared values along channel dimension.
5353 mean_sq = jnp .mean (jnp .square (x ), axis = channel_dim , keepdims = True )
54- rms = jnp .sqrt (mean_sq + self .eps )
55- return x / rms
54+ return x * jax .lax .rsqrt (mean_sq + self .eps )
5655
5756
5857class LTX2VideoCausalConv3d (nnx .Module ):
@@ -226,7 +225,7 @@ def __init__(
226225 self .per_channel_scale2 = None
227226
228227 if timestep_conditioning :
229- self .scale_shift_table = nnx .Param (jax .random .normal (rngs .params (), (4 , in_channels )) / (in_channels ** 0.5 ))
228+ self .scale_shift_table = nnx .Param (jax .random .normal (rngs .params (), (4 , in_channels ), dtype = dtype ) / (in_channels ** 0.5 ))
230229 else :
231230 self .scale_shift_table = None
232231
@@ -1261,24 +1260,42 @@ def enable_tiling(
12611260
12621261 def blend_v (self , a : jax .Array , b : jax .Array , blend_extent : int ) -> jax .Array :
12631262 blend_extent = min (a .shape [2 ], b .shape [2 ], blend_extent )
1264- for y in range (blend_extent ):
1265- val = a [:, :, - blend_extent + y , :, :] * (1 - y / blend_extent ) + b [:, :, y , :, :] * (y / blend_extent )
1266- b = b .at [:, :, y , :, :].set (val )
1267- return b
1263+ if blend_extent <= 0 :
1264+ return b
1265+
1266+ # Create broadcastable blending weights: (1, 1, blend_extent, 1, 1)
1267+ y = jnp .arange (blend_extent , dtype = a .dtype ).reshape (1 , 1 , - 1 , 1 , 1 )
1268+
1269+ val = a [:, :, - blend_extent :, :, :] * (1.0 - y / blend_extent ) + \
1270+ b [:, :, :blend_extent , :, :] * (y / blend_extent )
1271+
1272+ return b .at [:, :, :blend_extent , :, :].set (val )
12681273
12691274 def blend_h (self , a : jax .Array , b : jax .Array , blend_extent : int ) -> jax .Array :
12701275 blend_extent = min (a .shape [3 ], b .shape [3 ], blend_extent )
1271- for x in range (blend_extent ):
1272- val = a [:, :, :, - blend_extent + x , :] * (1 - x / blend_extent ) + b [:, :, :, x , :] * (x / blend_extent )
1273- b = b .at [:, :, :, x , :].set (val )
1274- return b
1276+ if blend_extent <= 0 :
1277+ return b
1278+
1279+ # Create broadcastable blending weights: (1, 1, 1, blend_extent, 1)
1280+ x = jnp .arange (blend_extent , dtype = a .dtype ).reshape (1 , 1 , 1 , - 1 , 1 )
1281+
1282+ val = a [:, :, :, - blend_extent :, :] * (1.0 - x / blend_extent ) + \
1283+ b [:, :, :, :blend_extent , :] * (x / blend_extent )
1284+
1285+ return b .at [:, :, :, :blend_extent , :].set (val )
12751286
12761287 def blend_t (self , a : jax .Array , b : jax .Array , blend_extent : int ) -> jax .Array :
12771288 blend_extent = min (a .shape [1 ], b .shape [1 ], blend_extent )
1278- for x in range (blend_extent ):
1279- val = a [:, - blend_extent + x , :, :, :] * (1 - x / blend_extent ) + b [:, x , :, :, :] * (x / blend_extent )
1280- b = b .at [:, x , :, :, :].set (val )
1281- return b
1289+ if blend_extent <= 0 :
1290+ return b
1291+
1292+ # Create broadcastable blending weights: (1, blend_extent, 1, 1, 1)
1293+ x = jnp .arange (blend_extent , dtype = a .dtype ).reshape (1 , - 1 , 1 , 1 , 1 )
1294+
1295+ val = a [:, - blend_extent :, :, :, :] * (1.0 - x / blend_extent ) + \
1296+ b [:, :blend_extent , :, :, :] * (x / blend_extent )
1297+
1298+ return b .at [:, :blend_extent , :, :, :].set (val )
12821299
12831300 def tiled_encode (self , x : jax .Array , key : Optional [jax .Array ] = None , causal : Optional [bool ] = None ) -> jax .Array :
12841301 B , T , H , W , C = x .shape
0 commit comments