diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index f7ee2166c..d2e9a5a2c 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -31,9 +31,9 @@ jobs: strategy: fail-fast: false matrix: - tpu-type: ["v4-8"] + tpu-type: ["v5p-8"] name: "TPU test (${{ matrix.tpu-type }})" - runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"] + runs-on: ["self-hosted","${{ matrix.tpu-type }}"] steps: - uses: actions/checkout@v4 - name: Set up Python 3.12 diff --git a/.github/workflows/XLML.yml b/.github/workflows/XLML.yml index c9a3bf69b..37f320787 100644 --- a/.github/workflows/XLML.yml +++ b/.github/workflows/XLML.yml @@ -2,7 +2,7 @@ name: Add Testgrid Link to PR on: pull_request: - types: [opened, synchronize] + types: [opened] jobs: add_testgrid_link: diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py index a63e3a966..974ac3ab3 100644 --- a/tests/schedulers/test_scheduler_flax.py +++ b/tests/schedulers/test_scheduler_flax.py @@ -335,8 +335,8 @@ def test_full_loop_no_noise(self): result_mean = jnp.mean(jnp.abs(sample)) if jax_device == "tpu": - assert abs(result_sum - 257.2727) < 1e-2 - assert abs(result_mean - 0.3349905) < 1e-3 + assert abs(result_sum - 257.2727) < 1.5e-2 + assert abs(result_mean - 0.3349905) < 1e-5 else: assert abs(result_sum - 255.1113) < 1e-2 assert abs(result_mean - 0.332176) < 1e-3