File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 6969 # IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
7070 runs-on : ["linux-x86-a3-megagpu-h100-8gpu"]
7171 container :
72- image : gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0
72+ image : gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:
7373 steps :
7474 - name : Checkout Repository
7575 uses : actions/checkout@v4
8989
9090 - name : Print dependencies
9191 run : |
92- # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
92+ pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12
9393 # pip install transformer_engine[jax]==2.4.0
9494 # pip install -U transformer-engine[jax]==2.6.0
9595 # pip uninstall -y transformer-engine-cu12
@@ -108,7 +108,7 @@ jobs:
108108 train_text_encoder=false \
109109 cache_latents_text_encoder_outputs=true \
110110 per_device_batch_size=1 \
111- attention=dot_product \
111+ attention=cudnn_flash_te \
112112 activations_dtype=bfloat16 \
113113 weights_dtype=bfloat16 \
114114 max_train_steps=200 \
You can’t perform that action at this time.
0 commit comments