@@ -75,8 +75,22 @@ def patched_fe_call(self, hidden_states, attention_mask):
7575
7676orig_replace = Embeddings1DConnector ._replace_padded_with_learnable_registers
7777def patched_replace (self , hidden_states , attention_mask ):
78+ if attention_mask .ndim == 2 :
79+ mask = attention_mask
80+ else :
81+ mask = attention_mask .squeeze (- 1 ) # [B, T]
82+ curr_mask = (mask > 0.5 ).astype (jnp .int32 )
83+
84+ jax .debug .print ("\n [MAXDIFFUSION] Mask Debug:" )
85+ jax .debug .print (" Input Attn Mask min/max: {} / {}" , jnp .min (attention_mask ), jnp .max (attention_mask ))
86+ jax .debug .print (" Curr Mask sum: {} (valid tokens)" , jnp .sum (curr_mask ))
87+ jax .debug .print (" Curr Mask start 20 elements: {}" , curr_mask [0 , :20 ])
88+
89+ flipped = jnp .flip (curr_mask , axis = [1 ])
90+ jax .debug .print (" Flipped Mask Fwd logic sum: {} (first 20 elements: {})" , jnp .sum (flipped ), flipped [0 , :20 ])
91+
7892 regs = self .learnable_registers .value
79- jax .debug .print ("[MAXDIFFUSION] Connector Registers std: {std}, mean: {mean}, min: {min}" ,
93+ jax .debug .print (" [MAXDIFFUSION] Connector Registers std: {std}, mean: {mean}, min: {min}" ,
8094 std = jnp .std (regs ), mean = jnp .mean (regs ), min = jnp .min (regs ))
8195
8296 return orig_replace (self , hidden_states , attention_mask )
0 commit comments