@@ -79,18 +79,18 @@ def patched_transformer_forward_pass(*args, **kwargs):
7979 # In Maxdiffusion, args are usually (hidden_states, encoder_hidden_states, timestep, ...)
8080 if "hidden_states" in kwargs :
8181 print_stat ("transformer_input_video_latents" , kwargs ["hidden_states" ])
82- elif len (args ) > 0 and args [0 ] is not None :
83- print_stat ("transformer_input_video_latents" , args [0 ])
82+ elif len (args ) > 2 and args [2 ] is not None :
83+ print_stat ("transformer_input_video_latents" , args [2 ])
8484
8585 if "encoder_hidden_states" in kwargs :
8686 print_stat ("transformers_encoder_hidden_states" , kwargs ["encoder_hidden_states" ])
87- elif len (args ) > 1 and args [1 ] is not None :
88- print_stat ("transformers_encoder_hidden_states" , args [1 ])
87+ elif len (args ) > 5 and args [5 ] is not None :
88+ print_stat ("transformers_encoder_hidden_states" , args [5 ])
8989
9090 if "timestep" in kwargs :
9191 print_stat ("transformer_timestep" , kwargs ["timestep" ])
92- elif len (args ) > 2 and args [2 ] is not None :
93- print_stat ("transformer_timestep" , args [2 ])
92+ elif len (args ) > 4 and args [4 ] is not None :
93+ print_stat ("transformer_timestep" , args [4 ])
9494
9595 if "audio_hidden_states" in kwargs :
9696 print_stat ("transformer_input_audio_latents" , kwargs ["audio_hidden_states" ])
@@ -99,8 +99,8 @@ def patched_transformer_forward_pass(*args, **kwargs):
9999
100100 if "audio_encoder_hidden_states" in kwargs :
101101 print_stat ("transformers_audio_encoder_hidden_states" , kwargs ["audio_encoder_hidden_states" ])
102- elif len (args ) > 4 and args [4 ] is not None :
103- print_stat ("transformers_audio_encoder_hidden_states" , args [4 ])
102+ elif len (args ) > 6 and args [6 ] is not None :
103+ print_stat ("transformers_audio_encoder_hidden_states" , args [6 ])
104104
105105 print ("\n [SUCCESS] Captured all inputs up to Transformer logic. Exiting early to save compute.\n " )
106106 import os
0 commit comments