@@ -1595,15 +1595,17 @@ def decode(
15951595 keys_slice = jax .random .split (key , latents .shape [0 ])
15961596 decoded_slices = []
15971597 for i in range (latents .shape [0 ]):
1598- z_slice = latents [i : i + 1 ]
1599- t_slice = temb [i : i + 1 ] if temb is not None else None
1600- subkey = keys_slice [i ] if keys_slice is not None else None
1601- res = self ._decode (z_slice , t_slice , key = subkey , causal = causal , return_dict = True )
1602- decoded_slices .append (res .sample )
1598+ with jax .named_scope (f"Decode Slice { i } " ):
1599+ z_slice = latents [i : i + 1 ]
1600+ t_slice = temb [i : i + 1 ] if temb is not None else None
1601+ subkey = keys_slice [i ] if keys_slice is not None else None
1602+ res = self ._decode (z_slice , t_slice , key = subkey , causal = causal , return_dict = True )
1603+ decoded_slices .append (res .sample )
16031604
16041605 dec = jnp .concatenate (decoded_slices , axis = 0 )
16051606 else :
1606- dec = self ._decode (latents , temb , key = key , causal = causal , return_dict = True ).sample
1607+ with jax .named_scope ("Decode Full Batch" ):
1608+ dec = self ._decode (latents , temb , key = key , causal = causal , return_dict = True ).sample
16071609
16081610 if not return_dict :
16091611 return (dec ,)
0 commit comments