Skip to content

Commit d37e0dd

Browse files
Test with new cuda13 images and TE 2.6.0
1 parent 4243d53 commit d37e0dd

1 file changed

Lines changed: 39 additions & 39 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,62 +24,62 @@ on:
2424
workflow_dispatch:
2525

2626
jobs:
27-
# maxtext_workload:
28-
# name: "Run MaxText Workload"
29-
# # IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
30-
# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
31-
# container:
32-
# image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
33-
# steps:
34-
# - name: Checkout MaxText Repo
35-
# uses: actions/checkout@v4
36-
# with:
37-
# repository: AI-Hypercomputer/maxtext
38-
# path: maxtext
39-
# ref: rbierneni-test-gpu-run
27+
maxtext_workload:
28+
name: "Run MaxText Workload"
29+
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
30+
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
31+
container:
32+
image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
33+
steps:
34+
- name: Checkout MaxText Repo
35+
uses: actions/checkout@v4
36+
with:
37+
repository: AI-Hypercomputer/maxtext
38+
path: maxtext
39+
ref: rbierneni-test-gpu-run
4040

41-
# - name: Print dependencies
42-
# run: |
43-
# pip uninstall -y transformer-engine transformer-engine-jax
44-
# pip install -U transformer-engine[jax]==2.6.0
45-
# pip uninstall -y tensorflow
46-
# pip install tensorflow-cpu
47-
# pip freeze
41+
- name: Print dependencies
42+
run: |
43+
pip uninstall -y transformer-engine transformer-engine-jax
44+
pip install -U transformer-engine[jax]==2.6.0
45+
# pip uninstall -y tensorflow
46+
# pip install tensorflow-cpu
47+
pip freeze
4848
49-
# - name: Run MaxText Training
50-
# run: |
51-
# # This command is adapted from your DAG for a single-slice configuration.
52-
# cd maxtext && \
53-
# pip install . --no-dependencies
49+
- name: Run MaxText Training
50+
run: |
51+
# This command is adapted from your DAG for a single-slice configuration.
52+
cd maxtext && \
53+
pip install . --no-dependencies
5454
55-
# export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
56-
# export TF_FORCE_GPU_ALLOW_GROWTH=true
55+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
56+
export TF_FORCE_GPU_ALLOW_GROWTH=true
5757
58-
# python3 -m MaxText.train MaxText/configs/base.yml \
59-
# steps=2 \
60-
# enable_checkpointing=false \
61-
# attention=cudnn_flash_te \
62-
# dataset_type=synthetic \
63-
# run_name=rbierneni-test-maxtext-gpu \
64-
# base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
58+
python3 -m MaxText.train MaxText/configs/base.yml \
59+
steps=2 \
60+
enable_checkpointing=false \
61+
attention=cudnn_flash_te \
62+
dataset_type=synthetic \
63+
run_name=rbierneni-test-maxtext-gpu \
64+
base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
6565
6666
# STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
6767
maxdiffusion_workload:
6868
name: "Run MaxDiffusion Workload"
6969
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
7070
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
7171
container:
72-
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu
72+
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev2_gpu
7373
steps:
7474
- name: Checkout Repository
7575
uses: actions/checkout@v4
7676

7777
- name: Print dependencies
7878
run: |
7979
# pip uninstall -y transformer-engine transformer-engine-jax
80-
# pip install -U transformer-engine[pytorch,jax]
81-
pip uninstall -y tensorflow
82-
pip install tensorflow-cpu
80+
pip install -U transformer-engine[jax]==2.6.0
81+
# pip uninstall -y tensorflow
82+
# pip install tensorflow-cpu
8383
pip freeze
8484
8585
- name: Run MaxDiffusion Training
@@ -92,7 +92,7 @@ jobs:
9292
train_text_encoder=false \
9393
cache_latents_text_encoder_outputs=true \
9494
per_device_batch_size=1 \
95-
attention=dot_product \
95+
attention=cudnn_flash_te \
9696
activations_dtype=bfloat16 \
9797
weights_dtype=bfloat16 \
9898
max_train_steps=200 \

0 commit comments

Comments
 (0)