11"""
2- This is a test file used for ensuring numerical parity between pytorch and jax implementation of LTX2.
3- This is to be ignored and will not be pushed when commiting to main branch.
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
415"""
516from typing import Optional , Tuple , Any , Dict , Union
617import jax
@@ -105,8 +116,9 @@ def __init__(
105116 epsilon = self .norm_eps ,
106117 use_scale = self .norm_elementwise_affine ,
107118 rngs = rngs ,
108- dtype = dtype ,
109- param_dtype = weights_dtype ,
119+ rngs = rngs ,
120+ dtype = jnp .float32 ,
121+ param_dtype = jnp .float32 ,
110122 )
111123 self .attn1 = LTX2Attention (
112124 rngs = rngs ,
@@ -128,8 +140,9 @@ def __init__(
128140 epsilon = self .norm_eps ,
129141 use_scale = self .norm_elementwise_affine ,
130142 rngs = rngs ,
131- dtype = dtype ,
132- param_dtype = weights_dtype ,
143+ rngs = rngs ,
144+ dtype = jnp .float32 ,
145+ param_dtype = jnp .float32 ,
133146 )
134147 self .audio_attn1 = LTX2Attention (
135148 rngs = rngs ,
@@ -152,8 +165,9 @@ def __init__(
152165 epsilon = self .norm_eps ,
153166 use_scale = self .norm_elementwise_affine ,
154167 rngs = rngs ,
155- dtype = dtype ,
156- param_dtype = weights_dtype ,
168+ rngs = rngs ,
169+ dtype = jnp .float32 ,
170+ param_dtype = jnp .float32 ,
157171 )
158172 self .attn2 = LTX2Attention (
159173 rngs = rngs ,
@@ -176,8 +190,9 @@ def __init__(
176190 epsilon = self .norm_eps ,
177191 use_scale = self .norm_elementwise_affine ,
178192 rngs = rngs ,
179- dtype = dtype ,
180- param_dtype = weights_dtype ,
193+ rngs = rngs ,
194+ dtype = jnp .float32 ,
195+ param_dtype = jnp .float32 ,
181196 )
182197 self .audio_attn2 = LTX2Attention (
183198 rngs = rngs ,
@@ -197,7 +212,7 @@ def __init__(
197212
198213 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
199214 self .audio_to_video_norm = nnx .RMSNorm (
200- dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype
215+ dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = jnp . float32 , param_dtype = jnp . float32
201216 )
202217 self .audio_to_video_attn = LTX2Attention (
203218 rngs = rngs ,
@@ -220,8 +235,9 @@ def __init__(
220235 epsilon = self .norm_eps ,
221236 use_scale = self .norm_elementwise_affine ,
222237 rngs = rngs ,
223- dtype = dtype ,
224- param_dtype = weights_dtype ,
238+ rngs = rngs ,
239+ dtype = jnp .float32 ,
240+ param_dtype = jnp .float32 ,
225241 )
226242 self .video_to_audio_attn = LTX2Attention (
227243 rngs = rngs ,
@@ -241,7 +257,7 @@ def __init__(
241257
242258 # 4. Feed Forward
243259 self .norm3 = nnx .RMSNorm (
244- dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = dtype , param_dtype = weights_dtype
260+ dim , epsilon = self .norm_eps , use_scale = self .norm_elementwise_affine , rngs = rngs , dtype = jnp . float32 , param_dtype = jnp . float32
245261 )
246262 self .ff = NNXSimpleFeedForward (
247263 rngs = rngs ,
@@ -257,8 +273,8 @@ def __init__(
257273 epsilon = self .norm_eps ,
258274 use_scale = self .norm_elementwise_affine ,
259275 rngs = rngs ,
260- dtype = dtype ,
261- param_dtype = weights_dtype ,
276+ dtype = jnp . float32 ,
277+ param_dtype = jnp . float32 ,
262278 )
263279 self .audio_ff = NNXSimpleFeedForward (
264280 rngs = rngs , dim = audio_dim , dim_out = audio_dim , activation_fn = activation_fn , dtype = dtype , weights_dtype = weights_dtype
@@ -776,7 +792,7 @@ def init_block(rngs):
776792 # 6. Output layers
777793 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
778794 self .norm_out = nnx .LayerNorm (
779- inner_dim , epsilon = 1e-6 , use_scale = False , rngs = rngs , dtype = self . dtype , param_dtype = self . weights_dtype
795+ inner_dim , epsilon = 1e-6 , use_scale = False , rngs = rngs , dtype = jnp . float32 , param_dtype = jnp . float32
780796 )
781797 self .proj_out = nnx .Linear (
782798 inner_dim ,
@@ -789,7 +805,7 @@ def init_block(rngs):
789805 )
790806
791807 self .audio_norm_out = nnx .LayerNorm (
792- audio_inner_dim , epsilon = 1e-6 , use_scale = False , rngs = rngs , dtype = self . dtype , param_dtype = self . weights_dtype
808+ audio_inner_dim , epsilon = 1e-6 , use_scale = False , rngs = rngs , dtype = jnp . float32 , param_dtype = jnp . float32
793809 )
794810 self .audio_proj_out = nnx .Linear (
795811 audio_inner_dim ,
0 commit comments