From e1779354e41fed3baab82c1658d3de5b1f12bc2a Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 18 Mar 2025 03:44:15 +0000 Subject: [PATCH] fixes accidental wrong merge. --- src/maxdiffusion/generate_flux.py | 2 +- src/maxdiffusion/max_utils.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 59564a271..e6e5bcb0e 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -469,7 +469,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep max_logging.log(f"Compile time: {t1 - t0:.1f}s.") t0 = time.perf_counter() - with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"): + with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"): _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] imgs = p_run_inference(states).block_until_ready() t1 = time.perf_counter() diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 2d37e9416..3dff39e3c 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -424,8 +424,6 @@ def setup_initial_state( if model_params: state = state.replace(params=model_params) state = jax.device_put(state, state_mesh_shardings) - if model_params: - state = state.replace(params=model_params) state = unbox_logicallypartioned_trainstate(state)