Skip to content

Commit d1263eb

Browse files
try without te
1 parent 1273c3e commit d1263eb

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ jobs:
6969
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
7070
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
7171
container:
72-
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0
72+
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:
7373
steps:
7474
- name: Checkout Repository
7575
uses: actions/checkout@v4
@@ -89,7 +89,7 @@ jobs:
8989
9090
- name: Print dependencies
9191
run: |
92-
# pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
92+
pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
9393
# pip install transformer_engine[jax]==2.4.0
9494
# pip install -U transformer-engine[jax]==2.6.0
9595
# pip uninstall -y transformer-engine-cu12
@@ -108,7 +108,7 @@ jobs:
108108
train_text_encoder=false \
109109
cache_latents_text_encoder_outputs=true \
110110
per_device_batch_size=1 \
111-
attention=dot_product \
111+
attention=cudnn_flash_te \
112112
activations_dtype=bfloat16 \
113113
weights_dtype=bfloat16 \
114114
max_train_steps=200 \

0 commit comments

Comments
 (0)