|
41 | 41 | - name: Print dependencies |
42 | 42 | run: | |
43 | 43 | pip uninstall -y transformer-engine transformer-engine-jax |
44 | | - pip install -U transformer-engine[jax]==2.5.0 |
| 44 | + pip install -U transformer-engine[jax]==2.6.0 |
45 | 45 | pip freeze |
46 | 46 |
|
47 | 47 | - name: Run MaxText Training |
@@ -81,20 +81,20 @@ jobs: |
81 | 81 | # - name: Run MaxDiffusion Training |
82 | 82 | # run: | |
83 | 83 | # # This command is adapted from your DAG for a single-slice configuration. |
84 | | - # NVTE_FUSED_ATTN=1 pip install . && \ |
85 | | - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
86 | | - # hardware=gpu \ |
87 | | - # train_new_unet=true \ |
88 | | - # train_text_encoder=false \ |
89 | | - # cache_latents_text_encoder_outputs=true \ |
90 | | - # per_device_batch_size=1 \ |
91 | | - # attention=dot_product \ |
92 | | - # activations_dtype=bfloat16 \ |
93 | | - # weights_dtype=bfloat16 \ |
94 | | - # max_train_steps=200 \ |
95 | | - # enable_profiler=True \ |
96 | | - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ |
97 | | - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} |
| 84 | + # NVTE_FUSED_ATTN=1 pip install . && \ |
| 85 | + # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
| 86 | + # hardware=gpu \ |
| 87 | + # train_new_unet=true \ |
| 88 | + # train_text_encoder=false \ |
| 89 | + # cache_latents_text_encoder_outputs=true \ |
| 90 | + # per_device_batch_size=1 \ |
| 91 | + # attention=dot_product \ |
| 92 | + # activations_dtype=bfloat16 \ |
| 93 | + # weights_dtype=bfloat16 \ |
| 94 | + # max_train_steps=200 \ |
| 95 | + # enable_profiler=True \ |
| 96 | + # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ |
| 97 | + # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} |
98 | 98 |
|
99 | 99 | # jobs: |
100 | 100 | # build: |
|
0 commit comments