1515# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
1616# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
1717
18- name : Unit Test
18+ # This workflow will run a small FLUX training workload on a GPU runner.
19+
20+ name : FLUX Workload Training on GPU
1921
2022on :
2123 pull_request :
2426 workflow_dispatch :
2527
2628jobs :
27- # maxtext_workload:
28- # name: "Run MaxText Workload"
29- # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
30- # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
31- # container:
32- # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest
33- # steps:
34- # - name: Checkout MaxText Repo
35- # uses: actions/checkout@v4
36- # with:
37- # repository: AI-Hypercomputer/maxtext
38- # path: maxtext
39- # ref: rbierneni-test-gpu-run
40-
41- # - name: Print dependencies
42- # run: |
43- # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
44- # pip install -U transformer-engine[jax]==2.6.0
45- # # pip uninstall -y tensorflow
46- # # pip install tensorflow-cpu
47- # pip freeze
48-
49- # - name: Run MaxText Training
50- # run: |
51- # # This command is adapted from your DAG for a single-slice configuration.
52- # cd maxtext && \
53- # pip install . --no-dependencies
54-
55- # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
56- # export TF_FORCE_GPU_ALLOW_GROWTH=true
57-
58- # python3 -m MaxText.train MaxText/configs/base.yml \
59- # steps=2 \
60- # enable_checkpointing=false \
61- # attention=cudnn_flash_te \
62- # dataset_type=synthetic \
63- # run_name=rbierneni-test-maxtext-gpu \
64- # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
65-
66- # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
67- maxdiffusion_workload :
68- name : " Run MaxDiffusion Workload"
69- # IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
29+ flux_training_workload :
30+ name : " Run FLUX Training Workload"
31+ # IMPORTANT: Replace with the label for your specific GPU runner if different
7032 runs-on : ["linux-x86-a3-megagpu-h100-8gpu"]
7133 container :
72- image : gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0
34+ image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest
35+
7336 steps :
7437 - name : Checkout Repository
7538 uses : actions/checkout@v4
7639
77- - name : Check Host CUDA and GPU Environment
40+ - name : Install Dependencies
7841 run : |
79- echo "--- Checking NVIDIA driver and supported CUDA version ---"
80- nvidia-smi || echo "nvidia-smi command not found. No GPU or NVIDIA driver detected."
81-
82- echo ""
83- echo "--- Checking for default CUDA toolkit installation ---"
84- ls -l /usr/local/ | grep cuda || echo "No default CUDA toolkit found in /usr/local/"
85-
86- echo ""
87- echo "--- Checking dynamic linker library path ---"
88- echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-'Not Set'}"
89-
90- - name : Print dependencies
42+ pip install -r requirements.txt
43+ pip install --upgrade torch torchvision
44+ # Install the maxdiffusion package to make it available for execution
45+ pip install .
46+
47+ - name : List Installed Libraries
9148 run : |
92- # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
93- # pip install transformer_engine[jax]==2.4.0
94- # pip install -U transformer-engine[jax]==2.6.0
95- # pip uninstall -y transformer-engine-cu12
96- # pip install transformer-engine-cu13
97- # pip uninstall -y tensorflow
98- # pip install tensorflow-cpu
49+ echo "--- Installed Python packages ---"
9950 pip freeze
10051
101- - name : Run MaxDiffusion Training
52+ - name : Run FLUX Training
53+ env :
54+ NVTE_FRAMEWORK : jax
10255 run : |
103- # This command is adapted from your DAG for a single-slice configuration.
104- NVTE_FUSED_ATTN=1 pip install . && \
105- python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
56+ python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
57+ run_name="flux-ci-test-${{ github.run_id }}" \
58+ output_dir="/tmp/flux-output/" \
59+ max_train_steps=5 \
10660 hardware=gpu \
107- train_new_unet=true \
108- train_text_encoder=false \
109- cache_latents_text_encoder_outputs=true \
110- per_device_batch_size=1 \
111- attention=cudnn_flash_te \
112- activations_dtype=bfloat16 \
113- weights_dtype=bfloat16 \
114- max_train_steps=200 \
115- enable_profiler=True \
116- run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
117- output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
118-
119- # jobs:
120- # build:
121- # strategy:
122- # fail-fast: false
123- # matrix:
124- # tpu-type: ["v5p-8"]
125- # name: "TPU test (${{ matrix.tpu-type }})"
126- # runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
127- # steps:
128- # - uses: actions/checkout@v4
129- # - name: Set up Python 3.12
130- # uses: actions/setup-python@v5
131- # with:
132- # python-version: '3.12'
133- # - name: Install dependencies
134- # run: |
135- # pip install -e .
136- # pip uninstall jax jaxlib libtpu-nightly libtpu -y
137- # bash setup.sh MODE=stable
138- # export PATH=$PATH:$HOME/.local/bin
139- # pip install ruff
140- # pip install isort
141- # pip install pytest
142- # - name: Analysing the code with ruff
143- # run: |
144- # ruff check .
145- # - name: version check
146- # run: |
147- # python --version
148- # pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
149- # - name: PyTest
150- # run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
151- # HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
152- # add_pull_ready:
153- # if: github.ref != 'refs/heads/main'
154- # permissions:
155- # checks: read
156- # pull-requests: write
157- # needs: build
158- # uses: ./.github/workflows/AddLabel.yml
61+ attention="cudnn_flash_te"
0 commit comments