|
18 | 18 | import jax.numpy as jnp |
19 | 19 | from flax import nnx |
20 | 20 | import flax.linen as nn |
| 21 | +import numpy as np |
| 22 | + |
| 23 | +printed_count = 0 |
| 24 | +def print_shape(name, tensor): |
| 25 | + global printed_count |
| 26 | + if printed_count > 1000: |
| 27 | + return |
| 28 | + if tensor is not None: |
| 29 | + def _print_fn(n, t): |
| 30 | + t_np = np.array(t, dtype=np.float32) |
| 31 | + print(f"[{n}] min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}") |
| 32 | + if isinstance(tensor, jax.core.Tracer): |
| 33 | + jax.debug.callback(_print_fn, name, tensor) |
| 34 | + else: |
| 35 | + _print_fn(name, tensor) |
| 36 | + printed_count += 1 |
21 | 37 |
|
22 | 38 | from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed |
23 | 39 | from maxdiffusion.models.attention_flax import NNXSimpleFeedForward |
@@ -342,6 +358,17 @@ def __call__( |
342 | 358 | ) -> Tuple[jax.Array, jax.Array]: |
343 | 359 | batch_size = hidden_states.shape[0] |
344 | 360 |
|
| 361 | + print_shape("Block Input hidden_states", hidden_states) |
| 362 | + print_shape("Block Input audio_hidden_states", audio_hidden_states) |
| 363 | + print_shape("Block Input encoder_hidden_states", encoder_hidden_states) |
| 364 | + print_shape("Block Input audio_encoder_hidden_states", audio_encoder_hidden_states) |
| 365 | + print_shape("Block Input temb", temb) |
| 366 | + print_shape("Block Input temb_audio", temb_audio) |
| 367 | + print_shape("Block Input temb_ca_scale_shift", temb_ca_scale_shift) |
| 368 | + print_shape("Block Input temb_ca_audio_scale_shift", temb_ca_audio_scale_shift) |
| 369 | + print_shape("Block Input temb_ca_gate", temb_ca_gate) |
| 370 | + print_shape("Block Input temb_ca_audio_gate", temb_ca_audio_gate) |
| 371 | + |
345 | 372 | axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) |
346 | 373 | hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) |
347 | 374 | audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names) |
@@ -370,6 +397,13 @@ def __call__( |
370 | 397 | scale_mlp = ada_values[:, :, 4, :] |
371 | 398 | gate_mlp = ada_values[:, :, 5, :] |
372 | 399 |
|
| 400 | + print_shape("shift_msa", shift_msa) |
| 401 | + print_shape("scale_msa", scale_msa) |
| 402 | + print_shape("gate_msa", gate_msa) |
| 403 | + print_shape("shift_mlp", shift_mlp) |
| 404 | + print_shape("scale_mlp", scale_mlp) |
| 405 | + print_shape("gate_mlp", gate_mlp) |
| 406 | + |
373 | 407 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
374 | 408 |
|
375 | 409 | attn_hidden_states = self.attn1( |
@@ -889,6 +923,11 @@ def __call__( |
889 | 923 | audio_encoder_attention_mask = jnp.expand_dims(audio_encoder_attention_mask, axis=1) |
890 | 924 |
|
891 | 925 | batch_size = hidden_states.shape[0] |
| 926 | + print_shape("Model Input hidden_states", hidden_states) |
| 927 | + print_shape("Model Input audio_hidden_states", audio_hidden_states) |
| 928 | + print_shape("Model Input encoder_hidden_states", encoder_hidden_states) |
| 929 | + print_shape("Model Input audio_encoder_hidden_states", audio_encoder_hidden_states) |
| 930 | + print_shape("Model Input timestep", timestep) |
892 | 931 |
|
893 | 932 | # 1. Prepare RoPE positional embeddings |
894 | 933 | if video_coords is None: |
|
0 commit comments