Skip to content

Commit 7bd49ec

Browse files
committed
added dummy attention and rope, verified transformer_ltx2.py, added unit tests for transformer
1 parent e44ddbc commit 7bd49ec

6 files changed

Lines changed: 1185 additions & 38 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,18 @@ def __call__(self, timesteps: jax.Array) -> jax.Array:
521521

522522

523523
class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module):
524-
def __init__(self, rngs: nnx.Rngs, embedding_dim: int, size_emb_dim: int, dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32):
524+
def __init__(
525+
self,
526+
rngs: nnx.Rngs,
527+
embedding_dim: int,
528+
size_emb_dim: int,
529+
use_additional_conditions: bool = False,
530+
dtype: jnp.dtype = jnp.float32,
531+
weights_dtype: jnp.dtype = jnp.float32
532+
):
525533
self.outdim = size_emb_dim
534+
self.use_additional_conditions = use_additional_conditions
535+
526536
self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
527537
self.timestep_embedder = NNXTimestepEmbedding(
528538
rngs=rngs,
@@ -532,7 +542,49 @@ def __init__(self, rngs: nnx.Rngs, embedding_dim: int, size_emb_dim: int, dtype:
532542
weights_dtype=weights_dtype
533543
)
534544

535-
def __call__(self, timestep: jax.Array, hidden_dtype: jnp.dtype = jnp.float32) -> jax.Array:
545+
if use_additional_conditions:
546+
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
547+
self.resolution_embedder = NNXTimestepEmbedding(
548+
rngs=rngs,
549+
in_channels=256,
550+
time_embed_dim=size_emb_dim,
551+
dtype=dtype,
552+
weights_dtype=weights_dtype
553+
)
554+
self.aspect_ratio_embedder = NNXTimestepEmbedding(
555+
rngs=rngs,
556+
in_channels=256,
557+
time_embed_dim=size_emb_dim,
558+
dtype=dtype,
559+
weights_dtype=weights_dtype
560+
)
561+
562+
def __call__(
563+
self,
564+
timestep: jax.Array,
565+
resolution: Optional[jax.Array] = None,
566+
aspect_ratio: Optional[jax.Array] = None,
567+
hidden_dtype: jnp.dtype = jnp.float32
568+
) -> jax.Array:
536569
timesteps_proj = self.time_proj(timestep)
537570
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
538-
return timesteps_emb
571+
572+
if self.use_additional_conditions:
573+
if resolution is None or aspect_ratio is None:
574+
raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True")
575+
576+
resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype)
577+
resolution_emb = self.resolution_embedder(resolution_emb)
578+
# Reshape to (batch_size, -1) matching PyTorch's reshape(batch_size, -1)
579+
# assuming resolution input was (batch_size, ...) so flatten logic holds.
580+
resolution_emb = resolution_emb.reshape(timestep.shape[0], -1)
581+
582+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype)
583+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb)
584+
aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1)
585+
586+
conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1)
587+
else:
588+
conditioning = timesteps_emb
589+
590+
return conditioning

src/maxdiffusion/models/ltx_2/adaln.py

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
from typing import Optional, Tuple, Any
3+
from flax import nnx
4+
import jax
5+
import jax.numpy as jnp
6+
7+
class Attention(nnx.Module):
8+
"""
9+
Placeholder for LTX-2 Attention (Self/Cross, Audio/Video).
10+
Assumed to be implemented by another team/task.
11+
"""
12+
def __init__(
13+
self,
14+
rngs: nnx.Rngs,
15+
query_dim: int,
16+
heads: int = 8,
17+
kv_heads: int = 8,
18+
dim_head: int = 64,
19+
dropout: float = 0.0,
20+
use_bias: bool = True,
21+
cross_attention_dim: Optional[int] = None,
22+
out_bias: bool = True,
23+
qk_norm: str = "rms_norm_across_heads",
24+
norm_eps: float = 1e-6,
25+
rope_type: str = "interleaved",
26+
dtype: jnp.dtype = jnp.float32,
27+
param_dtype: jnp.dtype = jnp.float32,
28+
):
29+
self.heads = heads
30+
self.dim_head = dim_head
31+
# Full implementation omitted.
32+
33+
def __call__(
34+
self,
35+
hidden_states: jax.Array,
36+
encoder_hidden_states: Optional[jax.Array] = None,
37+
attention_mask: Optional[jax.Array] = None,
38+
query_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
39+
key_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
40+
deterministic: bool = True
41+
) -> jax.Array:
42+
"""
43+
Placeholder forward pass.
44+
Returns tensor of same shape as input (hidden_states) usually,
45+
or projected to query_dim.
46+
"""
47+
# Return hidden_states for shape compatibility in simple tests,
48+
# or zeros if dimensions change (e.g. cross attn).
49+
# If cross attention (encoder_hidden_states provided), usually output is query_dim-based.
50+
# We assume output shape matches hidden_states (query) spatial dims, but depth is query_dim.
51+
# But 'out' is projected to query_dim.
52+
# In Block, we add this to 'hidden_states' (residual).
53+
# So it MUST match hidden_states shape.
54+
55+
return hidden_states
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
2+
from flax import nnx
3+
from enum import Enum
4+
import jax
5+
import jax.numpy as jnp
6+
from typing import Tuple, Optional, Union, List
7+
8+
class LTXRopeType(Enum):
9+
INTERLEAVED = "interleaved"
10+
SPLIT = "split"
11+
12+
class LTX2AudioVideoRotaryPosEmbed(nnx.Module):
13+
"""
14+
Placeholder for LTX-2 3D Video and 1D Audio RoPE.
15+
Assumed to be implemented by another team/task.
16+
"""
17+
def __init__(
18+
self,
19+
dim: int,
20+
patch_size: int,
21+
patch_size_t: int,
22+
base_num_frames: int = 128,
23+
base_height: int = 2048,
24+
base_width: int = 2048,
25+
sampling_rate: int = 16000,
26+
hop_length: int = 160,
27+
scale_factors: Union[List[int], Tuple[int, ...]] = (8, 32, 32),
28+
theta: float = 10000.0,
29+
causal_offset: int = 1,
30+
modality: str = "video",
31+
double_precision: bool = True,
32+
rope_type: str = "interleaved",
33+
num_attention_heads: int = 32,
34+
dtype: jnp.dtype = jnp.float32,
35+
):
36+
self.dim = dim
37+
self.rope_type = rope_type
38+
self.dtype = dtype
39+
self.modality = modality
40+
41+
def prepare_video_coords(self, batch_size, num_frames, height, width, fps):
42+
# Return dummy coords
43+
return jnp.zeros((batch_size, 1, 1), dtype=self.dtype)
44+
45+
def prepare_audio_coords(self, batch_size, audio_num_frames):
46+
# Return dummy coords
47+
return jnp.zeros((batch_size, 1, 1), dtype=self.dtype)
48+
49+
def __call__(
50+
self,
51+
coords: jax.Array,
52+
) -> Tuple[jax.Array, jax.Array]:
53+
"""
54+
Returns placeholder frequencies (cos, sin).
55+
"""
56+
# Return dummy cos/sin
57+
# Shape: (1, 1, dim) to broadcast?
58+
# Attention expects (batch, seq, head_dim) usually or (batch, 1, head_dim)
59+
# Let's return sensible broadcastable shapes.
60+
return jnp.zeros((1, 1, self.dim), dtype=self.dtype), jnp.zeros((1, 1, self.dim), dtype=self.dtype)
61+
62+
# Helper placeholders if used by attention
63+
def apply_interleaved_rotary_emb(x, freqs):
64+
return x
65+
66+
def apply_split_rotary_emb(x, freqs):
67+
return x

0 commit comments

Comments
 (0)