Skip to content

Commit fefe18e

Browse files
committed
prepare latents
1 parent 443243d commit fefe18e

1 file changed

Lines changed: 30 additions & 25 deletions

File tree

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -892,31 +892,36 @@ def run_inference(
892892
return latents, scheduler_state
893893

894894

895-
def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0):
896-
"""
897-
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on
898-
statistics from a reference latent tensor.
899-
900-
Args:
901-
latent (torch.Tensor): Input latents to normalize
902-
reference_latent (torch.Tensor): The reference latents providing style statistics.
903-
factor (float): Blending factor between original and transformed latent.
904-
Range: -10.0 to 10.0, Default: 1.0
905-
906-
Returns:
907-
torch.Tensor: The transformed latent tensor
908-
"""
909-
result = latents.clone()
910-
911-
for i in range(latents.size(0)):
912-
for c in range(latents.size(1)):
913-
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
914-
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
915-
916-
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
917-
918-
result = torch.lerp(latents, result, factor)
919-
return result
895+
def adain_filter_latent(
896+
latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
897+
):
898+
"""
899+
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on
900+
statistics from a reference latent tensor.
901+
902+
Args:
903+
latent (torch.Tensor): Input latents to normalize
904+
reference_latent (torch.Tensor): The reference latents providing style statistics.
905+
factor (float): Blending factor between original and transformed latent.
906+
Range: -10.0 to 10.0, Default: 1.0
907+
908+
Returns:
909+
torch.Tensor: The transformed latent tensor
910+
"""
911+
with default_env():
912+
result = latents.clone()
913+
914+
for i in range(latents.size(0)):
915+
for c in range(latents.size(1)):
916+
r_sd, r_mean = torch.std_mean(
917+
reference_latents[i, c], dim=None
918+
) # index by original dim order
919+
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
920+
921+
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
922+
923+
result = torch.lerp(latents, result, factor)
924+
return result
920925

921926

922927
class LTXMultiScalePipeline:

0 commit comments

Comments
 (0)