|
20 | 20 | import flax.linen as nn |
21 | 21 | import numpy as np |
22 | 22 |
|
23 | | -printed_count = 0 |
24 | | -def print_shape(name, tensor): |
25 | | - global printed_count |
26 | | - if printed_count > 1000: |
27 | | - return |
28 | | - if tensor is not None: |
29 | | - def _print_fn(n, t): |
30 | | - t_np = np.array(t, dtype=np.float32) |
31 | | - print(f"[{n}] min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}") |
32 | | - if isinstance(tensor, jax.core.Tracer): |
33 | | - jax.debug.callback(_print_fn, name, tensor) |
34 | | - else: |
35 | | - _print_fn(name, tensor) |
36 | | - printed_count += 1 |
| 23 | + |
37 | 24 |
|
38 | 25 | from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed |
39 | 26 | from maxdiffusion.models.attention_flax import NNXSimpleFeedForward |
@@ -358,16 +345,7 @@ def __call__( |
358 | 345 | ) -> Tuple[jax.Array, jax.Array]: |
359 | 346 | batch_size = hidden_states.shape[0] |
360 | 347 |
|
361 | | - print_shape("Block Input hidden_states", hidden_states) |
362 | | - print_shape("Block Input audio_hidden_states", audio_hidden_states) |
363 | | - print_shape("Block Input encoder_hidden_states", encoder_hidden_states) |
364 | | - print_shape("Block Input audio_encoder_hidden_states", audio_encoder_hidden_states) |
365 | | - print_shape("Block Input temb", temb) |
366 | | - print_shape("Block Input temb_audio", temb_audio) |
367 | | - print_shape("Block Input temb_ca_scale_shift", temb_ca_scale_shift) |
368 | | - print_shape("Block Input temb_ca_audio_scale_shift", temb_ca_audio_scale_shift) |
369 | | - print_shape("Block Input temb_ca_gate", temb_ca_gate) |
370 | | - print_shape("Block Input temb_ca_audio_gate", temb_ca_audio_gate) |
| 348 | + |
371 | 349 |
|
372 | 350 | axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) |
373 | 351 | hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) |
@@ -397,12 +375,8 @@ def __call__( |
397 | 375 | scale_mlp = ada_values[:, :, 4, :] |
398 | 376 | gate_mlp = ada_values[:, :, 5, :] |
399 | 377 |
|
400 | | - print_shape("shift_msa", shift_msa) |
401 | | - print_shape("scale_msa", scale_msa) |
402 | | - print_shape("gate_msa", gate_msa) |
403 | | - print_shape("shift_mlp", shift_mlp) |
404 | | - print_shape("scale_mlp", scale_mlp) |
405 | | - print_shape("gate_mlp", gate_mlp) |
| 378 | + |
| 379 | + |
406 | 380 |
|
407 | 381 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
408 | 382 |
|
@@ -923,11 +897,7 @@ def __call__( |
923 | 897 | audio_encoder_attention_mask = jnp.expand_dims(audio_encoder_attention_mask, axis=1) |
924 | 898 |
|
925 | 899 | batch_size = hidden_states.shape[0] |
926 | | - print_shape("Model Input hidden_states", hidden_states) |
927 | | - print_shape("Model Input audio_hidden_states", audio_hidden_states) |
928 | | - print_shape("Model Input encoder_hidden_states", encoder_hidden_states) |
929 | | - print_shape("Model Input audio_encoder_hidden_states", audio_encoder_hidden_states) |
930 | | - print_shape("Model Input timestep", timestep) |
| 900 | + |
931 | 901 |
|
932 | 902 | # 1. Prepare RoPE positional embeddings |
933 | 903 | if video_coords is None: |
|
0 commit comments