@@ -562,6 +562,7 @@ def __init__(
562562 self .conv_out = nnx .Conv (block_in , z_channels , kernel_size = (3 , 3 ), padding = "SAME" , dtype = dtype , rngs = rngs )
563563
564564 def __call__ (self , x , train : bool = False ):
565+ print (f"[LTX2 XPROF Tracing] Audio Encoder __call__ input shape: { x .shape } " )
565566 h = self .conv_in (x )
566567
567568 for stage in self .down_stages :
@@ -702,6 +703,7 @@ def __init__(
702703 self .conv_out = nnx .Conv (block_in , self .output_channels , kernel_size = (3 , 3 ), padding = "SAME" , dtype = dtype , rngs = rngs )
703704
704705 def __call__ (self , z , target_frames = None , target_mel_bins = None , train : bool = False ):
706+ print (f"[LTX2 XPROF Tracing] Audio Decoder __call__ input shape: { z .shape } " )
705707 h = self .conv_in (z )
706708
707709 h = self .mid_block1 (h , train = train )
@@ -825,6 +827,7 @@ def __init__(
825827 self .latents_std = nnx .Param (jnp .ones ((base_channels ,), dtype = dtype ))
826828
827829 def encode (self , x : jnp .ndarray , return_dict : bool = True , train : bool = False ):
830+ print (f"[LTX2 XPROF Tracing] Audio VAE encode input shape: { x .shape } " )
828831 h = self .encoder (x , train = train )
829832 posterior = FlaxDiagonalGaussianDistribution (h )
830833
@@ -833,6 +836,7 @@ def encode(self, x: jnp.ndarray, return_dict: bool = True, train: bool = False):
833836 return FlaxAutoencoderKLOutput (latent_dist = posterior )
834837
835838 def decode (self , z : jnp .ndarray , return_dict : bool = True , train : bool = False ):
839+ print (f"[LTX2 XPROF Tracing] Audio VAE decode input shape: { z .shape } " )
836840 batch , time , freq , channels = z .shape
837841 target_frames = time * self .latent_downsample_factor
838842 if self .causality_axis is not None and self .causality_axis != "none" :
0 commit comments