@@ -302,12 +302,16 @@ def main(config, test_args): # pylint: disable=W0621
302302 "Comparing up to the smaller vocab size."
303303 )
304304 min_vocab_size = min (full_train_logits .shape [- 1 ], golden_logits .shape [- 1 ])
305+
306+ start_index = 1 if test_args .skip_first_token else 0
305307 # shape [seq_len, vocab_size]
306- train_logits_slice = full_train_logits [0 , :token_size , :min_vocab_size ]
307- golden_logits_slice = golden_logits [:token_size , :min_vocab_size ]
308- max_logging .log ("\n [logits: token 2]" )
309- max_logging .log (f"{ golden_logits_slice [2 ]= } " )
310- max_logging .log (f"{ train_logits_slice [2 ]= } " )
308+ train_logits_slice = full_train_logits [0 , start_index :token_size , :min_vocab_size ]
309+ golden_logits_slice = golden_logits [start_index :token_size , :min_vocab_size ]
310+
311+ if train_logits_slice .shape [0 ] > 2 :
312+ max_logging .log (f"\n [logits: token { start_index + 2 } ]" )
313+ max_logging .log (f"{ golden_logits_slice [2 ]= } " )
314+ max_logging .log (f"{ train_logits_slice [2 ]= } " )
311315
312316 # Calculate absolute and relative differences for detailed reporting
313317 abs_diff = jnp .abs (train_logits_slice - golden_logits_slice )
@@ -337,17 +341,18 @@ def main(config, test_args): # pylint: disable=W0621
337341 model_probabilities = jax .nn .softmax (train_logits_slice , axis = - 1 )
338342 golden_probabilities = jax .nn .softmax (golden_logits_slice , axis = - 1 )
339343
340- max_logging .log ("\n [probability: token 1]" )
341- max_logging .log (f"{ golden_probabilities [1 ]= } " )
342- max_logging .log (f"{ model_probabilities [1 ]= } " )
344+ if golden_probabilities .shape [0 ] > 1 :
345+ max_logging .log (f"\n [probability: token { start_index + 1 } ]" )
346+ max_logging .log (f"{ golden_probabilities [1 ]= } " )
347+ max_logging .log (f"{ model_probabilities [1 ]= } " )
343348
344349 kl_div = jax .numpy .sum (jax .scipy .special .kl_div (golden_probabilities , model_probabilities ), axis = - 1 )
345350 max_kl_div_val = jax .numpy .max (kl_div )
346351 max_kl_div_idx = jax .numpy .argmax (kl_div )
347352 max_logging .log (
348353 f"\n [KL divergence]\n "
349354 f"KL divergence = { kl_div } , max KL divergence = { max_kl_div_val } at index { max_kl_div_idx } , "
350- f"the corresponding token id is { ids [0 , max_kl_div_idx ]} "
355+ f"the corresponding token id is { ids [0 , max_kl_div_idx + start_index ]} "
351356 )
352357
353358 if jax .process_index () == 0 and test_args .output_logits_path :
@@ -465,7 +470,12 @@ def main(config, test_args): # pylint: disable=W0621
465470
466471 # --- Compare all logits in the sequence (for the first batch item) ---
467472 # Unsqueeze to add batch dimension for check_kl_divergence: [1, seq, vocab]
468- check_kl_divergence (mt_logits_torch [0 ].unsqueeze (0 ), hf_logits_torch [0 ].unsqueeze (0 ), atol = test_args .max_kl_div )
473+ start_index = 1 if test_args .skip_first_token else 0
474+ check_kl_divergence (
475+ mt_logits_torch [0 , start_index :].unsqueeze (0 ),
476+ hf_logits_torch [0 , start_index :].unsqueeze (0 ),
477+ atol = test_args .max_kl_div ,
478+ )
469479 if jax .process_index () == 0 and test_args .output_logits_path :
470480 data_to_save = {
471481 "mt_logits" : mt_logits_torch [0 ].tolist (),
@@ -504,6 +514,13 @@ def main(config, test_args): # pylint: disable=W0621
504514 parser .add_argument ("--output_logits_path" , type = str , required = False , default = "" )
505515 parser .add_argument ("--gcs_output_logits_path" , type = str , required = False , default = "" )
506516 parser .add_argument ("--clip_logits_epsilon" , type = float , required = False , default = None )
517+ parser .add_argument (
518+ "--skip_first_token" ,
519+ action = "store_true" ,
520+ required = False ,
521+ default = False ,
522+ help = "Skip the first token during comparison to ignore BOS/init mismatches." ,
523+ )
507524 test_args , _ = parser .parse_known_args ()
508525
509526 # Remove args defined in this test file to avoid error from pyconfig
@@ -519,6 +536,7 @@ def main(config, test_args): # pylint: disable=W0621
519536 "--output_logits_path" ,
520537 "--gcs_output_logits_path" ,
521538 "--clip_logits_epsilon" ,
539+ "--skip_first_token" ,
522540 ]
523541 for arg in to_remove_args :
524542 model_args = [s for s in model_args if not s .startswith (arg )]
@@ -527,6 +545,10 @@ def main(config, test_args): # pylint: disable=W0621
527545 assert (
528546 test_args .atol is not None or test_args .max_kl_div is not None
529547 ), "At least one of --atol or --max_kl_div must be specified to define the test criteria."
548+
549+ if test_args .run_hf_model and test_args .clip_logits_epsilon is not None :
550+ raise ValueError ("--clip_logits_epsilon is not supported when running HF model on-the-fly (run_hf_model=True)." )
551+
530552 if cfg .use_multimodal :
531553 assert not test_args .run_hf_model , (
532554 "Multimodal does not support running hf model on-the-fly, please generate hf golden logits "
0 commit comments