@@ -335,8 +335,8 @@ def test_full_loop_no_noise(self):
335335 result_mean = jnp .mean (jnp .abs (sample ))
336336
337337 if jax_device == "tpu" :
338- assert abs (result_sum - 263.11 ) < 1.5e-2
339- assert abs (result_mean - 0.34259 ) < 2e-5
338+ assert abs (result_sum - 257.32495 ) < 1.5e-2
339+ assert abs (result_mean - 0.335059 ) < 2e-5
340340 else :
341341 assert abs (result_sum - 255.1113 ) < 1e-2
342342 assert abs (result_mean - 0.332176 ) < 1e-3
@@ -621,7 +621,7 @@ def test_full_loop_with_set_alpha_to_one(self):
621621 result_mean = jnp .mean (jnp .abs (sample ))
622622
623623 if jax_device == "tpu" :
624- assert abs (result_sum - 149.8409 ) < 1e-2
624+ assert abs (result_sum - 149.82944 ) < 1e-2
625625 assert abs (result_mean - 0.1951 ) < 1e-3
626626 else :
627627 assert abs (result_sum - 149.8295 ) < 1e-2
@@ -919,7 +919,7 @@ def test_full_loop_with_set_alpha_to_one(self):
919919 result_mean = jnp .mean (jnp .abs (sample ))
920920
921921 if jax_device == "tpu" :
922- assert abs (result_sum - 186.83226 ) < 8e-2
922+ assert abs (result_sum - 186.94574 ) < 8e-2
923923 assert abs (result_mean - 0.24327 ) < 1e-3
924924 else :
925925 assert abs (result_sum - 186.9466 ) < 1e-2
@@ -932,7 +932,7 @@ def test_full_loop_with_no_set_alpha_to_one(self):
932932 result_mean = jnp .mean (jnp .abs (sample ))
933933
934934 if jax_device == "tpu" :
935- assert abs (result_sum - 186.83226 ) < 8e-2
935+ assert abs (result_sum - 186.94574 ) < 8e-2
936936 assert abs (result_mean - 0.24327 ) < 1e-3
937937 else :
938938 assert abs (result_sum - 186.9482 ) < 1e-2
0 commit comments