Skip to content

Commit 25123c6

Browse files
committed
annotations added
1 parent efce9a5 commit 25123c6

3 files changed

Lines changed: 237 additions & 224 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Optional, Tuple
1818
from flax import nnx
19+
import jax
1920
import jax.numpy as jnp
2021
from ... import common_types
2122
from ..attention_flax import NNXAttentionOp
@@ -446,46 +447,48 @@ def __call__(
446447
# Determine context (Self or Cross)
447448
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
448449

449-
# 1. Project
450-
query = self.to_q(hidden_states)
451-
key = self.to_k(context)
452-
value = self.to_v(context)
450+
# 1. Project and Norm
451+
with jax.named_scope("QKV Projection and Norm"):
452+
query = self.to_q(hidden_states)
453+
key = self.to_k(context)
454+
value = self.to_v(context)
453455

454-
# 2. Norm (Full Inner Dimension)
455-
query = self.norm_q(query)
456-
key = self.norm_k(key)
456+
query = self.norm_q(query)
457+
key = self.norm_k(key)
457458

458459
# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
459460
# Frequencies are shape [B, S, InnerDim]
460461
# 3. Apply RoPE
461-
if rotary_emb is not None:
462-
if hasattr(self, "rope_type") and self.rope_type == "split":
463-
# Split RoPE: passing full freqs [B, H, S, D//2]
464-
# apply_split_rotary_emb handles reshaping query/key
465-
466-
query = apply_split_rotary_emb(query, rotary_emb)
467-
468-
if k_rotary_emb is not None:
469-
key = apply_split_rotary_emb(key, k_rotary_emb)
470-
elif encoder_hidden_states is None:
471-
key = apply_split_rotary_emb(key, rotary_emb)
472-
473-
else:
474-
# Interleaved (Default)
475-
query = apply_rotary_emb(query, rotary_emb)
476-
if k_rotary_emb is not None:
477-
key = apply_rotary_emb(key, k_rotary_emb)
478-
elif encoder_hidden_states is None:
479-
key = apply_rotary_emb(key, rotary_emb)
480-
481-
# 4. Attention
482-
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
483-
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
484-
485-
# 7. Output Projection
486-
hidden_states = self.to_out(attn_output)
487-
488-
if self.dropout_layer is not None:
489-
hidden_states = self.dropout_layer(hidden_states)
462+
with jax.named_scope("Apply RoPE"):
463+
if rotary_emb is not None:
464+
if hasattr(self, "rope_type") and self.rope_type == "split":
465+
# Split RoPE: passing full freqs [B, H, S, D//2]
466+
# apply_split_rotary_emb handles reshaping query/key
467+
468+
query = apply_split_rotary_emb(query, rotary_emb)
469+
470+
if k_rotary_emb is not None:
471+
key = apply_split_rotary_emb(key, k_rotary_emb)
472+
elif encoder_hidden_states is None:
473+
key = apply_split_rotary_emb(key, rotary_emb)
474+
475+
else:
476+
# Interleaved (Default)
477+
query = apply_rotary_emb(query, rotary_emb)
478+
if k_rotary_emb is not None:
479+
key = apply_rotary_emb(key, k_rotary_emb)
480+
elif encoder_hidden_states is None:
481+
key = apply_rotary_emb(key, rotary_emb)
482+
483+
with jax.named_scope("Attention and Output Project"):
484+
# 4. Attention
485+
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
486+
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
487+
488+
# 7. Output Projection
489+
hidden_states = self.to_out(attn_output)
490+
491+
if self.dropout_layer is not None:
492+
hidden_states = self.dropout_layer(hidden_states)
490493

491494
return hidden_states

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 96 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -900,113 +900,79 @@ def __call__(
900900
batch_size = hidden_states.shape[0]
901901

902902
# 1. Prepare RoPE positional embeddings
903-
if video_coords is None:
904-
video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, fps=fps)
905-
if audio_coords is None:
906-
audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames)
903+
with jax.named_scope("RoPE Preparation"):
904+
if video_coords is None:
905+
video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, fps=fps)
906+
if audio_coords is None:
907+
audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames)
907908

908-
video_rotary_emb = self.rope(video_coords)
909-
audio_rotary_emb = self.audio_rope(audio_coords)
909+
video_rotary_emb = self.rope(video_coords)
910+
audio_rotary_emb = self.audio_rope(audio_coords)
910911

911-
video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :])
912-
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :])
912+
video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :])
913+
audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :])
913914

914915
# 2. Patchify input projections
915-
hidden_states = self.proj_in(hidden_states)
916-
audio_hidden_states = self.audio_proj_in(audio_hidden_states)
916+
with jax.named_scope("Input Projection"):
917+
hidden_states = self.proj_in(hidden_states)
918+
audio_hidden_states = self.audio_proj_in(audio_hidden_states)
917919

918920
# 3. Prepare timestep embeddings and modulation parameters
919-
timestep_cross_attn_gate_scale_factor = self.cross_attn_timestep_scale_multiplier / self.timestep_scale_multiplier
921+
with jax.named_scope("Timestep and Caption Projection"):
922+
timestep_cross_attn_gate_scale_factor = self.cross_attn_timestep_scale_multiplier / self.timestep_scale_multiplier
920923

921-
temb, embedded_timestep = self.time_embed(
922-
timestep.flatten(),
923-
hidden_dtype=hidden_states.dtype,
924-
)
925-
temb = temb.reshape(batch_size, -1, temb.shape[-1])
926-
embedded_timestep = embedded_timestep.reshape(batch_size, -1, embedded_timestep.shape[-1])
924+
temb, embedded_timestep = self.time_embed(
925+
timestep.flatten(),
926+
hidden_dtype=hidden_states.dtype,
927+
)
928+
temb = temb.reshape(batch_size, -1, temb.shape[-1])
929+
embedded_timestep = embedded_timestep.reshape(batch_size, -1, embedded_timestep.shape[-1])
927930

928-
temb_audio, audio_embedded_timestep = self.audio_time_embed(
929-
audio_timestep.flatten(),
930-
hidden_dtype=audio_hidden_states.dtype,
931-
)
932-
temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1])
933-
audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1])
931+
temb_audio, audio_embedded_timestep = self.audio_time_embed(
932+
audio_timestep.flatten(),
933+
hidden_dtype=audio_hidden_states.dtype,
934+
)
935+
temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1])
936+
audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1])
934937

935-
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
936-
timestep.flatten(),
937-
hidden_dtype=hidden_states.dtype,
938-
)
939-
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
940-
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
941-
hidden_dtype=hidden_states.dtype,
942-
)
943-
video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape(
944-
batch_size, -1, video_cross_attn_scale_shift.shape[-1]
945-
)
946-
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
938+
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
939+
timestep.flatten(),
940+
hidden_dtype=hidden_states.dtype,
941+
)
942+
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
943+
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
944+
hidden_dtype=hidden_states.dtype,
945+
)
946+
video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape(
947+
batch_size, -1, video_cross_attn_scale_shift.shape[-1]
948+
)
949+
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
947950

948-
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
949-
audio_timestep.flatten(),
950-
hidden_dtype=audio_hidden_states.dtype,
951-
)
952-
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
953-
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
954-
hidden_dtype=audio_hidden_states.dtype,
955-
)
956-
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape(
957-
batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
958-
)
959-
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
951+
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
952+
audio_timestep.flatten(),
953+
hidden_dtype=audio_hidden_states.dtype,
954+
)
955+
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
956+
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
957+
hidden_dtype=audio_hidden_states.dtype,
958+
)
959+
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape(
960+
batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
961+
)
962+
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
960963

961-
# 4. Prepare prompt embeddings
962-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
963-
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
964+
# 4. Prepare prompt embeddings
965+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
966+
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
964967

965-
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
966-
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
968+
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
969+
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
967970

968971
# 5. Run transformer blocks
969972
def scan_fn(carry, block):
970973
hidden_states, audio_hidden_states, rngs_carry = carry
971-
hidden_states_out, audio_hidden_states_out = block(
972-
hidden_states=hidden_states,
973-
audio_hidden_states=audio_hidden_states,
974-
encoder_hidden_states=encoder_hidden_states,
975-
audio_encoder_hidden_states=audio_encoder_hidden_states,
976-
temb=temb,
977-
temb_audio=temb_audio,
978-
temb_ca_scale_shift=video_cross_attn_scale_shift,
979-
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
980-
temb_ca_gate=video_cross_attn_a2v_gate,
981-
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
982-
video_rotary_emb=video_rotary_emb,
983-
audio_rotary_emb=audio_rotary_emb,
984-
ca_video_rotary_emb=video_cross_attn_rotary_emb,
985-
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
986-
encoder_attention_mask=encoder_attention_mask,
987-
audio_encoder_attention_mask=audio_encoder_attention_mask,
988-
)
989-
return (
990-
hidden_states_out.astype(hidden_states.dtype),
991-
audio_hidden_states_out.astype(audio_hidden_states.dtype),
992-
rngs_carry,
993-
), None
994-
995-
if self.scan_layers:
996-
rematted_scan_fn = self.gradient_checkpoint.apply(
997-
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
998-
)
999-
carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) # Placeholder RNGs for now if not used in block
1000-
(hidden_states, audio_hidden_states, _), _ = nnx.scan(
1001-
rematted_scan_fn,
1002-
length=self.num_layers,
1003-
in_axes=(nnx.Carry, 0),
1004-
out_axes=(nnx.Carry, 0),
1005-
transform_metadata={nnx.PARTITION_NAME: "layers"},
1006-
)(carry, self.transformer_blocks)
1007-
else:
1008-
for block in self.transformer_blocks:
1009-
hidden_states, audio_hidden_states = block(
974+
with jax.named_scope("Transformer Block i"):
975+
hidden_states_out, audio_hidden_states_out = block(
1010976
hidden_states=hidden_states,
1011977
audio_hidden_states=audio_hidden_states,
1012978
encoder_hidden_states=encoder_hidden_states,
@@ -1024,6 +990,45 @@ def scan_fn(carry, block):
1024990
encoder_attention_mask=encoder_attention_mask,
1025991
audio_encoder_attention_mask=audio_encoder_attention_mask,
1026992
)
993+
return (
994+
hidden_states_out.astype(hidden_states.dtype),
995+
audio_hidden_states_out.astype(audio_hidden_states.dtype),
996+
rngs_carry,
997+
), None
998+
999+
with jax.named_scope("Transformer Blocks"):
1000+
if self.scan_layers:
1001+
rematted_scan_fn = self.gradient_checkpoint.apply(
1002+
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
1003+
)
1004+
carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) # Placeholder RNGs for now if not used in block
1005+
(hidden_states, audio_hidden_states, _), _ = nnx.scan(
1006+
rematted_scan_fn,
1007+
length=self.num_layers,
1008+
in_axes=(nnx.Carry, 0),
1009+
out_axes=(nnx.Carry, 0),
1010+
transform_metadata={nnx.PARTITION_NAME: "layers"},
1011+
)(carry, self.transformer_blocks)
1012+
else:
1013+
for block in self.transformer_blocks:
1014+
hidden_states, audio_hidden_states = block(
1015+
hidden_states=hidden_states,
1016+
audio_hidden_states=audio_hidden_states,
1017+
encoder_hidden_states=encoder_hidden_states,
1018+
audio_encoder_hidden_states=audio_encoder_hidden_states,
1019+
temb=temb,
1020+
temb_audio=temb_audio,
1021+
temb_ca_scale_shift=video_cross_attn_scale_shift,
1022+
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1023+
temb_ca_gate=video_cross_attn_a2v_gate,
1024+
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1025+
video_rotary_emb=video_rotary_emb,
1026+
audio_rotary_emb=audio_rotary_emb,
1027+
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1028+
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1029+
encoder_attention_mask=encoder_attention_mask,
1030+
audio_encoder_attention_mask=audio_encoder_attention_mask,
1031+
)
10271032

10281033
# 6. Output layers
10291034
scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2)

0 commit comments

Comments
 (0)