Skip to content

Commit c029bdd

Browse files
Try with new image that works
1 parent 1b5b8c4 commit c029bdd

1 file changed

Lines changed: 26 additions & 152 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 26 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -24,168 +24,42 @@ on:
2424
workflow_dispatch:
2525

2626
jobs:
27-
# maxtext_workload:
28-
# name: "Run MaxText Workload"
29-
# # IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
30-
# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
31-
# container:
32-
# image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
33-
# steps:
34-
# - name: Checkout MaxText Repo
35-
# uses: actions/checkout@v4
36-
# with:
37-
# repository: AI-Hypercomputer/maxtext
38-
# path: maxtext
39-
# ref: rbierneni-test-gpu-run
40-
41-
# - name: Print dependencies
42-
# run: |
43-
# pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
44-
# pip install -U transformer-engine[jax]==2.6.0
45-
# # pip uninstall -y tensorflow
46-
# # pip install tensorflow-cpu
47-
# pip freeze
48-
49-
# - name: Run MaxText Training
50-
# run: |
51-
# # This command is adapted from your DAG for a single-slice configuration.
52-
# cd maxtext && \
53-
# pip install . --no-dependencies
54-
55-
# export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
56-
# export TF_FORCE_GPU_ALLOW_GROWTH=true
57-
58-
# python3 -m MaxText.train MaxText/configs/base.yml \
59-
# steps=2 \
60-
# enable_checkpointing=false \
61-
# attention=cudnn_flash_te \
62-
# dataset_type=synthetic \
63-
# run_name=rbierneni-test-maxtext-gpu \
64-
# base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
65-
66-
# STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
67-
maxdiffusion_workload:
68-
name: "Run MaxDiffusion Workload"
27+
maxtext_workload:
28+
name: "Run MaxText Workload"
6929
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
7030
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
7131
container:
72-
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0
32+
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest
7333
steps:
74-
- name: Checkout Repository
34+
- name: Checkout MaxText Repo
7535
uses: actions/checkout@v4
76-
77-
- name: Check Host CUDA and GPU Environment
78-
run: |
79-
echo "--- Checking NVIDIA driver and supported CUDA version ---"
80-
nvidia-smi || echo "nvidia-smi command not found. No GPU or NVIDIA driver detected."
81-
82-
echo ""
83-
echo "--- Checking for default CUDA toolkit installation ---"
84-
ls -l /usr/local/ | grep cuda || echo "No default CUDA toolkit found in /usr/local/"
85-
86-
echo ""
87-
echo "--- Checking dynamic linker library path ---"
88-
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-'Not Set'}"
89-
36+
with:
37+
repository: AI-Hypercomputer/maxtext
38+
path: maxtext
39+
ref: rbierneni-test-gpu-run
40+
9041
- name: Print dependencies
9142
run: |
9243
pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
93-
# pip install transformer_engine[jax]==2.4.0
94-
# pip install -U transformer-engine[jax]==2.6.0
95-
# pip uninstall -y transformer-engine-cu12
96-
# pip install transformer-engine-cu13
44+
pip install -U transformer-engine[jax]==2.6.0
9745
# pip uninstall -y tensorflow
9846
# pip install tensorflow-cpu
9947
pip freeze
10048
101-
- name: Check devices
102-
run: |
103-
python -c "import jax; print(jax.devices())"
104-
105-
- name: Run Conflict Verification Script
49+
- name: Run MaxText Training
10650
run: |
107-
# This command creates the file inside the runner
108-
cat <<'EOF' > verify_conflict.py
109-
print("--- PyTorch vs. JAX Conflict Test ---")
110-
111-
print("\nStep 1: Attempting to import torch...")
112-
try:
113-
import torch
114-
print(f"Successfully imported torch version: {torch.__version__}")
115-
print(f"Is PyTorch using CUDA? -> {torch.cuda.is_available()}")
116-
except Exception as e:
117-
print(f"Failed to import torch: {e}")
118-
119-
print("\nStep 2: Now, attempting to initialize JAX...")
120-
try:
121-
import jax
122-
devices = jax.devices()
123-
print("\n--- RESULT: SUCCESS ---")
124-
print(f"JAX initialized correctly and found devices: {devices}")
125-
except Exception as e:
126-
print("\n--- RESULT: FAILURE ---")
127-
print("JAX failed to initialize after PyTorch was imported.")
128-
print(f"JAX Error: {e}")
129-
EOF
130-
131-
# Now that the file exists, this command will work
132-
python verify_conflict.py
133-
134-
# - name: Run MaxDiffusion Training
135-
# run: |
136-
# # This command is adapted from your DAG for a single-slice configuration.
137-
# NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \
138-
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
139-
# hardware=gpu \
140-
# train_new_unet=true \
141-
# train_text_encoder=false \
142-
# cache_latents_text_encoder_outputs=true \
143-
# per_device_batch_size=1 \
144-
# attention=dot_product \
145-
# activations_dtype=bfloat16 \
146-
# weights_dtype=bfloat16 \
147-
# max_train_steps=200 \
148-
# enable_profiler=True \
149-
# run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
150-
# output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
151-
152-
# jobs:
153-
# build:
154-
# strategy:
155-
# fail-fast: false
156-
# matrix:
157-
# tpu-type: ["v5p-8"]
158-
# name: "TPU test (${{ matrix.tpu-type }})"
159-
# runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
160-
# steps:
161-
# - uses: actions/checkout@v4
162-
# - name: Set up Python 3.12
163-
# uses: actions/setup-python@v5
164-
# with:
165-
# python-version: '3.12'
166-
# - name: Install dependencies
167-
# run: |
168-
# pip install -e .
169-
# pip uninstall jax jaxlib libtpu-nightly libtpu -y
170-
# bash setup.sh MODE=stable
171-
# export PATH=$PATH:$HOME/.local/bin
172-
# pip install ruff
173-
# pip install isort
174-
# pip install pytest
175-
# - name: Analysing the code with ruff
176-
# run: |
177-
# ruff check .
178-
# - name: version check
179-
# run: |
180-
# python --version
181-
# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
182-
# - name: PyTest
183-
# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
184-
# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
185-
# add_pull_ready:
186-
# if: github.ref != 'refs/heads/main'
187-
# permissions:
188-
# checks: read
189-
# pull-requests: write
190-
# needs: build
191-
# uses: ./.github/workflows/AddLabel.yml
51+
# This command is adapted from your DAG for a single-slice configuration.
52+
cd maxtext && \
53+
pip install . --no-dependencies
54+
55+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
56+
export TF_FORCE_GPU_ALLOW_GROWTH=true
57+
export NVTE_FUSED_ATTN=1
58+
59+
python3 -m MaxText.train MaxText/configs/base.yml \
60+
steps=5 \
61+
enable_checkpointing=false \
62+
attention=cudnn_flash_te \
63+
dataset_type=synthetic \
64+
run_name=rbierneni-test-maxtext-gpu \
65+
base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}

0 commit comments

Comments
 (0)