Skip to content

Commit 9bee595

Browse files
fixes accidental wrong merge. (#158)
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
1 parent 8b77101 commit 9bee595

2 files changed

Lines changed: 1 addition & 3 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
469469
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
470470

471471
t0 = time.perf_counter()
472-
with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"):
472+
with ExitStack() as stack, jax.profiler.trace("/tmp/trace/"):
473473
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
474474
imgs = p_run_inference(states).block_until_ready()
475475
t1 = time.perf_counter()

src/maxdiffusion/max_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,6 @@ def setup_initial_state(
424424
if model_params:
425425
state = state.replace(params=model_params)
426426
state = jax.device_put(state, state_mesh_shardings)
427-
if model_params:
428-
state = state.replace(params=model_params)
429427

430428
state = unbox_logicallypartioned_trainstate(state)
431429

0 commit comments

Comments
 (0)