File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -305,7 +305,7 @@ def run(config):
305305 _ = [stack .enter_context (nn .intercept_methods (interceptor )) for interceptor in lora_interceptors ]
306306 images = p_run_inference (states ).block_until_ready ()
307307 print ("inference time: " , (time .time () - s ))
308- images = jax .experimental .multihost_utils .process_allgather (images )
308+ images = jax .experimental .multihost_utils .process_allgather (images , tiled = True )
309309 numpy_images = np .array (images )
310310 images = VaeImageProcessor .numpy_to_pil (numpy_images )
311311 for i , image in enumerate (images ):
Original file line number Diff line number Diff line change @@ -335,8 +335,8 @@ def test_full_loop_no_noise(self):
335335 result_mean = jnp .mean (jnp .abs (sample ))
336336
337337 if jax_device == "tpu" :
338- assert abs (result_sum - 251.26245 ) < 1e-2
339- assert abs (result_mean - 0.32716465 ) < 1e-3
338+ assert abs (result_sum - 257.2727 ) < 1e-2
339+ assert abs (result_mean - 0.3349905 ) < 1e-3
340340 else :
341341 assert abs (result_sum - 255.1113 ) < 1e-2
342342 assert abs (result_mean - 0.332176 ) < 1e-3
You can’t perform that action at this time.
0 commit comments