1414 limitations under the License.
1515 """
1616
17+ """
18+ Example to run
19+ python end_to_end/tpu/eval_assert.py avg_tflops metrics.txt 100
20+ python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
21+ python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
22+ """
23+
24+
25+
1726# pylint: skip-file
1827"""Reads and asserts over target values"""
1928from absl import app
@@ -34,26 +43,89 @@ def get_last_n_data(metrics_file, target, n=10):
3443 return last_n_data
3544
3645
37- def test_final_loss (metrics_file , target_loss ):
46+ def test_final_loss (metrics_file , target_loss , num_samples_str = "10" ):
3847 target_loss = float (target_loss )
48+ num_samples = int (num_samples_str )
3949 with open (metrics_file , "r" , encoding = "utf8" ) as _ :
40- use_last_n_data = 10
41- last_n_data = get_last_n_data (metrics_file , "learning/loss" , use_last_n_data )
50+ last_n_data = get_last_n_data (metrics_file , "learning/loss" ,num_samples )
4251 avg_last_n_data = sum (last_n_data ) / len (last_n_data )
4352 print (f"Mean of last { len (last_n_data )} losses is { avg_last_n_data } " )
4453 print (f"Target loss is { target_loss } " )
4554 assert avg_last_n_data < target_loss
4655 print ("Final loss test passed." )
4756
4857
58+ def test_avg_step_time (metrics_file , max_avg_step_time_str , num_samples_str = "10" ):
59+ """Tests if the average of the last N step times is below a maximum threshold."""
60+ max_avg_step_time = float (max_avg_step_time_str )
61+ num_samples = int (num_samples_str )
62+ metric_key = "perf/step_time_seconds"
63+ last_n_step_times = get_last_n_data (metrics_file , metric_key , num_samples )
64+
65+ if not last_n_step_times :
66+ raise ValueError (f"Metric '{ metric_key } ' not found or no data points in { metrics_file } ." )
67+
68+ avg_last_n_step_time = sum (last_n_step_times ) / len (last_n_step_times )
69+
70+ print (f"Found { len (last_n_step_times )} data points for '{ metric_key } '." )
71+ print (f"Mean of last { len (last_n_step_times )} step times is { avg_last_n_step_time :.4f} s" )
72+
73+ assert (
74+ avg_last_n_step_time < max_avg_step_time
75+ ), f"Average step time { avg_last_n_step_time :.4f} s is not less than target { max_avg_step_time } s."
76+ print ("Average step time test passed." )
77+
78+
79+ def test_avg_tflops (metrics_file , min_avg_tflops_str , num_samples_str = "10" ):
80+ """Tests if the average of the last N TFLOPs/sec values is above a minimum threshold."""
81+ min_avg_tflops = float (min_avg_tflops_str )
82+ num_samples = int (num_samples_str )
83+ metric_key = "perf/per_device_tflops_per_sec"
84+
85+ last_n_tflops = get_last_n_data (metrics_file , metric_key , num_samples )
86+
87+ if not last_n_tflops :
88+ raise ValueError (f"Metric '{ metric_key } ' not found or no data points in { metrics_file } ." )
89+
90+ avg_last_n_tflops = sum (last_n_tflops ) / len (last_n_tflops )
91+
92+ print (f"Found { len (last_n_tflops )} data points for '{ metric_key } '." )
93+ print (f"Mean of last { len (last_n_tflops )} steps TFLOPs/sec is { avg_last_n_tflops :.2f} " )
94+
95+ assert (
96+ avg_last_n_tflops > min_avg_tflops
97+ ), f"Average TFLOPs/sec { avg_last_n_tflops :.2f} is not greater than target { min_avg_tflops } ."
98+ print ("Average TFLOPs/sec test passed." )
99+
100+
49101def main (argv : Sequence [str ]) -> None :
102+ if len (argv ) < 2 :
103+ print ("Usage: python script.py <test_scenario> [test_vars...]" )
104+ print ("Available scenarios: final_loss, avg_step_time, avg_tflops" )
105+ raise ValueError ("Test scenario not specified." )
50106
51107 _ , test_scenario , * test_vars = argv
52108
53109 if test_scenario == "final_loss" :
54- test_final_loss (* test_vars )
110+ if len (test_vars ) < 2 :
111+ raise ValueError ("Usage: final_loss <metrics_file> <target_loss> [num_samples]" )
112+ metrics_file , target_loss , * num_samples_opt = test_vars
113+ num_samples = num_samples_opt [0 ] if num_samples_opt else "10"
114+ test_final_loss (metrics_file , target_loss , num_samples )
115+ elif test_scenario == "avg_step_time" :
116+ if len (test_vars ) < 2 :
117+ raise ValueError ("Usage: avg_step_time <metrics_file> <max_avg_step_time> [num_samples]" )
118+ metrics_file , max_avg_step_time , * num_samples_opt = test_vars
119+ num_samples = num_samples_opt [0 ] if num_samples_opt else "10"
120+ test_avg_step_time (metrics_file , max_avg_step_time , num_samples )
121+ elif test_scenario == "avg_tflops" :
122+ if len (test_vars ) < 2 :
123+ raise ValueError ("Usage: avg_tflops <metrics_file> <min_avg_tflops> [num_samples]" )
124+ metrics_file , min_avg_tflops , * num_samples_opt = test_vars
125+ num_samples = num_samples_opt [0 ] if num_samples_opt else "10"
126+ test_avg_tflops (metrics_file , min_avg_tflops , num_samples )
55127 else :
56- raise ValueError (f"Unrecognized test_scenario { test_scenario } " )
128+ raise ValueError (f"Unrecognized test_scenario ' { test_scenario } '. Available: final_loss, avg_step_time, avg_tflops " )
57129
58130
59131if __name__ == "__main__" :
0 commit comments