Skip to content

Commit 4243240

Browse files
committed
feat: Implement feature extractor
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 6e3b58b commit 4243240

3 files changed

Lines changed: 280 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/__init__.py

Whitespace-only changes.
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Tuple, Optional, Union
18+
import jax
19+
import jax.numpy as jnp
20+
from flax import nnx
21+
from maxdiffusion import common_types
22+
23+
Array = common_types.Array
24+
DType = common_types.DType
25+
26+
27+
def _norm_and_concat_padded_batch(
28+
encoded_text: Array,
29+
sequence_lengths: Array,
30+
padding_side: str = "right",
31+
) -> Array:
32+
"""Normalize and flatten multi-layer hidden states, respecting padding.
33+
Performs per-batch, per-layer normalization using masked mean and range,
34+
then concatenates across the layer dimension.
35+
36+
Args:
37+
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
38+
sequence_lengths: Number of valid (non-padded) tokens per batch item.
39+
padding_side: Whether padding is on "left" or "right".
40+
41+
Returns:
42+
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
43+
with padded positions zeroed out.
44+
"""
45+
b, t, d, l = encoded_text.shape
46+
47+
# Build mask: [B, T] -> [B, T, 1, 1]
48+
# token_indices: [1, T]
49+
token_indices = jnp.arange(t)[None, :]
50+
51+
if padding_side == "right":
52+
# Valid: indices < lengths
53+
mask = token_indices < sequence_lengths[:, None]
54+
elif padding_side == "left":
55+
# Valid: indices >= (T - lengths)
56+
start_indices = t - sequence_lengths[:, None]
57+
mask = token_indices >= start_indices
58+
else:
59+
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
60+
61+
# [B, T, 1, 1]
62+
mask = mask[:, :, None, None]
63+
64+
eps = 1e-6
65+
66+
# 1. Compute Masked Mean
67+
# Masked sum: [B, 1, 1, L] (sum over T, D)
68+
# Using jnp.where to zero-out padding
69+
masked_text = jnp.where(mask, encoded_text, 0.0)
70+
sum_vals = jnp.sum(masked_text, axis=(1, 2), keepdims=True)
71+
72+
# Denom: sequence_length * D
73+
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
74+
mean = sum_vals / (denom + eps)
75+
76+
# 2. Compute Masked Min/Max for Range
77+
# Use jnp.inf / -jnp.inf for padding to ignore them in min/max
78+
safe_text_min = jnp.where(mask, encoded_text, jnp.inf)
79+
safe_text_max = jnp.where(mask, encoded_text, -jnp.inf)
80+
81+
x_min = jnp.min(safe_text_min, axis=(1, 2), keepdims=True)
82+
x_max = jnp.max(safe_text_max, axis=(1, 2), keepdims=True)
83+
84+
range_val = x_max - x_min
85+
86+
# 3. Normalize
87+
# Only valid tokens are normalized. Padding will be garbage but masked out later.
88+
normed = 8.0 * (encoded_text - mean) / (range_val + eps)
89+
90+
# 4. Concatenate/Flatten Layers
91+
# [B, T, D, L] -> [B, T, D * L]
92+
normed = normed.reshape(b, t, -1)
93+
94+
# 5. Apply Mask to Output
95+
# Ensure padding positions are exactly 0.0
96+
# mask: [B, T, 1, 1] -> [B, T, 1]
97+
output_mask = mask.squeeze(-1).squeeze(-1)[:, :, None]
98+
normed = jnp.where(output_mask, normed, 0.0)
99+
100+
return normed
101+
102+
103+
class LTX2GemmaFeatureExtractor(nnx.Module):
104+
"""
105+
Feature extractor module for Gemma models in LTX-2.
106+
Applies mean-centered scaling and a linear projection.
107+
"""
108+
109+
def __init__(
110+
self,
111+
input_dim: int,
112+
output_dim: int,
113+
dtype: DType = jnp.float32,
114+
rngs: nnx.Rngs = None,
115+
):
116+
"""
117+
Args:
118+
input_dim: Dimension of flattened hidden states (Gemma dim * Num layers).
119+
output_dim: Target dimension for diffusion conditioning.
120+
"""
121+
# LTX-2 uses bias=False for the projection
122+
self.linear = nnx.Linear(input_dim, output_dim, use_bias=False, dtype=dtype, rngs=rngs)
123+
124+
def __call__(
125+
self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array, padding_side: str = "right"
126+
) -> Array:
127+
"""
128+
Args:
129+
hidden_states: Tuple of arrays from Gemma, each [B, T, D].
130+
Or pre-stacked array [B, T, D, L].
131+
attention_mask: Mask [B, T] (1 for valid, 0 for padding).
132+
padding_side: "right" or "left".
133+
134+
Returns:
135+
Projected features [B, T, OutputDim].
136+
"""
137+
138+
# 1. Stack Hidden States if needed
139+
if isinstance(hidden_states, (tuple, list)):
140+
# [B, T, D, L]
141+
x = jnp.stack(hidden_states, axis=-1)
142+
else:
143+
x = hidden_states
144+
145+
# 2. Calculate Sequence Lengths
146+
sequence_lengths = jnp.sum(attention_mask, axis=-1)
147+
148+
# 3. Norm and Concat
149+
x_norm = _norm_and_concat_padded_batch(x, sequence_lengths, padding_side=padding_side)
150+
151+
# 4. Projection
152+
return self.linear(x_norm)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
import torch
19+
import numpy as np
20+
import jax
21+
import jax.numpy as jnp
22+
from flax import nnx
23+
24+
from ..models.ltx2.text_encoders.feature_extractor_ltx2 import LTX2GemmaFeatureExtractor, _norm_and_concat_padded_batch
25+
26+
27+
# ==========================================
28+
# PyTorch Reference Logic
29+
# ==========================================
30+
def pt_norm_and_concat_padded_batch(
31+
encoded_text: torch.Tensor,
32+
sequence_lengths: torch.Tensor,
33+
padding_side: str = "right",
34+
) -> torch.Tensor:
35+
b, t, d, l = encoded_text.shape
36+
device = encoded_text.device
37+
38+
token_indices = torch.arange(t, device=device)[None, :]
39+
if padding_side == "right":
40+
mask = token_indices < sequence_lengths[:, None]
41+
elif padding_side == "left":
42+
start_indices = t - sequence_lengths[:, None]
43+
mask = token_indices >= start_indices
44+
else:
45+
raise ValueError
46+
47+
mask = mask[:, :, None, None] # [B, T, 1, 1]
48+
49+
eps = 1e-6
50+
masked = encoded_text.masked_fill(~mask, 0.0)
51+
denom = (sequence_lengths * d).view(b, 1, 1, 1)
52+
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
53+
54+
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
55+
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
56+
range_ = x_max - x_min
57+
58+
normed = 8 * (encoded_text - mean) / (range_ + eps)
59+
normed = normed.reshape(b, t, -1)
60+
61+
# Apply mask
62+
mask_flattened = mask.view(b, t, 1).expand(-1, -1, d * l)
63+
normed = normed.masked_fill(~mask_flattened, 0.0)
64+
65+
return normed
66+
67+
68+
class LTX2FeatureExtractorTest(unittest.TestCase):
69+
70+
def setUp(self):
71+
self.rng = nnx.Rngs(0)
72+
self.B = 2
73+
self.T = 10
74+
self.D = 8
75+
self.L = 3
76+
self.target_dim = 16
77+
78+
def test_norm_parity(self):
79+
# Create random input with some padding
80+
np_input = np.random.randn(self.B, self.T, self.D, self.L).astype(np.float32)
81+
82+
# Lengths: e.g. [5, 8] out of 10
83+
lengths = np.array([5, 8], dtype=np.int32)
84+
85+
# PyTorch Reference
86+
pt_input = torch.from_numpy(np_input)
87+
pt_lengths = torch.from_numpy(lengths)
88+
pt_out = pt_norm_and_concat_padded_batch(pt_input, pt_lengths)
89+
90+
# JAX Implementation
91+
jax_input = jnp.array(np_input)
92+
jax_lengths = jnp.array(lengths)
93+
jax_out = _norm_and_concat_padded_batch(jax_input, jax_lengths)
94+
95+
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()
96+
print(f"\n[Norm Parity] Max Diff: {diff:.6f}")
97+
98+
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
99+
print("[PASS] Normalization Logic Parity Verified.")
100+
101+
def test_module_forward(self):
102+
# Test full module
103+
model = LTX2GemmaFeatureExtractor(input_dim=self.D * self.L, output_dim=self.target_dim, rngs=self.rng)
104+
105+
# Create input tuple (simulate Gemma output)
106+
hidden_states = [jnp.array(np.random.randn(self.B, self.T, self.D)) for _ in range(self.L)]
107+
108+
# Attention Mask [B, T]
109+
mask = np.zeros((self.B, self.T), dtype=np.int32)
110+
mask[0, :5] = 1
111+
mask[1, :8] = 1
112+
jax_mask = jnp.array(mask)
113+
114+
output = model(tuple(hidden_states), jax_mask)
115+
116+
expected_shape = (self.B, self.T, self.target_dim)
117+
self.assertEqual(output.shape, expected_shape)
118+
119+
# Check padding regions are zero
120+
# Batch 0, indices 5: should be 0
121+
padding_val = output[0, 5:, :]
122+
self.assertTrue(jnp.all(padding_val == 0.0), "Padding region should be zero")
123+
124+
print("\n[PASS] Feature Extractor Module Forward Pass Verified.")
125+
126+
127+
if __name__ == "__main__":
128+
unittest.main()

0 commit comments

Comments
 (0)