Skip to content

Commit 002352e

Browse files
author
James
committed
[Text Pipeline] implement embedding connector
Signed-off-by: James <shyhuanh@google.com>
1 parent 504377f commit 002352e

2 files changed

Lines changed: 315 additions & 0 deletions

File tree

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Optional, Tuple
18+
import jax
19+
import jax.numpy as jnp
20+
from flax import nnx
21+
from maxdiffusion import common_types
22+
from ..attention_ltx2 import LTX2Attention
23+
24+
Array = common_types.Array
25+
DType = common_types.DType
26+
27+
class FeedForward(nnx.Module):
28+
def __init__(self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, rngs: nnx.Rngs = None):
29+
inner_dim = int(dim * mult)
30+
dim_out = dim_out if dim_out is not None else dim
31+
32+
self.proj1 = nnx.Linear(dim, inner_dim, rngs=rngs)
33+
self.proj2 = nnx.Linear(inner_dim, dim_out, rngs=rngs)
34+
35+
def __call__(self, x: Array) -> Array:
36+
x = self.proj1(x)
37+
x = jax.nn.gelu(x)
38+
x = self.proj2(x)
39+
return x
40+
41+
class _BasicTransformerBlock1D(nnx.Module):
42+
def __init__(
43+
self,
44+
dim: int,
45+
heads: int,
46+
dim_head: int,
47+
rope_type: str = "interleaved",
48+
attention_kernel: str = "flash",
49+
rngs: nnx.Rngs = None,
50+
):
51+
self.attn1 = LTX2Attention(
52+
query_dim=dim,
53+
heads=heads,
54+
dim_head=dim_head,
55+
rope_type=rope_type,
56+
bias=True, # LTX-2 default
57+
out_bias=True,
58+
attention_kernel=attention_kernel,
59+
rngs=rngs,
60+
)
61+
self.ff = FeedForward(dim, dim_out=dim, rngs=rngs)
62+
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
63+
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
64+
65+
def __call__(
66+
self,
67+
hidden_states: Array,
68+
attention_mask: Optional[Array] = None,
69+
rotary_emb: Optional[Tuple[Array, Array]] = None,
70+
) -> Array:
71+
# 1. Norm -> Attention
72+
normed = self.norm1(hidden_states)
73+
attn_output = self.attn1(normed, attention_mask=attention_mask, rotary_emb=rotary_emb)
74+
hidden_states = hidden_states + attn_output
75+
76+
# 2. Norm -> FeedForward
77+
normed = self.norm2(hidden_states)
78+
ff_output = self.ff(normed)
79+
hidden_states = hidden_states + ff_output
80+
81+
return hidden_states
82+
83+
class Embeddings1DConnector(nnx.Module):
84+
"""
85+
Applies 1D transformer processing with Thinking Tokens (Learnable Registers).
86+
Uses nnx.scan for efficient JAX-idiomatic layer execution.
87+
"""
88+
def __init__(
89+
self,
90+
input_dim: int,
91+
heads: int = 30,
92+
head_dim: int = 128,
93+
layers: int = 2,
94+
theta: float = 10000.0,
95+
num_learnable_registers: int = 128,
96+
rope_type: str = "interleaved",
97+
attention_kernel: str = "flash",
98+
rngs: nnx.Rngs = None,
99+
):
100+
self.dim = input_dim
101+
self.theta = theta
102+
self.num_learnable_registers = num_learnable_registers
103+
self.num_layers = layers
104+
105+
# 1. Initialize Stacked Layers using vmap
106+
# This creates a single module where parameters have an extra leading dimension [layers, ...]
107+
# We need to ensure rngs are split for each layer
108+
@nnx.split_rngs(splits=layers)
109+
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers)
110+
def create_block(rngs):
111+
return _BasicTransformerBlock1D(
112+
dim=input_dim,
113+
heads=heads,
114+
dim_head=head_dim,
115+
rope_type=rope_type,
116+
attention_kernel=attention_kernel,
117+
rngs=rngs
118+
)
119+
120+
# Call the vmapped constructor
121+
self.stacked_blocks = create_block(rngs)
122+
123+
# 2. Thinking Tokens
124+
if num_learnable_registers > 0:
125+
key = rngs.params()
126+
self.learnable_registers = nnx.Param(
127+
jax.random.uniform(key, (num_learnable_registers, self.dim), dtype=jnp.bfloat16) * 2.0 - 1.0
128+
)
129+
130+
self.final_norm = nnx.RMSNorm(self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
131+
132+
def _replace_padded_with_learnable_registers(
133+
self, hidden_states: Array, attention_mask: Array
134+
) -> Tuple[Array, Array]:
135+
b, t, d = hidden_states.shape
136+
if t % self.num_learnable_registers != 0:
137+
raise ValueError(f"Sequence length {t} must be divisible by {self.num_learnable_registers}")
138+
139+
num_duplications = t // self.num_learnable_registers
140+
registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1))
141+
registers = jnp.expand_dims(registers, 0)
142+
143+
if attention_mask.ndim == 2:
144+
mask = attention_mask[:, :, None]
145+
else:
146+
mask = attention_mask
147+
148+
output = jnp.where(mask > 0.5, hidden_states, registers)
149+
new_mask = jnp.ones_like(attention_mask)
150+
return output, new_mask
151+
152+
def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
153+
t = jnp.arange(seq_len, dtype=jnp.float32)
154+
freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
155+
emb = jnp.outer(t, freqs)
156+
cos = jnp.cos(emb)
157+
sin = jnp.sin(emb)
158+
cos = jnp.repeat(cos, 2, axis=-1)
159+
sin = jnp.repeat(sin, 2, axis=-1)
160+
return cos[None, ...], sin[None, ...]
161+
162+
def __call__(
163+
self,
164+
hidden_states: Array,
165+
attention_mask: Optional[Array] = None,
166+
) -> Array:
167+
# 1. Thinking Tokens
168+
if self.num_learnable_registers > 0 and attention_mask is not None:
169+
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(
170+
hidden_states, attention_mask
171+
)
172+
173+
# 2. RoPE
174+
seq_len = hidden_states.shape[1]
175+
rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype)
176+
177+
# 3. Transformer Blocks (Scan)
178+
179+
# Scan function signature: (carry, x) -> (carry, y)
180+
def block_scan_fn(carry, block_module):
181+
hidden_states = carry
182+
# block_module is a sliced view of the vmapped module
183+
hidden_states = block_module(
184+
hidden_states,
185+
attention_mask=attention_mask,
186+
rotary_emb=rotary_emb
187+
)
188+
return hidden_states, None
189+
190+
# Execute scan
191+
hidden_states, _ = nnx.scan(
192+
block_scan_fn,
193+
length=self.num_layers,
194+
in_axes=(nnx.Carry, 0), # Scan over the layers dimension (0) of block_module
195+
out_axes=(nnx.Carry, 0)
196+
)(hidden_states, self.stacked_blocks)
197+
198+
# 4. Final Norm
199+
hidden_states = self.final_norm(hidden_states)
200+
201+
return hidden_states
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
import jax
19+
import jax.numpy as jnp
20+
import numpy as np
21+
from flax import nnx
22+
from ..models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector
23+
24+
class Embeddings1DConnectorTest(unittest.TestCase):
25+
def setUp(self):
26+
self.rng = nnx.Rngs(0)
27+
self.B = 2
28+
self.T = 16 # Must be divisible by num_learnable_registers if we want tiling to work simply
29+
self.D = 64 # inner_dim
30+
31+
# Test config
32+
self.num_learnable_registers = 8
33+
self.heads = 4
34+
self.head_dim = 16
35+
36+
# input dim = heads * head_dim = 64
37+
38+
def test_thinking_tokens_replacement(self):
39+
connector = Embeddings1DConnector(
40+
input_dim=self.D,
41+
heads=self.heads,
42+
head_dim=self.head_dim,
43+
layers=1,
44+
num_learnable_registers=self.num_learnable_registers,
45+
rngs=self.rng
46+
)
47+
48+
# Create input [B, T, D]
49+
hidden_states = jnp.zeros((self.B, self.T, self.D))
50+
51+
# Create mask [B, T]
52+
# Batch 0: First 4 valid, rest padding
53+
# Batch 1: First 8 valid, rest padding
54+
mask = np.zeros((self.B, self.T), dtype=np.int32)
55+
mask[0, :4] = 1
56+
mask[1, :8] = 1
57+
58+
# Explicitly run replacement method
59+
output, new_mask = connector._replace_padded_with_learnable_registers(
60+
hidden_states, jnp.array(mask)
61+
)
62+
63+
# 1. Check Mask Reset
64+
self.assertTrue(jnp.all(new_mask == 1.0), "New mask should be all 1s")
65+
66+
# 2. Check Valid Tokens (should be 0 as input was 0)
67+
# Batch 0, 0-3
68+
valid_b0 = output[0, :4, :]
69+
self.assertTrue(jnp.all(valid_b0 == 0.0), "Valid tokens should remain unchanged")
70+
71+
# 3. Check Thinking Tokens (Padding area)
72+
# Batch 0, 4-15
73+
thinking_b0 = output[0, 4:, :]
74+
75+
# The learnable registers should be tiled.
76+
# Registers shape: [8, 64]
77+
# T=16, so it's tiled 2 times -> [16, 64]
78+
# We need to verify that padding positions contain values from registers
79+
80+
# Get expected registers values
81+
registers_val = connector.learnable_registers[...] # [8, 64]
82+
tiled_regs = jnp.tile(registers_val, (2, 1)) # [16, 64]
83+
84+
expected_padding = tiled_regs[4:, :] # corresponding slice
85+
86+
np.testing.assert_allclose(
87+
thinking_b0,
88+
expected_padding,
89+
err_msg="Padding should be replaced by corresponding register values"
90+
)
91+
print("\n[PASS] Thinking Tokens Replacement Logic Verified.")
92+
93+
def test_forward_shape_and_run(self):
94+
connector = Embeddings1DConnector(
95+
input_dim=self.D,
96+
heads=self.heads,
97+
head_dim=self.head_dim,
98+
layers=2,
99+
num_learnable_registers=self.num_learnable_registers,
100+
attention_kernel="dot_product", # Use dot_product for testing on CPU
101+
rngs=self.rng
102+
)
103+
104+
hidden_states = jnp.array(np.random.randn(self.B, self.T, self.D))
105+
mask = jnp.ones((self.B, self.T)) # All valid
106+
107+
output = connector(hidden_states, mask)
108+
109+
self.assertEqual(output.shape, (self.B, self.T, self.D))
110+
self.assertFalse(jnp.isnan(output).any(), "Output should not contain NaNs")
111+
print("\n[PASS] Embeddings1DConnector Forward Pass Verified.")
112+
113+
if __name__ == "__main__":
114+
unittest.main()

0 commit comments

Comments
 (0)