Skip to content

Commit 4ac11bb

Browse files
Test with te-cu13 package
1 parent 67141da commit 4ac11bb

1 file changed

Lines changed: 36 additions & 34 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,44 +24,44 @@ on:
2424
workflow_dispatch:
2525

2626
jobs:
27-
maxtext_workload:
28-
name: "Run MaxText Workload"
29-
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
30-
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
31-
container:
32-
image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
33-
steps:
34-
- name: Checkout MaxText Repo
35-
uses: actions/checkout@v4
36-
with:
37-
repository: AI-Hypercomputer/maxtext
38-
path: maxtext
39-
ref: rbierneni-test-gpu-run
27+
# maxtext_workload:
28+
# name: "Run MaxText Workload"
29+
# # IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
30+
# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
31+
# container:
32+
# image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
33+
# steps:
34+
# - name: Checkout MaxText Repo
35+
# uses: actions/checkout@v4
36+
# with:
37+
# repository: AI-Hypercomputer/maxtext
38+
# path: maxtext
39+
# ref: rbierneni-test-gpu-run
4040

41-
- name: Print dependencies
42-
run: |
43-
pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
44-
pip install -U transformer-engine[jax]==2.6.0
45-
# pip uninstall -y tensorflow
46-
# pip install tensorflow-cpu
47-
pip freeze
41+
# - name: Print dependencies
42+
# run: |
43+
# pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
44+
# pip install -U transformer-engine[jax]==2.6.0
45+
# # pip uninstall -y tensorflow
46+
# # pip install tensorflow-cpu
47+
# pip freeze
4848

49-
- name: Run MaxText Training
50-
run: |
51-
# This command is adapted from your DAG for a single-slice configuration.
52-
cd maxtext && \
53-
pip install . --no-dependencies
49+
# - name: Run MaxText Training
50+
# run: |
51+
# # This command is adapted from your DAG for a single-slice configuration.
52+
# cd maxtext && \
53+
# pip install . --no-dependencies
5454

55-
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
56-
export TF_FORCE_GPU_ALLOW_GROWTH=true
55+
# export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
56+
# export TF_FORCE_GPU_ALLOW_GROWTH=true
5757

58-
python3 -m MaxText.train MaxText/configs/base.yml \
59-
steps=2 \
60-
enable_checkpointing=false \
61-
attention=cudnn_flash_te \
62-
dataset_type=synthetic \
63-
run_name=rbierneni-test-maxtext-gpu \
64-
base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
58+
# python3 -m MaxText.train MaxText/configs/base.yml \
59+
# steps=2 \
60+
# enable_checkpointing=false \
61+
# attention=cudnn_flash_te \
62+
# dataset_type=synthetic \
63+
# run_name=rbierneni-test-maxtext-gpu \
64+
# base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
6565

6666
# STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
6767
maxdiffusion_workload:
@@ -78,6 +78,8 @@ jobs:
7878
run: |
7979
# pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
8080
pip install -U transformer-engine[jax]==2.6.0
81+
pip uninstall -y transformer-engine-cu12
82+
pip install transformer-engine-cu13
8183
# pip uninstall -y tensorflow
8284
# pip install tensorflow-cpu
8385
pip freeze

0 commit comments

Comments
 (0)