1515# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
1616# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
1717
18- # This workflow will run a small SDXL training workload on a GPU runner.
1918
2019# This workflow will run a small MaxText training workload on a GPU runner
21- # by checking out the MaxText repo inside the MaxDiffusion environment .
20+ # using a custom Docker image with all dependencies pre-installed .
2221
23- name : MaxText Workload on MaxDiffusion Runner
22+ name : MaxText Custom Image Workload
2423
2524on :
2625 pull_request :
@@ -33,52 +32,23 @@ jobs:
3332 name : " Run MaxText Training Workload"
3433 runs-on : ["linux-x86-a3-megagpu-h100-8gpu"]
3534 container :
36- # Using the MaxDiffusion container as requested
37- image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/ gpu:jax0.7.2-cuda12.9-rev1
35+ # Use your newly built custom image
36+ image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext- gpu-custom:latest
3837
3938 steps :
40- - name : Checkout MaxText Repository
41- uses : actions/checkout@v4
42- with :
43- repository : ' AI-Hypercomputer/maxtext'
44- ref : ' main'
45- path : ' maxtext' # Clone it into a 'maxtext' subdirectory
46-
47- - name : Install Dependencies
48- working-directory : ./maxtext # Run all subsequent commands inside the new directory
49- run : |
50- # Uninstall full tensorflow to prevent GPU conflicts with JAX
51- pip uninstall -y tensorflow
52- # Install the CPU-only version for data loading
53- pip install tensorflow-cpu
54- # Install MaxText's dependencies
55- pip install -r requirements.txt
56- # Install the MaxText package itself
57- pip install .
58-
59- - name : List Installed Libraries
60- working-directory : ./maxtext
61- run : |
62- echo "--- Installed Python packages ---"
63- pip freeze
64-
6539 - name : Run MaxText Training
66- working-directory : ./maxtext
6740 env :
68- # Set the correct framework for Transformer Engine
6941 NVTE_FRAMEWORK : jax
70- # Prevent TensorFlow from grabbing all GPU memory
7142 TF_FORCE_GPU_ALLOW_GROWTH : " true"
7243 run : |
73- # Run the main training script with a base configuration
44+ # The working directory is /deps, so this path is correct.
7445 python MaxText/train.py MaxText/configs/base.yml \
7546 run_name="maxtext-ci-test-${{ github.run_id }}" \
7647 steps=5 \
7748 enable_checkpointing=false \
78- attention='cudnn_flash_te ' \
49+ attention='cudnet_flash_te ' \
7950 dataset_type='synthetic'
8051
81-
8252# name: SDXL Workload Training on GPU
8353
8454# on:
0 commit comments