diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 59564a271..e6e5bcb0e 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -469,7 +469,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep max_logging.log(f"Compile time: {t1 - t0:.1f}s.") t0 = time.perf_counter() - with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"): + with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"): _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] imgs = p_run_inference(states).block_until_ready() t1 = time.perf_counter() diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 2d37e9416..3dff39e3c 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -424,8 +424,6 @@ def setup_initial_state( if model_params: state = state.replace(params=model_params) state = jax.device_put(state, state_mesh_shardings) - if model_params: - state = state.replace(params=model_params) state = unbox_logicallypartioned_trainstate(state)