Skip to content

Commit 64f30f3

Browse files
Test if maxtext has same gpu issue
1 parent 37524bb commit 64f30f3

1 file changed

Lines changed: 52 additions & 26 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,40 +24,66 @@ on:
2424
workflow_dispatch:
2525

2626
jobs:
27-
# STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
28-
maxdiffusion_workload:
29-
name: "Run MaxDiffusion Workload"
27+
maxtext_workload:
28+
name: "Run MaxText Workload"
3029
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
3130
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
3231
container:
33-
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu
32+
image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
3433
steps:
35-
- name: Checkout Repository
34+
- name: Checkout MaxText Repo
3635
uses: actions/checkout@v4
36+
with:
37+
repository: AI-Hypercomputer/maxtext
38+
path: maxtext
3739

38-
- name: Print dependencies
39-
run: |
40-
# pip uninstall -y transformer-engine transformer-engine-jax
41-
# pip install -U transformer-engine[pytorch,jax]
42-
pip freeze
43-
44-
- name: Run MaxDiffusion Training
40+
- name: Run MaxText Training
4541
run: |
4642
# This command is adapted from your DAG for a single-slice configuration.
47-
NVTE_FUSED_ATTN=1 pip install . && \
48-
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
49-
hardware=gpu \
50-
train_new_unet=true \
51-
train_text_encoder=false \
52-
cache_latents_text_encoder_outputs=true \
53-
per_device_batch_size=1 \
54-
attention=dot_product \
55-
activations_dtype=bfloat16 \
56-
weights_dtype=bfloat16 \
57-
max_train_steps=200 \
58-
enable_profiler=True \
59-
run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
60-
output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
43+
cd maxtext && \
44+
pip install -e . --no-dependencies \
45+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true \
46+
python3 -m MaxText.train MaxText/configs/base.yml \
47+
steps=2 \
48+
enable_checkpointing=false \
49+
attention=dot_product \
50+
run_name=rbierneni-test-maxtext-gpu \
51+
base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
52+
53+
# # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
54+
# maxdiffusion_workload:
55+
# name: "Run MaxDiffusion Workload"
56+
# # IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
57+
# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
58+
# container:
59+
# image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu
60+
# steps:
61+
# - name: Checkout Repository
62+
# uses: actions/checkout@v4
63+
64+
# - name: Print dependencies
65+
# run: |
66+
# # pip uninstall -y transformer-engine transformer-engine-jax
67+
# # pip install -U transformer-engine[pytorch,jax]
68+
# pip freeze
69+
70+
# - name: Run MaxDiffusion Training
71+
# run: |
72+
# # This command is adapted from your DAG for a single-slice configuration.
73+
# NVTE_FUSED_ATTN=1 pip install . && \
74+
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
75+
# hardware=gpu \
76+
# train_new_unet=true \
77+
# train_text_encoder=false \
78+
# cache_latents_text_encoder_outputs=true \
79+
# per_device_batch_size=1 \
80+
# attention=dot_product \
81+
# activations_dtype=bfloat16 \
82+
# weights_dtype=bfloat16 \
83+
# max_train_steps=200 \
84+
# enable_profiler=True \
85+
# run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
86+
# output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
6187

6288
# jobs:
6389
# build:

0 commit comments

Comments
 (0)