|
24 | 24 | workflow_dispatch: |
25 | 25 |
|
26 | 26 | jobs: |
27 | | - maxtext_workload: |
28 | | - name: "Run MaxText Workload" |
29 | | - # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
30 | | - runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
31 | | - container: |
32 | | - image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest |
33 | | - steps: |
34 | | - - name: Checkout MaxText Repo |
35 | | - uses: actions/checkout@v4 |
36 | | - with: |
37 | | - repository: AI-Hypercomputer/maxtext |
38 | | - path: maxtext |
39 | | - ref: rbierneni-test-gpu-run |
40 | | - |
41 | | - - name: Print dependencies |
42 | | - run: | |
43 | | - pip uninstall -y transformer-engine transformer-engine-jax |
44 | | - pip install -U transformer-engine[jax]==2.6.0 |
45 | | - pip freeze |
46 | | -
|
47 | | - - name: Run MaxText Training |
48 | | - run: | |
49 | | - # This command is adapted from your DAG for a single-slice configuration. |
50 | | - cd maxtext && \ |
51 | | - pip install . --no-dependencies |
52 | | -
|
53 | | - export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 |
54 | | - export TF_FORCE_GPU_ALLOW_GROWTH=true |
55 | | -
|
56 | | - python3 -m MaxText.train MaxText/configs/base.yml \ |
57 | | - steps=2 \ |
58 | | - enable_checkpointing=false \ |
59 | | - attention=cudnn_flash_te \ |
60 | | - dataset_type=synthetic \ |
61 | | - run_name=rbierneni-test-maxtext-gpu \ |
62 | | - base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} |
63 | | -
|
64 | | - # # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD |
65 | | - # maxdiffusion_workload: |
66 | | - # name: "Run MaxDiffusion Workload" |
| 27 | + # maxtext_workload: |
| 28 | + # name: "Run MaxText Workload" |
67 | 29 | # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
68 | 30 | # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
69 | 31 | # container: |
70 | | - # image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu |
| 32 | + # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest |
71 | 33 | # steps: |
72 | | - # - name: Checkout Repository |
| 34 | + # - name: Checkout MaxText Repo |
73 | 35 | # uses: actions/checkout@v4 |
74 | | - |
| 36 | + # with: |
| 37 | + # repository: AI-Hypercomputer/maxtext |
| 38 | + # path: maxtext |
| 39 | + # ref: rbierneni-test-gpu-run |
| 40 | + |
75 | 41 | # - name: Print dependencies |
76 | 42 | # run: | |
77 | | - # # pip uninstall -y transformer-engine transformer-engine-jax |
78 | | - # # pip install -U transformer-engine[pytorch,jax] |
| 43 | + # pip uninstall -y transformer-engine transformer-engine-jax |
| 44 | + # pip install -U transformer-engine[jax]==2.6.0 |
| 45 | + # pip uninstall -y tensorflow |
| 46 | + # pip install tensorflow-cpu |
79 | 47 | # pip freeze |
80 | 48 |
|
81 | | - # - name: Run MaxDiffusion Training |
| 49 | + # - name: Run MaxText Training |
82 | 50 | # run: | |
83 | 51 | # # This command is adapted from your DAG for a single-slice configuration. |
84 | | - # NVTE_FUSED_ATTN=1 pip install . && \ |
85 | | - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
86 | | - # hardware=gpu \ |
87 | | - # train_new_unet=true \ |
88 | | - # train_text_encoder=false \ |
89 | | - # cache_latents_text_encoder_outputs=true \ |
90 | | - # per_device_batch_size=1 \ |
91 | | - # attention=dot_product \ |
92 | | - # activations_dtype=bfloat16 \ |
93 | | - # weights_dtype=bfloat16 \ |
94 | | - # max_train_steps=200 \ |
95 | | - # enable_profiler=True \ |
96 | | - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ |
97 | | - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} |
| 52 | + # cd maxtext && \ |
| 53 | + # pip install . --no-dependencies |
| 54 | + |
| 55 | + # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 |
| 56 | + # export TF_FORCE_GPU_ALLOW_GROWTH=true |
| 57 | + |
| 58 | + # python3 -m MaxText.train MaxText/configs/base.yml \ |
| 59 | + # steps=2 \ |
| 60 | + # enable_checkpointing=false \ |
| 61 | + # attention=cudnn_flash_te \ |
| 62 | + # dataset_type=synthetic \ |
| 63 | + # run_name=rbierneni-test-maxtext-gpu \ |
| 64 | + # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} |
| 65 | + |
| 66 | + # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD |
| 67 | + maxdiffusion_workload: |
| 68 | + name: "Run MaxDiffusion Workload" |
| 69 | + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
| 70 | + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
| 71 | + container: |
| 72 | + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu |
| 73 | + steps: |
| 74 | + - name: Checkout Repository |
| 75 | + uses: actions/checkout@v4 |
| 76 | + |
| 77 | + - name: Print dependencies |
| 78 | + run: | |
| 79 | + # pip uninstall -y transformer-engine transformer-engine-jax |
| 80 | + # pip install -U transformer-engine[pytorch,jax] |
| 81 | + pip uninstall -y tensorflow |
| 82 | + pip install tensorflow-cpu |
| 83 | + pip freeze |
| 84 | +
|
| 85 | + - name: Run MaxDiffusion Training |
| 86 | + run: | |
| 87 | + # This command is adapted from your DAG for a single-slice configuration. |
| 88 | + NVTE_FUSED_ATTN=1 pip install . && \ |
| 89 | + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
| 90 | + hardware=gpu \ |
| 91 | + train_new_unet=true \ |
| 92 | + train_text_encoder=false \ |
| 93 | + cache_latents_text_encoder_outputs=true \ |
| 94 | + per_device_batch_size=1 \ |
| 95 | + attention=dot_product \ |
| 96 | + activations_dtype=bfloat16 \ |
| 97 | + weights_dtype=bfloat16 \ |
| 98 | + max_train_steps=200 \ |
| 99 | + enable_profiler=True \ |
| 100 | + run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ |
| 101 | + output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} |
98 | 102 |
|
99 | 103 | # jobs: |
100 | 104 | # build: |
|
0 commit comments