@@ -258,22 +258,12 @@ def __call__(
258258 # 1. Video and Audio Self-Attention
259259 norm_hidden_states = self .norm1 (hidden_states )
260260
261- import sys
262-
263261 # Calculate Video AdaLN values
264262 num_ada_params = self .scale_shift_table .shape [0 ]
265263 # table shape: (6, dim) -> (1, 1, 6, dim)
266264 scale_shift_table_reshaped = jnp .expand_dims (self .scale_shift_table , axis = (0 , 1 ))
267265 # temb shape: (batch, temb_dim) -> (batch, 1, 6, dim) (assuming temb_dim is num_ada_params * dim)
268- print (f"DEBUG_BLOCK: scale_shift_table_reshaped shape: { scale_shift_table_reshaped .shape } " )
269- print (f"DEBUG_BLOCK: temb shape before reshape: { temb .shape } " )
270- sys .stdout .flush ()
271-
272266 temb_reshaped = temb .reshape (batch_size , 1 , num_ada_params , - 1 )
273-
274- print (f"DEBUG_BLOCK: temb_reshaped shape: { temb_reshaped .shape } " )
275- sys .stdout .flush ()
276-
277267 ada_values = scale_shift_table_reshaped + temb_reshaped
278268
279269 shift_msa = ada_values [:, :, 0 , :]
@@ -297,15 +287,7 @@ def __call__(
297287
298288 num_audio_ada_params = self .audio_scale_shift_table .shape [0 ]
299289 audio_scale_shift_table_reshaped = jnp .expand_dims (self .audio_scale_shift_table , axis = (0 , 1 ))
300-
301- print (f"DEBUG_BLOCK_AUDIO: audio_scale_shift_table_reshaped shape: { audio_scale_shift_table_reshaped .shape } " )
302- print (f"DEBUG_BLOCK_AUDIO: temb_audio shape before reshape: { temb_audio .shape } " )
303- sys .stdout .flush ()
304-
305290 temb_audio_reshaped = temb_audio .reshape (batch_size , 1 , num_audio_ada_params , - 1 )
306-
307- print (f"DEBUG_BLOCK_AUDIO: temb_audio_reshaped shape: { temb_audio_reshaped .shape } " )
308- sys .stdout .flush ()
309291 audio_ada_values = audio_scale_shift_table_reshaped + temb_audio_reshaped
310292
311293 audio_shift_msa = audio_ada_values [:, :, 0 , :]
@@ -518,10 +500,6 @@ def __init__(
518500 self .audio_caption_projection = NNXPixArtAlphaTextProjection (
519501 rngs = rngs , in_features = self .caption_channels , hidden_size = audio_inner_dim , dtype = self .dtype , weights_dtype = self .weights_dtype
520502 )
521- import sys
522- print (f"DEBUG IN INIT: inner_dim={ inner_dim } , num_attention_heads={ num_attention_heads } , attention_head_dim={ attention_head_dim } " )
523- sys .stdout .flush ()
524-
525503 # 3. Timestep Modulation Params and Embedding
526504 self .time_embed = LTX2AdaLayerNormSingle (
527505 rngs = rngs , embedding_dim = inner_dim , num_mod_params = 6 , use_additional_conditions = False , dtype = self .dtype , weights_dtype = self .weights_dtype
0 commit comments