Skip to content

Commit 4d44e21

Browse files
committed
Added dtype arg to jax.random.normal, refactored blend_v, blend_h, blend_t methods, replaced 1/jnp.sqrt with jax.lax.rsqrt
1 parent fa69f87 commit 4d44e21

1 file changed

Lines changed: 32 additions & 15 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def __call__(self, x: jax.Array, channel_dim: Optional[int] = None) -> jax.Array
5151
channel_dim = channel_dim if channel_dim is not None else self.channel_dim
5252
# Compute mean of squared values along channel dimension.
5353
mean_sq = jnp.mean(jnp.square(x), axis=channel_dim, keepdims=True)
54-
rms = jnp.sqrt(mean_sq + self.eps)
55-
return x / rms
54+
return x * jax.lax.rsqrt(mean_sq + self.eps)
5655

5756

5857
class LTX2VideoCausalConv3d(nnx.Module):
@@ -226,7 +225,7 @@ def __init__(
226225
self.per_channel_scale2 = None
227226

228227
if timestep_conditioning:
229-
self.scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (4, in_channels)) / (in_channels**0.5))
228+
self.scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (4, in_channels), dtype=dtype) / (in_channels**0.5))
230229
else:
231230
self.scale_shift_table = None
232231

@@ -1261,24 +1260,42 @@ def enable_tiling(
12611260

12621261
def blend_v(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12631262
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
1264-
for y in range(blend_extent):
1265-
val = a[:, :, -blend_extent + y, :, :] * (1 - y / blend_extent) + b[:, :, y, :, :] * (y / blend_extent)
1266-
b = b.at[:, :, y, :, :].set(val)
1267-
return b
1263+
if blend_extent <= 0:
1264+
return b
1265+
1266+
# Create broadcastable blending weights: (1, 1, blend_extent, 1, 1)
1267+
y = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, 1, -1, 1, 1)
1268+
1269+
val = a[:, :, -blend_extent:, :, :] * (1.0 - y / blend_extent) + \
1270+
b[:, :, :blend_extent, :, :] * (y / blend_extent)
1271+
1272+
return b.at[:, :, :blend_extent, :, :].set(val)
12681273

12691274
def blend_h(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12701275
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1271-
for x in range(blend_extent):
1272-
val = a[:, :, :, -blend_extent + x, :] * (1 - x / blend_extent) + b[:, :, :, x, :] * (x / blend_extent)
1273-
b = b.at[:, :, :, x, :].set(val)
1274-
return b
1276+
if blend_extent <= 0:
1277+
return b
1278+
1279+
# Create broadcastable blending weights: (1, 1, 1, blend_extent, 1)
1280+
x = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, 1, 1, -1, 1)
1281+
1282+
val = a[:, :, :, -blend_extent:, :] * (1.0 - x / blend_extent) + \
1283+
b[:, :, :, :blend_extent, :] * (x / blend_extent)
1284+
1285+
return b.at[:, :, :, :blend_extent, :].set(val)
12751286

12761287
def blend_t(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12771288
blend_extent = min(a.shape[1], b.shape[1], blend_extent)
1278-
for x in range(blend_extent):
1279-
val = a[:, -blend_extent + x, :, :, :] * (1 - x / blend_extent) + b[:, x, :, :, :] * (x / blend_extent)
1280-
b = b.at[:, x, :, :, :].set(val)
1281-
return b
1289+
if blend_extent <= 0:
1290+
return b
1291+
1292+
# Create broadcastable blending weights: (1, blend_extent, 1, 1, 1)
1293+
x = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, -1, 1, 1, 1)
1294+
1295+
val = a[:, -blend_extent:, :, :, :] * (1.0 - x / blend_extent) + \
1296+
b[:, :blend_extent, :, :, :] * (x / blend_extent)
1297+
1298+
return b.at[:, :blend_extent, :, :, :].set(val)
12821299

12831300
def tiled_encode(self, x: jax.Array, key: Optional[jax.Array] = None, causal: Optional[bool] = None) -> jax.Array:
12841301
B, T, H, W, C = x.shape

0 commit comments

Comments
 (0)