Skip to content

Commit 03e76b9

Browse files
committed
weight loading script shortened, jax profiling annotations added
1 parent 3e7eba5 commit 03e76b9

4 files changed

Lines changed: 172 additions & 289 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,14 @@ def __init__(
345345
dtype: DType = jnp.float32,
346346
attention_kernel: str = "flash",
347347
rope_type: str = "interleaved",
348+
enable_jax_named_scopes: bool = False,
348349
):
349350
self.heads = heads
350351
self.rope_type = rope_type
351352
self.dim_head = dim_head
352353
self.inner_dim = dim_head * heads
353354
self.dropout_rate = dropout
355+
self.enable_jax_named_scopes = enable_jax_named_scopes
354356

355357
# 1. Define Partitioned Initializers (Logical Axes)
356358
# Q, K, V kernels: [in_features (embed), out_features (heads)]
@@ -433,6 +435,11 @@ def __init__(
433435
axis_names_kv=("batch", "heads", "length", "kv"),
434436
)
435437

438+
def conditional_named_scope(self, name: str):
439+
import jax
440+
import contextlib
441+
return jax.named_scope(name) if getattr(self, "enable_jax_named_scopes", False) else contextlib.nullcontext()
442+
436443
def __call__(
437444
self,
438445
hidden_states: Array,
@@ -445,13 +452,15 @@ def __call__(
445452
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
446453

447454
# 1. Project
448-
query = self.to_q(hidden_states)
449-
key = self.to_k(context)
450-
value = self.to_v(context)
455+
with self.conditional_named_scope("proj_in"):
456+
query = self.to_q(hidden_states)
457+
key = self.to_k(context)
458+
value = self.to_v(context)
451459

452460
# 2. Norm (Full Inner Dimension)
453-
query = self.norm_q(query)
454-
key = self.norm_k(key)
461+
with self.conditional_named_scope("norm"):
462+
query = self.norm_q(query)
463+
key = self.norm_k(key)
455464

456465
# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
457466
# Frequencies are shape [B, S, InnerDim]
@@ -478,12 +487,14 @@ def __call__(
478487

479488
# 4. Attention
480489
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
481-
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
490+
with self.conditional_named_scope("attention_op"):
491+
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
482492

483493
# 7. Output Projection
484-
hidden_states = self.to_out(attn_output)
485-
486-
if self.dropout_layer is not None:
487-
hidden_states = self.dropout_layer(hidden_states)
494+
with self.conditional_named_scope("proj_out"):
495+
hidden_states = self.to_out(attn_output)
496+
497+
if self.dropout_layer is not None:
498+
hidden_states = self.dropout_layer(hidden_states)
488499

489500
return hidden_states

0 commit comments

Comments
 (0)