Skip to content

Commit 859257d

Browse files
committed
Add option to skip first token during logits comparison
1 parent f44534f commit 859257d

1 file changed

Lines changed: 32 additions & 10 deletions

File tree

tests/utils/forward_pass_logit_checker.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,16 @@ def main(config, test_args): # pylint: disable=W0621
302302
"Comparing up to the smaller vocab size."
303303
)
304304
min_vocab_size = min(full_train_logits.shape[-1], golden_logits.shape[-1])
305+
306+
start_index = 1 if test_args.skip_first_token else 0
305307
# shape [seq_len, vocab_size]
306-
train_logits_slice = full_train_logits[0, :token_size, :min_vocab_size]
307-
golden_logits_slice = golden_logits[:token_size, :min_vocab_size]
308-
max_logging.log("\n[logits: token 2]")
309-
max_logging.log(f"{golden_logits_slice[2]=}")
310-
max_logging.log(f"{train_logits_slice[2]=}")
308+
train_logits_slice = full_train_logits[0, start_index:token_size, :min_vocab_size]
309+
golden_logits_slice = golden_logits[start_index:token_size, :min_vocab_size]
310+
311+
if train_logits_slice.shape[0] > 2:
312+
max_logging.log(f"\n[logits: token {start_index + 2}]")
313+
max_logging.log(f"{golden_logits_slice[2]=}")
314+
max_logging.log(f"{train_logits_slice[2]=}")
311315

312316
# Calculate absolute and relative differences for detailed reporting
313317
abs_diff = jnp.abs(train_logits_slice - golden_logits_slice)
@@ -337,17 +341,18 @@ def main(config, test_args): # pylint: disable=W0621
337341
model_probabilities = jax.nn.softmax(train_logits_slice, axis=-1)
338342
golden_probabilities = jax.nn.softmax(golden_logits_slice, axis=-1)
339343

340-
max_logging.log("\n[probability: token 1]")
341-
max_logging.log(f"{golden_probabilities[1]=}")
342-
max_logging.log(f"{model_probabilities[1]=}")
344+
if golden_probabilities.shape[0] > 1:
345+
max_logging.log(f"\n[probability: token {start_index + 1}]")
346+
max_logging.log(f"{golden_probabilities[1]=}")
347+
max_logging.log(f"{model_probabilities[1]=}")
343348

344349
kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1)
345350
max_kl_div_val = jax.numpy.max(kl_div)
346351
max_kl_div_idx = jax.numpy.argmax(kl_div)
347352
max_logging.log(
348353
f"\n[KL divergence]\n"
349354
f"KL divergence = {kl_div}, max KL divergence = {max_kl_div_val} at index {max_kl_div_idx}, "
350-
f"the corresponding token id is {ids[0, max_kl_div_idx]}"
355+
f"the corresponding token id is {ids[0, max_kl_div_idx + start_index]}"
351356
)
352357

353358
if jax.process_index() == 0 and test_args.output_logits_path:
@@ -465,7 +470,12 @@ def main(config, test_args): # pylint: disable=W0621
465470

466471
# --- Compare all logits in the sequence (for the first batch item) ---
467472
# Unsqueeze to add batch dimension for check_kl_divergence: [1, seq, vocab]
468-
check_kl_divergence(mt_logits_torch[0].unsqueeze(0), hf_logits_torch[0].unsqueeze(0), atol=test_args.max_kl_div)
473+
start_index = 1 if test_args.skip_first_token else 0
474+
check_kl_divergence(
475+
mt_logits_torch[0, start_index:].unsqueeze(0),
476+
hf_logits_torch[0, start_index:].unsqueeze(0),
477+
atol=test_args.max_kl_div,
478+
)
469479
if jax.process_index() == 0 and test_args.output_logits_path:
470480
data_to_save = {
471481
"mt_logits": mt_logits_torch[0].tolist(),
@@ -504,6 +514,13 @@ def main(config, test_args): # pylint: disable=W0621
504514
parser.add_argument("--output_logits_path", type=str, required=False, default="")
505515
parser.add_argument("--gcs_output_logits_path", type=str, required=False, default="")
506516
parser.add_argument("--clip_logits_epsilon", type=float, required=False, default=None)
517+
parser.add_argument(
518+
"--skip_first_token",
519+
action="store_true",
520+
required=False,
521+
default=False,
522+
help="Skip the first token during comparison to ignore BOS/init mismatches.",
523+
)
507524
test_args, _ = parser.parse_known_args()
508525

509526
# Remove args defined in this test file to avoid error from pyconfig
@@ -519,6 +536,7 @@ def main(config, test_args): # pylint: disable=W0621
519536
"--output_logits_path",
520537
"--gcs_output_logits_path",
521538
"--clip_logits_epsilon",
539+
"--skip_first_token",
522540
]
523541
for arg in to_remove_args:
524542
model_args = [s for s in model_args if not s.startswith(arg)]
@@ -527,6 +545,10 @@ def main(config, test_args): # pylint: disable=W0621
527545
assert (
528546
test_args.atol is not None or test_args.max_kl_div is not None
529547
), "At least one of --atol or --max_kl_div must be specified to define the test criteria."
548+
549+
if test_args.run_hf_model and test_args.clip_logits_epsilon is not None:
550+
raise ValueError("--clip_logits_epsilon is not supported when running HF model on-the-fly (run_hf_model=True).")
551+
530552
if cfg.use_multimodal:
531553
assert not test_args.run_hf_model, (
532554
"Multimodal does not support running hf model on-the-fly, please generate hf golden logits "

0 commit comments

Comments
 (0)