|
24 | 24 | workflow_dispatch: |
25 | 25 |
|
26 | 26 | jobs: |
27 | | - # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD |
28 | | - maxdiffusion_workload: |
29 | | - name: "Run MaxDiffusion Workload" |
| 27 | + maxtext_workload: |
| 28 | + name: "Run MaxText Workload" |
30 | 29 | # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
31 | 30 | runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
32 | 31 | container: |
33 | | - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu |
| 32 | + image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest |
34 | 33 | steps: |
35 | | - - name: Checkout Repository |
| 34 | + - name: Checkout MaxText Repo |
36 | 35 | uses: actions/checkout@v4 |
| 36 | + with: |
| 37 | + repository: AI-Hypercomputer/maxtext |
| 38 | + path: maxtext |
37 | 39 |
|
38 | | - - name: Print dependencies |
39 | | - run: | |
40 | | - # pip uninstall -y transformer-engine transformer-engine-jax |
41 | | - # pip install -U transformer-engine[pytorch,jax] |
42 | | - pip freeze |
43 | | -
|
44 | | - - name: Run MaxDiffusion Training |
| 40 | + - name: Run MaxText Training |
45 | 41 | run: | |
46 | 42 | # This command is adapted from your DAG for a single-slice configuration. |
47 | | - NVTE_FUSED_ATTN=1 pip install . && \ |
48 | | - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
49 | | - hardware=gpu \ |
50 | | - train_new_unet=true \ |
51 | | - train_text_encoder=false \ |
52 | | - cache_latents_text_encoder_outputs=true \ |
53 | | - per_device_batch_size=1 \ |
54 | | - attention=dot_product \ |
55 | | - activations_dtype=bfloat16 \ |
56 | | - weights_dtype=bfloat16 \ |
57 | | - max_train_steps=200 \ |
58 | | - enable_profiler=True \ |
59 | | - run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ |
60 | | - output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} |
| 43 | + cd maxtext && \ |
| 44 | + pip install -e . --no-dependencies \ |
| 45 | + XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true \ |
| 46 | + python3 -m MaxText.train MaxText/configs/base.yml \ |
| 47 | + steps=2 \ |
| 48 | + enable_checkpointing=false \ |
| 49 | + attention=dot_product \ |
| 50 | + run_name=rbierneni-test-maxtext-gpu \ |
| 51 | + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} |
| 52 | +
|
| 53 | + # # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD |
| 54 | + # maxdiffusion_workload: |
| 55 | + # name: "Run MaxDiffusion Workload" |
| 56 | + # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
| 57 | + # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
| 58 | + # container: |
| 59 | + # image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu |
| 60 | + # steps: |
| 61 | + # - name: Checkout Repository |
| 62 | + # uses: actions/checkout@v4 |
| 63 | + |
| 64 | + # - name: Print dependencies |
| 65 | + # run: | |
| 66 | + # # pip uninstall -y transformer-engine transformer-engine-jax |
| 67 | + # # pip install -U transformer-engine[pytorch,jax] |
| 68 | + # pip freeze |
| 69 | + |
| 70 | + # - name: Run MaxDiffusion Training |
| 71 | + # run: | |
| 72 | + # # This command is adapted from your DAG for a single-slice configuration. |
| 73 | + # NVTE_FUSED_ATTN=1 pip install . && \ |
| 74 | + # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
| 75 | + # hardware=gpu \ |
| 76 | + # train_new_unet=true \ |
| 77 | + # train_text_encoder=false \ |
| 78 | + # cache_latents_text_encoder_outputs=true \ |
| 79 | + # per_device_batch_size=1 \ |
| 80 | + # attention=dot_product \ |
| 81 | + # activations_dtype=bfloat16 \ |
| 82 | + # weights_dtype=bfloat16 \ |
| 83 | + # max_train_steps=200 \ |
| 84 | + # enable_profiler=True \ |
| 85 | + # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ |
| 86 | + # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} |
61 | 87 |
|
62 | 88 | # jobs: |
63 | 89 | # build: |
|
0 commit comments