Skip to content

Commit 4243d53

Browse files
Test with tensorflow-cpu
1 parent dfe8fc5 commit 4243d53

1 file changed

Lines changed: 64 additions & 60 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -24,77 +24,81 @@ on:
2424
workflow_dispatch:
2525

2626
jobs:
27-
maxtext_workload:
28-
name: "Run MaxText Workload"
29-
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
30-
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
31-
container:
32-
image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
33-
steps:
34-
- name: Checkout MaxText Repo
35-
uses: actions/checkout@v4
36-
with:
37-
repository: AI-Hypercomputer/maxtext
38-
path: maxtext
39-
ref: rbierneni-test-gpu-run
40-
41-
- name: Print dependencies
42-
run: |
43-
pip uninstall -y transformer-engine transformer-engine-jax
44-
pip install -U transformer-engine[jax]==2.6.0
45-
pip freeze
46-
47-
- name: Run MaxText Training
48-
run: |
49-
# This command is adapted from your DAG for a single-slice configuration.
50-
cd maxtext && \
51-
pip install . --no-dependencies
52-
53-
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
54-
export TF_FORCE_GPU_ALLOW_GROWTH=true
55-
56-
python3 -m MaxText.train MaxText/configs/base.yml \
57-
steps=2 \
58-
enable_checkpointing=false \
59-
attention=cudnn_flash_te \
60-
dataset_type=synthetic \
61-
run_name=rbierneni-test-maxtext-gpu \
62-
base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
63-
64-
# # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
65-
# maxdiffusion_workload:
66-
# name: "Run MaxDiffusion Workload"
27+
# maxtext_workload:
28+
# name: "Run MaxText Workload"
6729
# # IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
6830
# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
6931
# container:
70-
# image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu
32+
# image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
7133
# steps:
72-
# - name: Checkout Repository
34+
# - name: Checkout MaxText Repo
7335
# uses: actions/checkout@v4
74-
36+
# with:
37+
# repository: AI-Hypercomputer/maxtext
38+
# path: maxtext
39+
# ref: rbierneni-test-gpu-run
40+
7541
# - name: Print dependencies
7642
# run: |
77-
# # pip uninstall -y transformer-engine transformer-engine-jax
78-
# # pip install -U transformer-engine[pytorch,jax]
43+
# pip uninstall -y transformer-engine transformer-engine-jax
44+
# pip install -U transformer-engine[jax]==2.6.0
45+
# pip uninstall -y tensorflow
46+
# pip install tensorflow-cpu
7947
# pip freeze
8048

81-
# - name: Run MaxDiffusion Training
49+
# - name: Run MaxText Training
8250
# run: |
8351
# # This command is adapted from your DAG for a single-slice configuration.
84-
# NVTE_FUSED_ATTN=1 pip install . && \
85-
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
86-
# hardware=gpu \
87-
# train_new_unet=true \
88-
# train_text_encoder=false \
89-
# cache_latents_text_encoder_outputs=true \
90-
# per_device_batch_size=1 \
91-
# attention=dot_product \
92-
# activations_dtype=bfloat16 \
93-
# weights_dtype=bfloat16 \
94-
# max_train_steps=200 \
95-
# enable_profiler=True \
96-
# run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
97-
# output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
52+
# cd maxtext && \
53+
# pip install . --no-dependencies
54+
55+
# export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
56+
# export TF_FORCE_GPU_ALLOW_GROWTH=true
57+
58+
# python3 -m MaxText.train MaxText/configs/base.yml \
59+
# steps=2 \
60+
# enable_checkpointing=false \
61+
# attention=cudnn_flash_te \
62+
# dataset_type=synthetic \
63+
# run_name=rbierneni-test-maxtext-gpu \
64+
# base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
65+
66+
# STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
67+
maxdiffusion_workload:
68+
name: "Run MaxDiffusion Workload"
69+
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
70+
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
71+
container:
72+
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu
73+
steps:
74+
- name: Checkout Repository
75+
uses: actions/checkout@v4
76+
77+
- name: Print dependencies
78+
run: |
79+
# pip uninstall -y transformer-engine transformer-engine-jax
80+
# pip install -U transformer-engine[pytorch,jax]
81+
pip uninstall -y tensorflow
82+
pip install tensorflow-cpu
83+
pip freeze
84+
85+
- name: Run MaxDiffusion Training
86+
run: |
87+
# This command is adapted from your DAG for a single-slice configuration.
88+
NVTE_FUSED_ATTN=1 pip install . && \
89+
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
90+
hardware=gpu \
91+
train_new_unet=true \
92+
train_text_encoder=false \
93+
cache_latents_text_encoder_outputs=true \
94+
per_device_batch_size=1 \
95+
attention=dot_product \
96+
activations_dtype=bfloat16 \
97+
weights_dtype=bfloat16 \
98+
max_train_steps=200 \
99+
enable_profiler=True \
100+
run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
101+
output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
98102
99103
# jobs:
100104
# build:

0 commit comments

Comments
 (0)