|
24 | 24 | workflow_dispatch: |
25 | 25 |
|
26 | 26 | jobs: |
27 | | - # maxtext_workload: |
28 | | - # name: "Run MaxText Workload" |
29 | | - # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
30 | | - # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
31 | | - # container: |
32 | | - # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest |
33 | | - # steps: |
34 | | - # - name: Checkout MaxText Repo |
35 | | - # uses: actions/checkout@v4 |
36 | | - # with: |
37 | | - # repository: AI-Hypercomputer/maxtext |
38 | | - # path: maxtext |
39 | | - # ref: rbierneni-test-gpu-run |
40 | | - |
41 | | - # - name: Print dependencies |
42 | | - # run: | |
43 | | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 |
44 | | - # pip install -U transformer-engine[jax]==2.6.0 |
45 | | - # # pip uninstall -y tensorflow |
46 | | - # # pip install tensorflow-cpu |
47 | | - # pip freeze |
48 | | - |
49 | | - # - name: Run MaxText Training |
50 | | - # run: | |
51 | | - # # This command is adapted from your DAG for a single-slice configuration. |
52 | | - # cd maxtext && \ |
53 | | - # pip install . --no-dependencies |
54 | | - |
55 | | - # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 |
56 | | - # export TF_FORCE_GPU_ALLOW_GROWTH=true |
57 | | - |
58 | | - # python3 -m MaxText.train MaxText/configs/base.yml \ |
59 | | - # steps=2 \ |
60 | | - # enable_checkpointing=false \ |
61 | | - # attention=cudnn_flash_te \ |
62 | | - # dataset_type=synthetic \ |
63 | | - # run_name=rbierneni-test-maxtext-gpu \ |
64 | | - # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} |
65 | | - |
66 | | - # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD |
67 | | - maxdiffusion_workload: |
68 | | - name: "Run MaxDiffusion Workload" |
| 27 | + maxtext_workload: |
| 28 | + name: "Run MaxText Workload" |
69 | 29 | # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
70 | 30 | runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
71 | 31 | container: |
72 | | - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0 |
| 32 | + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest |
73 | 33 | steps: |
74 | | - - name: Checkout Repository |
| 34 | + - name: Checkout MaxText Repo |
75 | 35 | uses: actions/checkout@v4 |
76 | | - |
77 | | - - name: Check Host CUDA and GPU Environment |
78 | | - run: | |
79 | | - echo "--- Checking NVIDIA driver and supported CUDA version ---" |
80 | | - nvidia-smi || echo "nvidia-smi command not found. No GPU or NVIDIA driver detected." |
81 | | - |
82 | | - echo "" |
83 | | - echo "--- Checking for default CUDA toolkit installation ---" |
84 | | - ls -l /usr/local/ | grep cuda || echo "No default CUDA toolkit found in /usr/local/" |
85 | | - |
86 | | - echo "" |
87 | | - echo "--- Checking dynamic linker library path ---" |
88 | | - echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-'Not Set'}" |
89 | | -
|
| 36 | + with: |
| 37 | + repository: AI-Hypercomputer/maxtext |
| 38 | + path: maxtext |
| 39 | + ref: rbierneni-test-gpu-run |
| 40 | + |
90 | 41 | - name: Print dependencies |
91 | 42 | run: | |
92 | 43 | pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 |
93 | | - # pip install transformer_engine[jax]==2.4.0 |
94 | | - # pip install -U transformer-engine[jax]==2.6.0 |
95 | | - # pip uninstall -y transformer-engine-cu12 |
96 | | - # pip install transformer-engine-cu13 |
| 44 | + pip install -U transformer-engine[jax]==2.6.0 |
97 | 45 | # pip uninstall -y tensorflow |
98 | 46 | # pip install tensorflow-cpu |
99 | 47 | pip freeze |
100 | 48 |
|
101 | | - - name: Check devices |
102 | | - run: | |
103 | | - python -c "import jax; print(jax.devices())" |
104 | | - |
105 | | - - name: Run Conflict Verification Script |
| 49 | + - name: Run MaxText Training |
106 | 50 | run: | |
107 | | - # This command creates the file inside the runner |
108 | | - cat <<'EOF' > verify_conflict.py |
109 | | - print("--- PyTorch vs. JAX Conflict Test ---") |
110 | | -
|
111 | | - print("\nStep 1: Attempting to import torch...") |
112 | | - try: |
113 | | - import torch |
114 | | - print(f"Successfully imported torch version: {torch.__version__}") |
115 | | - print(f"Is PyTorch using CUDA? -> {torch.cuda.is_available()}") |
116 | | - except Exception as e: |
117 | | - print(f"Failed to import torch: {e}") |
118 | | -
|
119 | | - print("\nStep 2: Now, attempting to initialize JAX...") |
120 | | - try: |
121 | | - import jax |
122 | | - devices = jax.devices() |
123 | | - print("\n--- RESULT: SUCCESS ---") |
124 | | - print(f"JAX initialized correctly and found devices: {devices}") |
125 | | - except Exception as e: |
126 | | - print("\n--- RESULT: FAILURE ---") |
127 | | - print("JAX failed to initialize after PyTorch was imported.") |
128 | | - print(f"JAX Error: {e}") |
129 | | - EOF |
130 | | -
|
131 | | - # Now that the file exists, this command will work |
132 | | - python verify_conflict.py |
133 | | -
|
134 | | - # - name: Run MaxDiffusion Training |
135 | | - # run: | |
136 | | - # # This command is adapted from your DAG for a single-slice configuration. |
137 | | - # NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \ |
138 | | - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
139 | | - # hardware=gpu \ |
140 | | - # train_new_unet=true \ |
141 | | - # train_text_encoder=false \ |
142 | | - # cache_latents_text_encoder_outputs=true \ |
143 | | - # per_device_batch_size=1 \ |
144 | | - # attention=dot_product \ |
145 | | - # activations_dtype=bfloat16 \ |
146 | | - # weights_dtype=bfloat16 \ |
147 | | - # max_train_steps=200 \ |
148 | | - # enable_profiler=True \ |
149 | | - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ |
150 | | - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} |
151 | | - |
152 | | -# jobs: |
153 | | -# build: |
154 | | -# strategy: |
155 | | -# fail-fast: false |
156 | | -# matrix: |
157 | | -# tpu-type: ["v5p-8"] |
158 | | -# name: "TPU test (${{ matrix.tpu-type }})" |
159 | | -# runs-on: ["self-hosted","${{ matrix.tpu-type }}"] |
160 | | -# steps: |
161 | | -# - uses: actions/checkout@v4 |
162 | | -# - name: Set up Python 3.12 |
163 | | -# uses: actions/setup-python@v5 |
164 | | -# with: |
165 | | -# python-version: '3.12' |
166 | | -# - name: Install dependencies |
167 | | -# run: | |
168 | | -# pip install -e . |
169 | | -# pip uninstall jax jaxlib libtpu-nightly libtpu -y |
170 | | -# bash setup.sh MODE=stable |
171 | | -# export PATH=$PATH:$HOME/.local/bin |
172 | | -# pip install ruff |
173 | | -# pip install isort |
174 | | -# pip install pytest |
175 | | -# - name: Analysing the code with ruff |
176 | | -# run: | |
177 | | -# ruff check . |
178 | | -# - name: version check |
179 | | -# run: | |
180 | | -# python --version |
181 | | -# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets |
182 | | -# - name: PyTest |
183 | | -# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py |
184 | | -# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x |
185 | | -# add_pull_ready: |
186 | | -# if: github.ref != 'refs/heads/main' |
187 | | -# permissions: |
188 | | -# checks: read |
189 | | -# pull-requests: write |
190 | | -# needs: build |
191 | | -# uses: ./.github/workflows/AddLabel.yml |
| 51 | + # This command is adapted from your DAG for a single-slice configuration. |
| 52 | + cd maxtext && \ |
| 53 | + pip install . --no-dependencies |
| 54 | +
|
| 55 | + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 |
| 56 | + export TF_FORCE_GPU_ALLOW_GROWTH=true |
| 57 | + export NVTE_FUSED_ATTN=1 |
| 58 | +
|
| 59 | + python3 -m MaxText.train MaxText/configs/base.yml \ |
| 60 | + steps=5 \ |
| 61 | + enable_checkpointing=false \ |
| 62 | + attention=cudnn_flash_te \ |
| 63 | + dataset_type=synthetic \ |
| 64 | + run_name=rbierneni-test-maxtext-gpu \ |
| 65 | + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} |
0 commit comments