Skip to content

Commit b5539a4

Browse files
committed
debug
1 parent 46cae70 commit b5539a4

4 files changed

Lines changed: 31 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ def __call__(
901901
causal: bool = True,
902902
deterministic: bool = True,
903903
) -> jax.Array:
904+
print(f"[LTX2 XPROF Tracing] Encoder __call__ input shape: {sample.shape}")
904905
# JAX: (B, T, H, W, C)
905906
B, T, H, W, C = sample.shape
906907
p = self.patch_size
@@ -1074,6 +1075,7 @@ def __call__(
10741075
causal: bool = False,
10751076
deterministic: bool = True,
10761077
) -> jax.Array:
1078+
print(f"[LTX2 XPROF Tracing] Decoder __call__ input shape: {sample.shape}")
10771079
if self.timestep_scale_multiplier is not None and temb is not None:
10781080
temb = temb * self.timestep_scale_multiplier.value
10791081

@@ -1556,6 +1558,7 @@ def encode(
15561558
key: Optional[jax.Array] = None,
15571559
causal: Optional[bool] = None,
15581560
) -> Union[FlaxAutoencoderKLOutput, Tuple[jax.Array]]:
1561+
print(f"[LTX2 XPROF Tracing] VAE encode input shape: {sample.shape}")
15591562
causal = self.encoder_causal if causal is None else causal
15601563

15611564
if self.use_slicing and sample.shape[0] > 1:
@@ -1584,6 +1587,7 @@ def decode(
15841587
generator: Optional[jax.Array] = None,
15851588
causal: Optional[bool] = None,
15861589
) -> Union[FlaxDecoderOutput, Tuple[jax.Array]]:
1590+
print(f"[LTX2 XPROF Tracing] VAE decode input shape: {latents.shape}")
15871591
causal = self.decoder_causal if causal is None else causal
15881592
key = generator
15891593

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2_audio.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ def __init__(
562562
self.conv_out = nnx.Conv(block_in, z_channels, kernel_size=(3, 3), padding="SAME", dtype=dtype, rngs=rngs)
563563

564564
def __call__(self, x, train: bool = False):
565+
print(f"[LTX2 XPROF Tracing] Audio Encoder __call__ input shape: {x.shape}")
565566
h = self.conv_in(x)
566567

567568
for stage in self.down_stages:
@@ -702,6 +703,7 @@ def __init__(
702703
self.conv_out = nnx.Conv(block_in, self.output_channels, kernel_size=(3, 3), padding="SAME", dtype=dtype, rngs=rngs)
703704

704705
def __call__(self, z, target_frames=None, target_mel_bins=None, train: bool = False):
706+
print(f"[LTX2 XPROF Tracing] Audio Decoder __call__ input shape: {z.shape}")
705707
h = self.conv_in(z)
706708

707709
h = self.mid_block1(h, train=train)
@@ -825,6 +827,7 @@ def __init__(
825827
self.latents_std = nnx.Param(jnp.ones((base_channels,), dtype=dtype))
826828

827829
def encode(self, x: jnp.ndarray, return_dict: bool = True, train: bool = False):
830+
print(f"[LTX2 XPROF Tracing] Audio VAE encode input shape: {x.shape}")
828831
h = self.encoder(x, train=train)
829832
posterior = FlaxDiagonalGaussianDistribution(h)
830833

@@ -833,6 +836,7 @@ def encode(self, x: jnp.ndarray, return_dict: bool = True, train: bool = False):
833836
return FlaxAutoencoderKLOutput(latent_dist=posterior)
834837

835838
def decode(self, z: jnp.ndarray, return_dict: bool = True, train: bool = False):
839+
print(f"[LTX2 XPROF Tracing] Audio VAE decode input shape: {z.shape}")
836840
batch, time, freq, channels = z.shape
837841
target_frames = time * self.latent_downsample_factor
838842
if self.causality_axis is not None and self.causality_axis != "none":

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ def __call__(
354354
a2v_cross_attention_mask: Optional[jax.Array] = None,
355355
v2a_cross_attention_mask: Optional[jax.Array] = None,
356356
) -> Tuple[jax.Array, jax.Array]:
357+
print(f"[LTX2 XPROF Tracing] Block __call__ inputs:")
358+
print(f"[LTX2 XPROF Tracing] hidden_states shape: {hidden_states.shape}")
359+
print(f"[LTX2 XPROF Tracing] audio_hidden_states shape: {audio_hidden_states.shape}")
357360
batch_size = hidden_states.shape[0]
358361

359362
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
@@ -633,6 +636,10 @@ def __init__(
633636
inner_dim = self.num_attention_heads * self.attention_head_dim
634637
audio_inner_dim = self.audio_num_attention_heads * self.audio_attention_head_dim
635638

639+
print(f"[LTX2 XPROF Config] num_layers: {self.num_layers}")
640+
print(f"[LTX2 XPROF Config] Video: inner_dim={inner_dim}, num_heads={self.num_attention_heads}, head_dim={self.attention_head_dim}")
641+
print(f"[LTX2 XPROF Config] Audio: audio_inner_dim={audio_inner_dim}, num_heads={self.audio_num_attention_heads}, head_dim={self.audio_attention_head_dim}")
642+
636643
# 1. Patchification input projections
637644
self.proj_in = nnx.Linear(
638645
self.in_channels,
@@ -924,6 +931,11 @@ def __call__(
924931

925932
batch_size = hidden_states.shape[0]
926933

934+
print(f"[LTX2 XPROF Tracing] Model __call__ inputs:")
935+
print(f"[LTX2 XPROF Tracing] hidden_states shape: {hidden_states.shape}")
936+
print(f"[LTX2 XPROF Tracing] audio_hidden_states shape: {audio_hidden_states.shape}")
937+
print(f"[LTX2 XPROF Tracing] encoder_hidden_states shape: {encoder_hidden_states.shape}")
938+
927939
# 1. Prepare RoPE positional embeddings
928940
with jax.named_scope("RoPE Preparation"):
929941
if video_coords is None:

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,9 @@ def __call__(
12091209
latent_width = width // self.vae_spatial_compression_ratio
12101210
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
12111211

1212+
max_logging.log(f"[LTX2 XPROF] Input dimensions: height={height}, width={width}, num_frames={num_frames}")
1213+
max_logging.log(f"[LTX2 XPROF] Video Latent dimensions: height={latent_height}, width={latent_width}, num_frames={latent_num_frames}")
1214+
12121215
# 4. Prepare Audio Latents
12131216
audio_channels = (
12141217
self.audio_vae.config.latent_channels
@@ -1222,6 +1225,8 @@ def __call__(
12221225
)
12231226
audio_num_frames = round(duration_s * audio_latents_per_second)
12241227

1228+
max_logging.log(f"[LTX2 XPROF] Audio Latent dimensions: channels={audio_channels}, num_frames={audio_num_frames}")
1229+
12251230
audio_latents = self.prepare_audio_latents(
12261231
batch_size=batch_size,
12271232
num_channels_latents=audio_channels,
@@ -1238,6 +1243,8 @@ def __call__(
12381243
video_sequence_length = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
12391244
video_sequence_length *= (height // self.vae_spatial_compression_ratio) * (width // self.vae_spatial_compression_ratio)
12401245

1246+
max_logging.log(f"[LTX2 XPROF] Video Sequence Length: {video_sequence_length}")
1247+
12411248
mu = calculate_shift(
12421249
video_sequence_length,
12431250
self.scheduler.config.get("base_image_seq_len", 1024),
@@ -1521,6 +1528,10 @@ def transformer_forward_pass(
15211528
audio_num_frames,
15221529
fps,
15231530
):
1531+
print(f"[LTX2 XPROF Tracing] latents shape: {latents.shape}")
1532+
print(f"[LTX2 XPROF Tracing] audio_latents shape: {audio_latents.shape}")
1533+
print(f"[LTX2 XPROF Tracing] encoder_hidden_states shape: {encoder_hidden_states.shape}")
1534+
15241535
transformer = nnx.merge(graphdef, state)
15251536

15261537
# Expand timestep to batch size

0 commit comments

Comments
 (0)