Skip to content

Commit b148d5a

Browse files
committed
annotations for VAE
1 parent a325a13 commit b148d5a

1 file changed

Lines changed: 8 additions & 6 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,15 +1595,17 @@ def decode(
15951595
keys_slice = jax.random.split(key, latents.shape[0])
15961596
decoded_slices = []
15971597
for i in range(latents.shape[0]):
1598-
z_slice = latents[i : i + 1]
1599-
t_slice = temb[i : i + 1] if temb is not None else None
1600-
subkey = keys_slice[i] if keys_slice is not None else None
1601-
res = self._decode(z_slice, t_slice, key=subkey, causal=causal, return_dict=True)
1602-
decoded_slices.append(res.sample)
1598+
with jax.named_scope(f"Decode Slice {i}"):
1599+
z_slice = latents[i : i + 1]
1600+
t_slice = temb[i : i + 1] if temb is not None else None
1601+
subkey = keys_slice[i] if keys_slice is not None else None
1602+
res = self._decode(z_slice, t_slice, key=subkey, causal=causal, return_dict=True)
1603+
decoded_slices.append(res.sample)
16031604

16041605
dec = jnp.concatenate(decoded_slices, axis=0)
16051606
else:
1606-
dec = self._decode(latents, temb, key=key, causal=causal, return_dict=True).sample
1607+
with jax.named_scope("Decode Full Batch"):
1608+
dec = self._decode(latents, temb, key=key, causal=causal, return_dict=True).sample
16071609

16081610
if not return_dict:
16091611
return (dec,)

0 commit comments

Comments
 (0)