1+ """
2+ Copyright 2026 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
17+ import unittest
18+ from unittest .mock import MagicMock , patch
19+ import jax
20+ import jax .numpy as jnp
21+ import numpy as np
22+
23+ from maxdiffusion .models .ltx2 .ltx2_utils import adain_filter_latent , tone_map_latents
24+ from maxdiffusion .models .ltx2 .latent_upsampler_ltx2 import LTX2LatentUpsamplerModel
25+ from maxdiffusion .pipelines .ltx2 .pipeline_ltx2_latent_upsample import FlaxLTX2LatentUpsamplePipeline
26+
27+
28+ class LTX2LatentUpsamplerTest (unittest .TestCase ):
29+ """Tests for LTX2 Latent Upsampler components and pipeline."""
30+
31+ def test_adain_filter_latent (self ):
32+ """Test ADAIN filtering matches global statistics."""
33+ # Create latents and reference latents with different statistics
34+ key = jax .random .PRNGKey (0 )
35+ key1 , key2 = jax .random .split (key )
36+
37+ # Target (High-res) latents: mean ~ 0, std ~ 1
38+ latents = jax .random .normal (key1 , (1 , 4 , 16 , 16 , 8 ))
39+
40+ # Reference (Low-res) latents: mean ~ 5, std ~ 2
41+ reference_latents = jax .random .normal (key2 , (1 , 4 , 16 , 16 , 8 )) * 2.0 + 5.0
42+
43+ # Apply AdaIN with factor=1.0 (full replacement of style)
44+ filtered = adain_filter_latent (latents , reference_latents , factor = 1.0 )
45+
46+ # Validate shapes
47+ self .assertEqual (filtered .shape , latents .shape )
48+
49+ # Validate statistics: output should now roughly match reference stats
50+ axes = (1 , 2 , 3 )
51+ ref_mean = jnp .mean (reference_latents , axis = axes , keepdims = True )
52+ ref_std = jnp .std (reference_latents , axis = axes , keepdims = True )
53+
54+ out_mean = jnp .mean (filtered , axis = axes , keepdims = True )
55+ out_std = jnp .std (filtered , axis = axes , keepdims = True )
56+
57+ np .testing .assert_allclose (out_mean , ref_mean , rtol = 1e-4 , atol = 1e-4 )
58+ np .testing .assert_allclose (out_std , ref_std , rtol = 1e-4 , atol = 1e-4 )
59+
60+ # Test factor = 0.0 (no change)
61+ unfiltered = adain_filter_latent (latents , reference_latents , factor = 0.0 )
62+ np .testing .assert_allclose (unfiltered , latents , rtol = 1e-5 )
63+
64+ def test_tone_map_latents (self ):
65+ """Test tone mapping compression scale logic."""
66+ latents = jnp .ones ((1 , 4 , 16 , 16 , 8 )) * 2.0
67+
68+ # Compress with 0 ratio should do nothing when properly scaled,
69+ # but based on the code: scale_factor = compression * 0.75
70+ # If compression = 0.0, scale_factor = 0, scales = 1.0
71+ mapped_0 = tone_map_latents (latents , compression = 0.0 )
72+ np .testing .assert_allclose (mapped_0 , latents , rtol = 1e-5 )
73+
74+ # Compress with > 0 ratio should reduce the magnitude
75+ mapped_compressed = tone_map_latents (latents , compression = 1.0 )
76+ self .assertTrue (jnp .all (jnp .abs (mapped_compressed ) < jnp .abs (latents )))
77+ self .assertEqual (mapped_compressed .shape , latents .shape )
78+
79+ def test_upsampler_model_forward (self ):
80+ """Test the neural network component of the upsampler for shape validity."""
81+ b , f , h , w , c = 2 , 3 , 16 , 16 , 8
82+ key = jax .random .PRNGKey (0 )
83+
84+ # Instantiate the module with small channels/blocks to keep test fast.
85+ # mid_channels MUST be a multiple of 32 because GroupNorm uses num_groups=32 natively.
86+ model = LTX2LatentUpsamplerModel (
87+ in_channels = c ,
88+ mid_channels = 32 , # Fixed: Changed from 16 to 32 to satisfy GroupNorm requirements
89+ num_blocks_per_stage = 1 ,
90+ dims = 3 ,
91+ spatial_upsample = True ,
92+ temporal_upsample = False ,
93+ rational_spatial_scale = 2.0 # Maps to 2x upscaling
94+ )
95+
96+ dummy_input = jax .random .normal (key , (b , f , h , w , c ))
97+
98+ # Initialize variables
99+ variables = model .init (key , dummy_input )
100+
101+ # Forward pass
102+ output = model .apply (variables , dummy_input )
103+
104+ # Assert temporal unchanged, spatial doubled, channels restored to `in_channels`
105+ self .assertEqual (output .shape , (b , f , h * 2 , w * 2 , c ))
106+
107+ def test_pipeline_latent_upsample_logic (self ):
108+ """Test FlaxLTX2LatentUpsamplePipeline call pipeline properties."""
109+ mock_vae = MagicMock ()
110+ # Need to simulate the config behavior where parameters might be attached to VAE directly
111+ mock_vae .config = {"spatial_compression_ratio" : 32 , "temporal_compression_ratio" : 8 }
112+ mock_vae .latents_mean = [0.0 ] * 8
113+ mock_vae .latents_std = [1.0 ] * 8
114+ mock_vae .dtype = jnp .float32
115+
116+ # Dummy decode output logic (tuple with a video array)
117+ dummy_video = jnp .zeros ((1 , 1 , 32 , 32 , 3 ))
118+ mock_vae .decode .return_value = (dummy_video ,)
119+
120+ mock_upsampler = MagicMock ()
121+ # Upsampler .apply() should just return identically shaped / scaled latents for testing logic
122+ mock_upsampler .apply = MagicMock (return_value = jnp .ones ((1 , 4 , 16 , 16 , 8 )))
123+
124+ pipeline = FlaxLTX2LatentUpsamplePipeline (
125+ vae = mock_vae ,
126+ latent_upsampler = mock_upsampler ,
127+ )
128+
129+ # Bypass VideoProcessor dependency for test isolation
130+ pipeline .video_processor .postprocess_video = MagicMock (return_value = np .zeros ((1 , 3 , 1 , 32 , 32 )))
131+
132+ # Dummy params
133+ params = {"latent_upsampler" : {}}
134+ prng_seed = jax .random .PRNGKey (0 )
135+ latents = jnp .zeros ((1 , 4 , 16 , 16 , 8 ))
136+
137+ # Test returning latents directly
138+ out_latents = pipeline (
139+ params = params ,
140+ prng_seed = prng_seed ,
141+ latents = latents ,
142+ latents_normalized = False ,
143+ adain_factor = 1.0 ,
144+ tone_map_compression_ratio = 0.5 ,
145+ output_type = "latent" ,
146+ return_dict = True
147+ )
148+
149+ self .assertIn ("frames" , out_latents )
150+ self .assertEqual (out_latents ["frames" ].shape , (1 , 4 , 16 , 16 , 8 ))
151+
152+ # Ensure upsampler was called
153+ mock_upsampler .apply .assert_called_once ()
154+
155+ # Test decoding flow
156+ out_decoded = pipeline (
157+ params = params ,
158+ prng_seed = prng_seed ,
159+ latents = latents ,
160+ latents_normalized = False ,
161+ output_type = "pil" ,
162+ return_dict = True
163+ )
164+
165+ # Check if vae.decode was called
166+ mock_vae .decode .assert_called_once ()
167+ self .assertIn ("frames" , out_decoded )
168+
169+
170+ if __name__ == "__main__" :
171+ unittest .main ()
0 commit comments