Skip to content

Commit 2b162af

Browse files
committed
debug added for learnable registers
1 parent 79b084c commit 2b162af

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def patched_fe_call(self, hidden_states, attention_mask):
7676
orig_replace = Embeddings1DConnector._replace_padded_with_learnable_registers
7777
def patched_replace(self, hidden_states, attention_mask):
7878
regs = self.learnable_registers.value
79-
jax.debug.print("[MAXDIFFUSION] Connector Registers std: {std:.5f}, mean: {mean:.5f}, min: {min:.5f}",
79+
jax.debug.print("[MAXDIFFUSION] Connector Registers std: {std}, mean: {mean}, min: {min}",
8080
std=jnp.std(regs), mean=jnp.mean(regs), min=jnp.min(regs))
8181

8282
return orig_replace(self, hidden_states, attention_mask)

0 commit comments

Comments
 (0)