Skip to content

Commit 8c45e6e

Browse files
committed
Add compilation metrics, script to assert metrics, add conditional import of hugginface_hub and update huggingface hub
Signed-off-by: Kunjan patel <kunjanp@google.com>
1 parent 0d9afba commit 8c45e6e

18 files changed

Lines changed: 210 additions & 27 deletions

end_to_end/tpu/eval_assert.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,76 @@ def test_final_loss(metrics_file, target_loss):
4646
print("Final loss test passed.")
4747

4848

49+
def test_avg_step_time(metrics_file, max_avg_step_time_str, num_samples_str="10"):
50+
"""Tests if the average of the last N step times is below a maximum threshold."""
51+
max_avg_step_time = float(max_avg_step_time_str)
52+
num_samples = int(num_samples_str)
53+
metric_key = "perf/step_time_seconds"
54+
last_n_step_times = get_last_n_data(metrics_file, metric_key, num_samples)
55+
56+
if not last_n_step_times:
57+
raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.")
58+
59+
avg_last_n_step_time = sum(last_n_step_times) / len(last_n_step_times)
60+
61+
print(f"Found {len(last_n_step_times)} data points for '{metric_key}'.")
62+
print(f"Mean of last {len(last_n_step_times)} step times is {avg_last_n_step_time:.4f} s")
63+
64+
assert avg_last_n_step_time < max_avg_step_time, \
65+
f"Average step time {avg_last_n_step_time:.4f}s is not less than target {max_avg_step_time}s."
66+
print("Average step time test passed.")
67+
68+
69+
def test_avg_tflops(metrics_file, min_avg_tflops_str, num_samples_str="10"):
70+
"""Tests if the average of the last N TFLOPs/sec values is above a minimum threshold."""
71+
min_avg_tflops = float(min_avg_tflops_str)
72+
num_samples = int(num_samples_str)
73+
metric_key = "perf/per_device_tflops_per_sec"
74+
75+
last_n_tflops = get_last_n_data(metrics_file, metric_key, num_samples)
76+
77+
if not last_n_tflops:
78+
raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.")
79+
80+
avg_last_n_tflops = sum(last_n_tflops) / len(last_n_tflops)
81+
82+
print(f"Found {len(last_n_tflops)} data points for '{metric_key}'.")
83+
print(f"Mean of last {len(last_n_tflops)} steps TFLOPs/sec is {avg_last_n_tflops:.2f}")
84+
85+
assert avg_last_n_tflops > min_avg_tflops, \
86+
f"Average TFLOPs/sec {avg_last_n_tflops:.2f} is not greater than target {min_avg_tflops}."
87+
print("Average TFLOPs/sec test passed.")
88+
89+
4990
def main(argv: Sequence[str]) -> None:
91+
if len(argv) < 2:
92+
print("Usage: python script.py <test_scenario> [test_vars...]")
93+
print("Available scenarios: final_loss, avg_step_time, avg_tflops")
94+
raise ValueError("Test scenario not specified.")
5095

5196
_, test_scenario, *test_vars = argv
5297

5398
if test_scenario == "final_loss":
54-
test_final_loss(*test_vars)
99+
if len(test_vars) < 2:
100+
raise ValueError("Usage: final_loss <metrics_file> <target_loss> [num_samples]")
101+
metrics_file, target_loss, *num_samples_opt = test_vars
102+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
103+
test_final_loss(metrics_file, target_loss, num_samples)
104+
elif test_scenario == "avg_step_time":
105+
if len(test_vars) < 2:
106+
raise ValueError("Usage: avg_step_time <metrics_file> <max_avg_step_time> [num_samples]")
107+
metrics_file, max_avg_step_time, *num_samples_opt = test_vars
108+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
109+
test_avg_step_time(metrics_file, max_avg_step_time, num_samples)
110+
elif test_scenario == "avg_tflops":
111+
if len(test_vars) < 2:
112+
raise ValueError("Usage: avg_tflops <metrics_file> <min_avg_tflops> [num_samples]")
113+
metrics_file, min_avg_tflops, *num_samples_opt = test_vars
114+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
115+
test_avg_tflops(metrics_file, min_avg_tflops, num_samples)
55116
else:
56-
raise ValueError(f"Unrecognized test_scenario {test_scenario}")
117+
raise ValueError(f"Unrecognized test_scenario '{test_scenario}'. Available: final_loss, avg_step_time, avg_tflops")
57118

58119

59120
if __name__ == "__main__":
60-
app.run(main)
121+
app.run(main)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
2727
orbax-checkpoint==0.10.3
2828
tokenizers==0.21.0
29-
huggingface_hub==0.24.7
29+
huggingface_hub==0.30.2
3030
transformers==4.48.1
3131
einops==0.8.0
3232
sentencepiece

requirements_with_jax_stable_stack.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ftfy
88
git+https://github.com/mlperf/logging.git
99
google-cloud-storage==2.17.0
1010
grain-nightly==0.0.10
11-
huggingface_hub==0.24.7
11+
huggingface_hub==0.30.2
1212
jax>=0.4.30
1313
jaxlib>=0.4.30
1414
Jinja2

src/maxdiffusion/configs/base14.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: True
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 10000000000 # Flushes Tensorboard

src/maxdiffusion/configs/base21.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: True
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 10000000000 # Flushes Tensorboard

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: False
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 10000000000 # Flushes Tensorboard

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ run_name: ''
1818
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
21+
22+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23+
write_timing_metrics: True
24+
2125
gcs_metrics: False
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: False
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 100

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: False
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 100

src/maxdiffusion/configs/base_xl.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ run_name: ''
1818
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
21+
22+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23+
write_timing_metrics: True
24+
2125
gcs_metrics: False
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False

0 commit comments

Comments
 (0)