Skip to content

Commit ebb5d62

Browse files
committed
transformer debug
1 parent 6aa4898 commit ebb5d62

1 file changed

Lines changed: 39 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@
1818
import jax.numpy as jnp
1919
from flax import nnx
2020
import flax.linen as nn
21+
import numpy as np
22+
23+
printed_count = 0
24+
def print_shape(name, tensor):
25+
global printed_count
26+
if printed_count > 1000:
27+
return
28+
if tensor is not None:
29+
def _print_fn(n, t):
30+
t_np = np.array(t, dtype=np.float32)
31+
print(f"[{n}] min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}")
32+
if isinstance(tensor, jax.core.Tracer):
33+
jax.debug.callback(_print_fn, name, tensor)
34+
else:
35+
_print_fn(name, tensor)
36+
printed_count += 1
2137

2238
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
2339
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward
@@ -342,6 +358,17 @@ def __call__(
342358
) -> Tuple[jax.Array, jax.Array]:
343359
batch_size = hidden_states.shape[0]
344360

361+
print_shape("Block Input hidden_states", hidden_states)
362+
print_shape("Block Input audio_hidden_states", audio_hidden_states)
363+
print_shape("Block Input encoder_hidden_states", encoder_hidden_states)
364+
print_shape("Block Input audio_encoder_hidden_states", audio_encoder_hidden_states)
365+
print_shape("Block Input temb", temb)
366+
print_shape("Block Input temb_audio", temb_audio)
367+
print_shape("Block Input temb_ca_scale_shift", temb_ca_scale_shift)
368+
print_shape("Block Input temb_ca_audio_scale_shift", temb_ca_audio_scale_shift)
369+
print_shape("Block Input temb_ca_gate", temb_ca_gate)
370+
print_shape("Block Input temb_ca_audio_gate", temb_ca_audio_gate)
371+
345372
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
346373
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
347374
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names)
@@ -370,6 +397,13 @@ def __call__(
370397
scale_mlp = ada_values[:, :, 4, :]
371398
gate_mlp = ada_values[:, :, 5, :]
372399

400+
print_shape("shift_msa", shift_msa)
401+
print_shape("scale_msa", scale_msa)
402+
print_shape("gate_msa", gate_msa)
403+
print_shape("shift_mlp", shift_mlp)
404+
print_shape("scale_mlp", scale_mlp)
405+
print_shape("gate_mlp", gate_mlp)
406+
373407
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
374408

375409
attn_hidden_states = self.attn1(
@@ -889,6 +923,11 @@ def __call__(
889923
audio_encoder_attention_mask = jnp.expand_dims(audio_encoder_attention_mask, axis=1)
890924

891925
batch_size = hidden_states.shape[0]
926+
print_shape("Model Input hidden_states", hidden_states)
927+
print_shape("Model Input audio_hidden_states", audio_hidden_states)
928+
print_shape("Model Input encoder_hidden_states", encoder_hidden_states)
929+
print_shape("Model Input audio_encoder_hidden_states", audio_encoder_hidden_states)
930+
print_shape("Model Input timestep", timestep)
892931

893932
# 1. Prepare RoPE positional embeddings
894933
if video_coords is None:

0 commit comments

Comments
 (0)