Skip to content

Commit 61398b2

Browse files
committed
change in rope in embeddings_connector
1 parent 46a68be commit 61398b2

1 file changed

Lines changed: 57 additions & 10 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,21 @@ def __init__(
8686
theta: float = 10000.0,
8787
num_learnable_registers: int = 128,
8888
rope_type: str = "interleaved",
89+
base_seq_len: int = 4096,
90+
double_precision: bool = True,
8991
attention_kernel: str = "flash",
9092
mesh: jax.sharding.Mesh = None,
9193
rngs: nnx.Rngs = None,
9294
):
9395
self.dim = input_dim
96+
self.heads = heads
97+
self.head_dim = head_dim
9498
self.theta = theta
9599
self.num_learnable_registers = num_learnable_registers
96100
self.num_layers = layers
101+
self.rope_type = rope_type
102+
self.base_seq_len = base_seq_len
103+
self.double_precision = double_precision
97104

98105
# 1. Initialize Stacked Layers using vmap
99106
# This creates a single module where parameters have an extra leading dimension [layers, ...]
@@ -165,15 +172,54 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
165172
new_mask = jnp.ones_like(attention_mask)
166173
return output, new_mask
167174

168-
def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
169-
t = jnp.arange(seq_len, dtype=jnp.float32)
170-
freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
171-
emb = jnp.outer(t, freqs)
172-
cos = jnp.cos(emb)
173-
sin = jnp.sin(emb)
174-
cos = jnp.repeat(cos, 2, axis=-1)
175-
sin = jnp.repeat(sin, 2, axis=-1)
176-
return cos[None, ...], sin[None, ...]
175+
def _compute_1d_rope(self, batch_size: int, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
176+
grid_1d = jnp.arange(seq_len, dtype=jnp.float32)
177+
grid_1d = grid_1d / self.base_seq_len
178+
grid = jnp.expand_dims(grid_1d, 0)
179+
grid = jnp.tile(grid, (batch_size, 1))
180+
181+
num_rope_elems = 2
182+
freqs_dtype = jnp.float64 if self.double_precision else jnp.float32
183+
steps = self.dim // num_rope_elems
184+
pow_indices = jnp.power(self.theta, jnp.linspace(0.0, 1.0, steps, dtype=freqs_dtype))
185+
base_freqs = (pow_indices * jnp.pi / 2.0).astype(jnp.float32)
186+
187+
freqs = (jnp.expand_dims(grid, -1) * 2.0 - 1.0) * base_freqs
188+
189+
cos_freqs = jnp.cos(freqs)
190+
sin_freqs = jnp.sin(freqs)
191+
192+
if self.rope_type == "interleaved":
193+
cos_freqs = jnp.repeat(cos_freqs, 2, axis=-1)
194+
sin_freqs = jnp.repeat(sin_freqs, 2, axis=-1)
195+
196+
if self.dim % num_rope_elems != 0:
197+
curr_dim = cos_freqs.shape[-1]
198+
pad_amt = self.dim - curr_dim
199+
if pad_amt > 0:
200+
cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_amt), dtype=cos_freqs.dtype)
201+
sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_amt), dtype=sin_freqs.dtype)
202+
cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1)
203+
sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1)
204+
205+
elif self.rope_type == "split":
206+
expected_freqs = self.dim // 2
207+
current_freqs = freqs.shape[-1]
208+
pad_size = expected_freqs - current_freqs
209+
210+
if pad_size > 0:
211+
cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_size), dtype=cos_freqs.dtype)
212+
sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_size), dtype=sin_freqs.dtype)
213+
cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1)
214+
sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1)
215+
216+
b = cos_freqs.shape[0]
217+
t = cos_freqs.shape[1]
218+
h = self.heads
219+
cos_freqs = cos_freqs.reshape(b, t, h, -1).transpose(0, 2, 1, 3)
220+
sin_freqs = sin_freqs.reshape(b, t, h, -1).transpose(0, 2, 1, 3)
221+
222+
return cos_freqs, sin_freqs
177223

178224
def __call__(
179225
self,
@@ -198,8 +244,9 @@ def __call__(
198244
mean=jnp.mean(hidden_states), std=jnp.std(hidden_states))
199245

200246
# 2. RoPE
247+
batch_size = hidden_states.shape[0]
201248
seq_len = hidden_states.shape[1]
202-
rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype)
249+
rotary_emb = self._compute_1d_rope(batch_size, seq_len, hidden_states.dtype)
203250

204251
# 3. Transformer Blocks (Scan)
205252

0 commit comments

Comments
 (0)