Skip to content

Commit b5526e7

Browse files
committed
before_transformer parity file
1 parent fde6d7f commit b5526e7

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@ def patched_transformer_forward_pass(*args, **kwargs):
7979
# In Maxdiffusion, args are usually (hidden_states, encoder_hidden_states, timestep, ...)
8080
if "hidden_states" in kwargs:
8181
print_stat("transformer_input_video_latents", kwargs["hidden_states"])
82-
elif len(args) > 0 and args[0] is not None:
83-
print_stat("transformer_input_video_latents", args[0])
82+
elif len(args) > 2 and args[2] is not None:
83+
print_stat("transformer_input_video_latents", args[2])
8484

8585
if "encoder_hidden_states" in kwargs:
8686
print_stat("transformers_encoder_hidden_states", kwargs["encoder_hidden_states"])
87-
elif len(args) > 1 and args[1] is not None:
88-
print_stat("transformers_encoder_hidden_states", args[1])
87+
elif len(args) > 5 and args[5] is not None:
88+
print_stat("transformers_encoder_hidden_states", args[5])
8989

9090
if "timestep" in kwargs:
9191
print_stat("transformer_timestep", kwargs["timestep"])
92-
elif len(args) > 2 and args[2] is not None:
93-
print_stat("transformer_timestep", args[2])
92+
elif len(args) > 4 and args[4] is not None:
93+
print_stat("transformer_timestep", args[4])
9494

9595
if "audio_hidden_states" in kwargs:
9696
print_stat("transformer_input_audio_latents", kwargs["audio_hidden_states"])
@@ -99,8 +99,8 @@ def patched_transformer_forward_pass(*args, **kwargs):
9999

100100
if "audio_encoder_hidden_states" in kwargs:
101101
print_stat("transformers_audio_encoder_hidden_states", kwargs["audio_encoder_hidden_states"])
102-
elif len(args) > 4 and args[4] is not None:
103-
print_stat("transformers_audio_encoder_hidden_states", args[4])
102+
elif len(args) > 6 and args[6] is not None:
103+
print_stat("transformers_audio_encoder_hidden_states", args[6])
104104

105105
print("\n[SUCCESS] Captured all inputs up to Transformer logic. Exiting early to save compute.\n")
106106
import os

0 commit comments

Comments
 (0)