Skip to content

Commit d9749e9

Browse files
Merge branch 'main' into wan_vae_debugging
2 parents 66146b9 + 7284ca0 commit d9749e9

26 files changed

Lines changed: 1094 additions & 46 deletions

.github/workflows/UploadDockerImages.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,16 @@ jobs:
3535
- name: build maxdiffusion jax nightly image
3636
run: |
3737
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
38+
39+
build-gpu-image:
40+
runs-on: ["self-hosted", "e2", "cpu"]
41+
steps:
42+
- uses: actions/checkout@v3
43+
- name: Cleanup old docker images
44+
run: docker system prune --all --force
45+
- name: build maxdiffusion jax stable stack gpu image
46+
run: |
47+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu
48+
- name: build maxdiffusion jax nightly image
49+
run: |
50+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu

.github/workflows/build_and_upload_images.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ for ARGUMENT in "$@"; do
3434
echo "$KEY"="$VALUE"
3535
done
3636

37+
export DEVICE="${DEVICE:-tpu}"
38+
3739
if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] ; then
3840
echo "You must set CLOUD_IMAGE_NAME, PROJECT and MODE"
3941
exit 1
4042
fi
4143

4244
gcloud auth configure-docker us-docker.pkg.dev --quiet
43-
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE
45+
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE=$DEVICE
4446
image_date=$(date +%Y-%m-%d)
4547

4648
# Upload only dependencies image

end_to_end/tpu/eval_assert.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
limitations under the License.
1515
"""
1616

17+
"""
18+
Example to run
19+
python end_to_end/tpu/eval_assert.py avg_tflops metrics.txt 100
20+
python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
21+
python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
22+
"""
23+
24+
25+
1726
# pylint: skip-file
1827
"""Reads and asserts over target values"""
1928
from absl import app
@@ -34,26 +43,89 @@ def get_last_n_data(metrics_file, target, n=10):
3443
return last_n_data
3544

3645

37-
def test_final_loss(metrics_file, target_loss):
46+
def test_final_loss(metrics_file, target_loss, num_samples_str="10"):
3847
target_loss = float(target_loss)
48+
num_samples = int(num_samples_str)
3949
with open(metrics_file, "r", encoding="utf8") as _:
40-
use_last_n_data = 10
41-
last_n_data = get_last_n_data(metrics_file, "learning/loss", use_last_n_data)
50+
last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples)
4251
avg_last_n_data = sum(last_n_data) / len(last_n_data)
4352
print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}")
4453
print(f"Target loss is {target_loss}")
4554
assert avg_last_n_data < target_loss
4655
print("Final loss test passed.")
4756

4857

58+
def test_avg_step_time(metrics_file, max_avg_step_time_str, num_samples_str="10"):
59+
"""Tests if the average of the last N step times is below a maximum threshold."""
60+
max_avg_step_time = float(max_avg_step_time_str)
61+
num_samples = int(num_samples_str)
62+
metric_key = "perf/step_time_seconds"
63+
last_n_step_times = get_last_n_data(metrics_file, metric_key, num_samples)
64+
65+
if not last_n_step_times:
66+
raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.")
67+
68+
avg_last_n_step_time = sum(last_n_step_times) / len(last_n_step_times)
69+
70+
print(f"Found {len(last_n_step_times)} data points for '{metric_key}'.")
71+
print(f"Mean of last {len(last_n_step_times)} step times is {avg_last_n_step_time:.4f} s")
72+
73+
assert (
74+
avg_last_n_step_time < max_avg_step_time
75+
), f"Average step time {avg_last_n_step_time:.4f}s is not less than target {max_avg_step_time}s."
76+
print("Average step time test passed.")
77+
78+
79+
def test_avg_tflops(metrics_file, min_avg_tflops_str, num_samples_str="10"):
80+
"""Tests if the average of the last N TFLOPs/sec values is above a minimum threshold."""
81+
min_avg_tflops = float(min_avg_tflops_str)
82+
num_samples = int(num_samples_str)
83+
metric_key = "perf/per_device_tflops_per_sec"
84+
85+
last_n_tflops = get_last_n_data(metrics_file, metric_key, num_samples)
86+
87+
if not last_n_tflops:
88+
raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.")
89+
90+
avg_last_n_tflops = sum(last_n_tflops) / len(last_n_tflops)
91+
92+
print(f"Found {len(last_n_tflops)} data points for '{metric_key}'.")
93+
print(f"Mean of last {len(last_n_tflops)} steps TFLOPs/sec is {avg_last_n_tflops:.2f}")
94+
95+
assert (
96+
avg_last_n_tflops > min_avg_tflops
97+
), f"Average TFLOPs/sec {avg_last_n_tflops:.2f} is not greater than target {min_avg_tflops}."
98+
print("Average TFLOPs/sec test passed.")
99+
100+
49101
def main(argv: Sequence[str]) -> None:
102+
if len(argv) < 2:
103+
print("Usage: python script.py <test_scenario> [test_vars...]")
104+
print("Available scenarios: final_loss, avg_step_time, avg_tflops")
105+
raise ValueError("Test scenario not specified.")
50106

51107
_, test_scenario, *test_vars = argv
52108

53109
if test_scenario == "final_loss":
54-
test_final_loss(*test_vars)
110+
if len(test_vars) < 2:
111+
raise ValueError("Usage: final_loss <metrics_file> <target_loss> [num_samples]")
112+
metrics_file, target_loss, *num_samples_opt = test_vars
113+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
114+
test_final_loss(metrics_file, target_loss, num_samples)
115+
elif test_scenario == "avg_step_time":
116+
if len(test_vars) < 2:
117+
raise ValueError("Usage: avg_step_time <metrics_file> <max_avg_step_time> [num_samples]")
118+
metrics_file, max_avg_step_time, *num_samples_opt = test_vars
119+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
120+
test_avg_step_time(metrics_file, max_avg_step_time, num_samples)
121+
elif test_scenario == "avg_tflops":
122+
if len(test_vars) < 2:
123+
raise ValueError("Usage: avg_tflops <metrics_file> <min_avg_tflops> [num_samples]")
124+
metrics_file, min_avg_tflops, *num_samples_opt = test_vars
125+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
126+
test_avg_tflops(metrics_file, min_avg_tflops, num_samples)
55127
else:
56-
raise ValueError(f"Unrecognized test_scenario {test_scenario}")
128+
raise ValueError(f"Unrecognized test_scenario '{test_scenario}'. Available: final_loss, avg_step_time, avg_tflops")
57129

58130

59131
if __name__ == "__main__":

end_to_end/tpu/test_sdxl_training_loss.sh

100644100755
File mode changed.

maxdiffusion_gpu_dependencies.Dockerfile

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ RUN apt-get update && apt-get install -y google-cloud-sdk
2222
# Set environment variables for Google Cloud SDK
2323
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"
2424

25-
# Upgrade libcusprase to work with Jax
26-
RUN apt-get update && apt-get install -y libcusparse-12-3
25+
2726

2827
ARG MODE
2928
ENV ENV_MODE=$MODE
@@ -46,5 +45,4 @@ RUN ls .
4645
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
4746
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
4847

49-
5048
WORKDIR /deps

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ Pillow
1919
pylint
2020
pyink
2121
pytest==8.2.2
22-
tensorflow==2.17.0
22+
tensorflow>=2.17.0
2323
tensorflow-datasets>=4.9.6
2424
ruff>=0.1.5,<=0.2
2525
git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
2727
orbax-checkpoint==0.10.3
2828
tokenizers==0.21.0
29-
huggingface_hub==0.24.7
29+
huggingface_hub==0.30.2
3030
transformers==4.48.1
3131
einops==0.8.0
3232
sentencepiece

requirements_with_jax_stable_stack.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ftfy
88
git+https://github.com/mlperf/logging.git
99
google-cloud-storage==2.17.0
1010
grain-nightly==0.0.10
11-
huggingface_hub==0.24.7
11+
huggingface_hub==0.30.2
1212
jax>=0.4.30
1313
jaxlib>=0.4.30
1414
Jinja2

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
"filelock",
9898
"flax>=0.4.1",
9999
"hf-doc-builder>=0.3.0",
100-
"huggingface-hub==0.24.7",
100+
"huggingface-hub==0.30.0",
101101
"requests-mock==1.10.0",
102102
"importlib_metadata",
103103
"invisible-watermark>=0.2.0",

setup.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
5555
exit 1
5656
fi
5757

58+
# Install dependencies from requirements.txt first
59+
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
60+
5861
# Install JAX and JAXlib based on the specified mode
5962
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
6063
# Stable mode
@@ -78,7 +81,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
7881
pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
7982
fi
8083
export NVTE_FRAMEWORK=jax
81-
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
84+
pip3 install transformer_engine[jax]==2.1.0
8285
fi
8386

8487
elif [[ $MODE == "nightly" ]]; then
@@ -106,8 +109,5 @@ else
106109
exit 1
107110
fi
108111

109-
# Install dependencies from requirements.txt
110-
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
111-
112112
# Install maxdiffusion
113113
pip3 install -U . || echo "Failed to install maxdiffusion" >&2

src/maxdiffusion/configs/base14.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: True
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 10000000000 # Flushes Tensorboard

0 commit comments

Comments
 (0)