@@ -892,31 +892,36 @@ def run_inference(
892892 return latents , scheduler_state
893893
894894
895- def adain_filter_latent (latents : torch .Tensor , reference_latents : torch .Tensor , factor = 1.0 ):
896- """
897- Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on
898- statistics from a reference latent tensor.
899-
900- Args:
901- latent (torch.Tensor): Input latents to normalize
902- reference_latent (torch.Tensor): The reference latents providing style statistics.
903- factor (float): Blending factor between original and transformed latent.
904- Range: -10.0 to 10.0, Default: 1.0
905-
906- Returns:
907- torch.Tensor: The transformed latent tensor
908- """
909- result = latents .clone ()
910-
911- for i in range (latents .size (0 )):
912- for c in range (latents .size (1 )):
913- r_sd , r_mean = torch .std_mean (reference_latents [i , c ], dim = None ) # index by original dim order
914- i_sd , i_mean = torch .std_mean (result [i , c ], dim = None )
915-
916- result [i , c ] = ((result [i , c ] - i_mean ) / i_sd ) * r_sd + r_mean
917-
918- result = torch .lerp (latents , result , factor )
919- return result
895+ def adain_filter_latent (
896+ latents : torch .Tensor , reference_latents : torch .Tensor , factor = 1.0
897+ ):
898+ """
899+ Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on
900+ statistics from a reference latent tensor.
901+
902+ Args:
903+ latent (torch.Tensor): Input latents to normalize
904+ reference_latent (torch.Tensor): The reference latents providing style statistics.
905+ factor (float): Blending factor between original and transformed latent.
906+ Range: -10.0 to 10.0, Default: 1.0
907+
908+ Returns:
909+ torch.Tensor: The transformed latent tensor
910+ """
911+ with default_env ():
912+ result = latents .clone ()
913+
914+ for i in range (latents .size (0 )):
915+ for c in range (latents .size (1 )):
916+ r_sd , r_mean = torch .std_mean (
917+ reference_latents [i , c ], dim = None
918+ ) # index by original dim order
919+ i_sd , i_mean = torch .std_mean (result [i , c ], dim = None )
920+
921+ result [i , c ] = ((result [i , c ] - i_mean ) / i_sd ) * r_sd + r_mean
922+
923+ result = torch .lerp (latents , result , factor )
924+ return result
920925
921926
922927class LTXMultiScalePipeline :
0 commit comments