|
24 | 24 | workflow_dispatch: |
25 | 25 |
|
26 | 26 | jobs: |
27 | | - # maxtext_workload: |
28 | | - # name: "Run MaxText Workload" |
29 | | - # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
30 | | - # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
31 | | - # container: |
32 | | - # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest |
33 | | - # steps: |
34 | | - # - name: Checkout MaxText Repo |
35 | | - # uses: actions/checkout@v4 |
36 | | - # with: |
37 | | - # repository: AI-Hypercomputer/maxtext |
38 | | - # path: maxtext |
39 | | - # ref: rbierneni-test-gpu-run |
| 27 | + maxtext_workload: |
| 28 | + name: "Run MaxText Workload" |
| 29 | + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
| 30 | + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
| 31 | + container: |
| 32 | + image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest |
| 33 | + steps: |
| 34 | + - name: Checkout MaxText Repo |
| 35 | + uses: actions/checkout@v4 |
| 36 | + with: |
| 37 | + repository: AI-Hypercomputer/maxtext |
| 38 | + path: maxtext |
| 39 | + ref: rbierneni-test-gpu-run |
40 | 40 |
|
41 | | - # - name: Print dependencies |
42 | | - # run: | |
43 | | - # pip uninstall -y transformer-engine transformer-engine-jax |
44 | | - # pip install -U transformer-engine[jax]==2.6.0 |
45 | | - # pip uninstall -y tensorflow |
46 | | - # pip install tensorflow-cpu |
47 | | - # pip freeze |
| 41 | + - name: Print dependencies |
| 42 | + run: | |
| 43 | + pip uninstall -y transformer-engine transformer-engine-jax |
| 44 | + pip install -U transformer-engine[jax]==2.6.0 |
| 45 | + # pip uninstall -y tensorflow |
| 46 | + # pip install tensorflow-cpu |
| 47 | + pip freeze |
48 | 48 |
|
49 | | - # - name: Run MaxText Training |
50 | | - # run: | |
51 | | - # # This command is adapted from your DAG for a single-slice configuration. |
52 | | - # cd maxtext && \ |
53 | | - # pip install . --no-dependencies |
| 49 | + - name: Run MaxText Training |
| 50 | + run: | |
| 51 | + # This command is adapted from your DAG for a single-slice configuration. |
| 52 | + cd maxtext && \ |
| 53 | + pip install . --no-dependencies |
54 | 54 |
|
55 | | - # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 |
56 | | - # export TF_FORCE_GPU_ALLOW_GROWTH=true |
| 55 | + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 |
| 56 | + export TF_FORCE_GPU_ALLOW_GROWTH=true |
57 | 57 |
|
58 | | - # python3 -m MaxText.train MaxText/configs/base.yml \ |
59 | | - # steps=2 \ |
60 | | - # enable_checkpointing=false \ |
61 | | - # attention=cudnn_flash_te \ |
62 | | - # dataset_type=synthetic \ |
63 | | - # run_name=rbierneni-test-maxtext-gpu \ |
64 | | - # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} |
| 58 | + python3 -m MaxText.train MaxText/configs/base.yml \ |
| 59 | + steps=2 \ |
| 60 | + enable_checkpointing=false \ |
| 61 | + attention=cudnn_flash_te \ |
| 62 | + dataset_type=synthetic \ |
| 63 | + run_name=rbierneni-test-maxtext-gpu \ |
| 64 | + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} |
65 | 65 |
|
66 | 66 | # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD |
67 | 67 | maxdiffusion_workload: |
68 | 68 | name: "Run MaxDiffusion Workload" |
69 | 69 | # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
70 | 70 | runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
71 | 71 | container: |
72 | | - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu |
| 72 | + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev2_gpu |
73 | 73 | steps: |
74 | 74 | - name: Checkout Repository |
75 | 75 | uses: actions/checkout@v4 |
76 | 76 |
|
77 | 77 | - name: Print dependencies |
78 | 78 | run: | |
79 | 79 | # pip uninstall -y transformer-engine transformer-engine-jax |
80 | | - # pip install -U transformer-engine[pytorch,jax] |
81 | | - pip uninstall -y tensorflow |
82 | | - pip install tensorflow-cpu |
| 80 | + pip install -U transformer-engine[jax]==2.6.0 |
| 81 | + # pip uninstall -y tensorflow |
| 82 | + # pip install tensorflow-cpu |
83 | 83 | pip freeze |
84 | 84 |
|
85 | 85 | - name: Run MaxDiffusion Training |
|
92 | 92 | train_text_encoder=false \ |
93 | 93 | cache_latents_text_encoder_outputs=true \ |
94 | 94 | per_device_batch_size=1 \ |
95 | | - attention=dot_product \ |
| 95 | + attention=cudnn_flash_te \ |
96 | 96 | activations_dtype=bfloat16 \ |
97 | 97 | weights_dtype=bfloat16 \ |
98 | 98 | max_train_steps=200 \ |
|
0 commit comments