Skip to content

Commit 8905362

Browse files
Merge branch 'main' into flux_impl
2 parents 587bc6a + 271ce08 commit 8905362

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/generate_sdxl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def run(config):
305305
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
306306
images = p_run_inference(states).block_until_ready()
307307
print("inference time: ", (time.time() - s))
308-
images = jax.experimental.multihost_utils.process_allgather(images)
308+
images = jax.experimental.multihost_utils.process_allgather(images, tiled=True)
309309
numpy_images = np.array(images)
310310
images = VaeImageProcessor.numpy_to_pil(numpy_images)
311311
for i, image in enumerate(images):

tests/schedulers/test_scheduler_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def test_full_loop_no_noise(self):
335335
result_mean = jnp.mean(jnp.abs(sample))
336336

337337
if jax_device == "tpu":
338-
assert abs(result_sum - 251.26245) < 1e-2
339-
assert abs(result_mean - 0.32716465) < 1e-3
338+
assert abs(result_sum - 257.2727) < 1e-2
339+
assert abs(result_mean - 0.3349905) < 1e-3
340340
else:
341341
assert abs(result_sum - 255.1113) < 1e-2
342342
assert abs(result_mean - 0.332176) < 1e-3

0 commit comments

Comments
 (0)