Skip to content

Commit dfe8fc5

Browse files
Use TE 2.6.0
1 parent a1da601 commit dfe8fc5

1 file changed

Lines changed: 15 additions & 15 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
- name: Print dependencies
4242
run: |
4343
pip uninstall -y transformer-engine transformer-engine-jax
44-
pip install -U transformer-engine[jax]==2.5.0
44+
pip install -U transformer-engine[jax]==2.6.0
4545
pip freeze
4646
4747
- name: Run MaxText Training
@@ -81,20 +81,20 @@ jobs:
8181
# - name: Run MaxDiffusion Training
8282
# run: |
8383
# # This command is adapted from your DAG for a single-slice configuration.
84-
# NVTE_FUSED_ATTN=1 pip install . && \
85-
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
86-
# hardware=gpu \
87-
# train_new_unet=true \
88-
# train_text_encoder=false \
89-
# cache_latents_text_encoder_outputs=true \
90-
# per_device_batch_size=1 \
91-
# attention=dot_product \
92-
# activations_dtype=bfloat16 \
93-
# weights_dtype=bfloat16 \
94-
# max_train_steps=200 \
95-
# enable_profiler=True \
96-
# run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
97-
# output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
84+
# NVTE_FUSED_ATTN=1 pip install . && \
85+
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
86+
# hardware=gpu \
87+
# train_new_unet=true \
88+
# train_text_encoder=false \
89+
# cache_latents_text_encoder_outputs=true \
90+
# per_device_batch_size=1 \
91+
# attention=dot_product \
92+
# activations_dtype=bfloat16 \
93+
# weights_dtype=bfloat16 \
94+
# max_train_steps=200 \
95+
# enable_profiler=True \
96+
# run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
97+
# output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
9898

9999
# jobs:
100100
# build:

0 commit comments

Comments
 (0)