Skip to content

Commit aaf1dc9

Browse files
authored
Update UnitTests.yml
1 parent 6de4ff2 commit aaf1dc9

1 file changed

Lines changed: 6 additions & 36 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
1616
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
1717

18-
# This workflow will run a small SDXL training workload on a GPU runner.
1918

2019
# This workflow will run a small MaxText training workload on a GPU runner
21-
# by checking out the MaxText repo inside the MaxDiffusion environment.
20+
# using a custom Docker image with all dependencies pre-installed.
2221

23-
name: MaxText Workload on MaxDiffusion Runner
22+
name: MaxText Custom Image Workload
2423

2524
on:
2625
pull_request:
@@ -33,52 +32,23 @@ jobs:
3332
name: "Run MaxText Training Workload"
3433
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
3534
container:
36-
# Using the MaxDiffusion container as requested
37-
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1
35+
# Use your newly built custom image
36+
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest
3837

3938
steps:
40-
- name: Checkout MaxText Repository
41-
uses: actions/checkout@v4
42-
with:
43-
repository: 'AI-Hypercomputer/maxtext'
44-
ref: 'main'
45-
path: 'maxtext' # Clone it into a 'maxtext' subdirectory
46-
47-
- name: Install Dependencies
48-
working-directory: ./maxtext # Run all subsequent commands inside the new directory
49-
run: |
50-
# Uninstall full tensorflow to prevent GPU conflicts with JAX
51-
pip uninstall -y tensorflow
52-
# Install the CPU-only version for data loading
53-
pip install tensorflow-cpu
54-
# Install MaxText's dependencies
55-
pip install -r requirements.txt
56-
# Install the MaxText package itself
57-
pip install .
58-
59-
- name: List Installed Libraries
60-
working-directory: ./maxtext
61-
run: |
62-
echo "--- Installed Python packages ---"
63-
pip freeze
64-
6539
- name: Run MaxText Training
66-
working-directory: ./maxtext
6740
env:
68-
# Set the correct framework for Transformer Engine
6941
NVTE_FRAMEWORK: jax
70-
# Prevent TensorFlow from grabbing all GPU memory
7142
TF_FORCE_GPU_ALLOW_GROWTH: "true"
7243
run: |
73-
# Run the main training script with a base configuration
44+
# The working directory is /deps, so this path is correct.
7445
python MaxText/train.py MaxText/configs/base.yml \
7546
run_name="maxtext-ci-test-${{ github.run_id }}" \
7647
steps=5 \
7748
enable_checkpointing=false \
78-
attention='cudnn_flash_te' \
49+
attention='cudnet_flash_te' \
7950
dataset_type='synthetic'
8051
81-
8252
# name: SDXL Workload Training on GPU
8353

8454
# on:

0 commit comments

Comments
 (0)