Skip to content

Commit 7ff2c23

Browse files
try with cuda 12 on TE cu12
1 parent ab478a1 commit 7ff2c23

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ jobs:
6969
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
7070
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
7171
container:
72-
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev4_gpu
72+
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda12_tecu12_gpu
7373
steps:
7474
- name: Checkout Repository
7575
uses: actions/checkout@v4
@@ -89,7 +89,8 @@ jobs:
8989
9090
- name: Print dependencies
9191
run: |
92-
# pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
92+
pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
93+
pip install transformer_engine[jax]==2.6.0
9394
# pip install -U transformer-engine[jax]==2.6.0
9495
# pip uninstall -y transformer-engine-cu12
9596
# pip install transformer-engine-cu13

0 commit comments

Comments
 (0)