Skip to content

Commit 2d55e87

Browse files
committed
clean up, ensuring layers are in fp32
1 parent 8f05f7c commit 2d55e87

3 files changed

Lines changed: 40 additions & 20 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,8 @@ def __init__(
398398
self.to_v = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
399399

400400
# 2. Normalization (Applied to full inner_dim, NOT per-head)
401-
self.norm_q = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=dtype, use_scale=True, rngs=rngs)
402-
self.norm_k = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=dtype, use_scale=True, rngs=rngs)
401+
self.norm_q = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
402+
self.norm_k = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
403403

404404
# 3. Output
405405
self.to_out = nnx.Linear(self.inner_dim, query_dim, use_bias=out_bias, rngs=rngs, dtype=dtype)

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
"""
2-
This is a test file used for ensuring numerical parity between pytorch and jax implementation of LTX2.
3-
This is to be ignored and will not be pushed when commiting to main branch.
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
415
"""
516
from typing import Optional, Tuple, Any, Dict, Union
617
import jax
@@ -105,8 +116,9 @@ def __init__(
105116
epsilon=self.norm_eps,
106117
use_scale=self.norm_elementwise_affine,
107118
rngs=rngs,
108-
dtype=dtype,
109-
param_dtype=weights_dtype,
119+
rngs=rngs,
120+
dtype=jnp.float32,
121+
param_dtype=jnp.float32,
110122
)
111123
self.attn1 = LTX2Attention(
112124
rngs=rngs,
@@ -128,8 +140,9 @@ def __init__(
128140
epsilon=self.norm_eps,
129141
use_scale=self.norm_elementwise_affine,
130142
rngs=rngs,
131-
dtype=dtype,
132-
param_dtype=weights_dtype,
143+
rngs=rngs,
144+
dtype=jnp.float32,
145+
param_dtype=jnp.float32,
133146
)
134147
self.audio_attn1 = LTX2Attention(
135148
rngs=rngs,
@@ -152,8 +165,9 @@ def __init__(
152165
epsilon=self.norm_eps,
153166
use_scale=self.norm_elementwise_affine,
154167
rngs=rngs,
155-
dtype=dtype,
156-
param_dtype=weights_dtype,
168+
rngs=rngs,
169+
dtype=jnp.float32,
170+
param_dtype=jnp.float32,
157171
)
158172
self.attn2 = LTX2Attention(
159173
rngs=rngs,
@@ -176,8 +190,9 @@ def __init__(
176190
epsilon=self.norm_eps,
177191
use_scale=self.norm_elementwise_affine,
178192
rngs=rngs,
179-
dtype=dtype,
180-
param_dtype=weights_dtype,
193+
rngs=rngs,
194+
dtype=jnp.float32,
195+
param_dtype=jnp.float32,
181196
)
182197
self.audio_attn2 = LTX2Attention(
183198
rngs=rngs,
@@ -197,7 +212,7 @@ def __init__(
197212

198213
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
199214
self.audio_to_video_norm = nnx.RMSNorm(
200-
dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype
215+
dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
201216
)
202217
self.audio_to_video_attn = LTX2Attention(
203218
rngs=rngs,
@@ -220,8 +235,9 @@ def __init__(
220235
epsilon=self.norm_eps,
221236
use_scale=self.norm_elementwise_affine,
222237
rngs=rngs,
223-
dtype=dtype,
224-
param_dtype=weights_dtype,
238+
rngs=rngs,
239+
dtype=jnp.float32,
240+
param_dtype=jnp.float32,
225241
)
226242
self.video_to_audio_attn = LTX2Attention(
227243
rngs=rngs,
@@ -241,7 +257,7 @@ def __init__(
241257

242258
# 4. Feed Forward
243259
self.norm3 = nnx.RMSNorm(
244-
dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=dtype, param_dtype=weights_dtype
260+
dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
245261
)
246262
self.ff = NNXSimpleFeedForward(
247263
rngs=rngs,
@@ -257,8 +273,8 @@ def __init__(
257273
epsilon=self.norm_eps,
258274
use_scale=self.norm_elementwise_affine,
259275
rngs=rngs,
260-
dtype=dtype,
261-
param_dtype=weights_dtype,
276+
dtype=jnp.float32,
277+
param_dtype=jnp.float32,
262278
)
263279
self.audio_ff = NNXSimpleFeedForward(
264280
rngs=rngs, dim=audio_dim, dim_out=audio_dim, activation_fn=activation_fn, dtype=dtype, weights_dtype=weights_dtype
@@ -776,7 +792,7 @@ def init_block(rngs):
776792
# 6. Output layers
777793
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
778794
self.norm_out = nnx.LayerNorm(
779-
inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=self.dtype, param_dtype=self.weights_dtype
795+
inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
780796
)
781797
self.proj_out = nnx.Linear(
782798
inner_dim,
@@ -789,7 +805,7 @@ def init_block(rngs):
789805
)
790806

791807
self.audio_norm_out = nnx.LayerNorm(
792-
audio_inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=self.dtype, param_dtype=self.weights_dtype
808+
audio_inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
793809
)
794810
self.audio_proj_out = nnx.Linear(
795811
audio_inner_dim,

src/maxdiffusion/tests/ltx2_parity_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
This is a test file used for ensuring numerical parity between pytorch and jax implementation of LTX2.
3+
This is to be ignored and will not be pushed when commiting to main branch.
4+
"""
15
import unittest
26
import jax
37
import jax.numpy as jnp

0 commit comments

Comments
 (0)