Skip to content

Commit 9aae51d

Browse files
Integrate Inference Framework and Serving Stack
- Created maxdiffusion.inference module for unified inference logic. - Extracted loading logic into standalone FluxLoader and SDXLLoader. - Implemented InferenceLoader facade and DiffusionRunner for standardized execution. - Ported production serving stack (FastAPI, ZMQ Scheduler/Worker) to maxdiffusion.inference.server. - Refactored generate_flux.py and generate_sdxl.py to use the new framework.
1 parent f23746b commit 9aae51d

16 files changed

Lines changed: 1674 additions & 766 deletions

src/maxdiffusion/generate_flux.py

Lines changed: 31 additions & 467 deletions
Large diffs are not rendered by default.

src/maxdiffusion/generate_sdxl.py

Lines changed: 27 additions & 299 deletions
Original file line numberDiff line numberDiff line change
@@ -14,312 +14,40 @@
1414
limitations under the License.
1515
"""
1616

17-
import functools
18-
from absl import app
19-
from contextlib import ExitStack
2017
from typing import Sequence
18+
from absl import app
2119
import time
22-
23-
import numpy as np
24-
import jax
25-
import jax.numpy as jnp
26-
from jax.sharding import PartitionSpec as P
27-
import flax.linen as nn
28-
from flax.linen import partitioning as nn_partitioning
29-
30-
from maxdiffusion import pyconfig, max_utils
31-
from maxdiffusion.image_processor import VaeImageProcessor
32-
from maxdiffusion.maxdiffusion_utils import (
33-
get_add_time_ids,
34-
rescale_noise_cfg,
35-
load_sdxllightning_unet,
36-
maybe_load_sdxl_lora,
37-
create_scheduler,
38-
)
39-
40-
from maxdiffusion.trainers.sdxl_trainer import (StableDiffusionXLTrainer)
41-
42-
from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path
43-
44-
45-
class GenerateSDXL(StableDiffusionXLTrainer):
46-
47-
def __init__(self, config):
48-
super().__init__(config)
49-
50-
51-
def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale, config):
52-
latents, scheduler_state, state = args
53-
54-
if config.do_classifier_free_guidance:
55-
latents_input = jnp.concatenate([latents] * 2)
56-
else:
57-
latents_input = latents
58-
59-
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
60-
timestep = jnp.broadcast_to(t, latents_input.shape[0])
61-
62-
latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t)
63-
noise_pred = model.apply(
64-
{"params": state.params},
65-
jnp.array(latents_input),
66-
jnp.array(timestep, dtype=jnp.int32),
67-
encoder_hidden_states=prompt_embeds,
68-
added_cond_kwargs=added_cond_kwargs,
69-
).sample
70-
71-
def apply_classifier_free_guidance(noise_pred, guidance_scale):
72-
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
73-
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
74-
return noise_pred, noise_prediction_text
75-
76-
if config.do_classifier_free_guidance:
77-
noise_pred, noise_prediction_text = apply_classifier_free_guidance(noise_pred, guidance_scale)
78-
79-
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
80-
# Helps solve overexposure problem when terminal SNR approaches zero.
81-
# Empirical values recomended from the paper are guidance_scale=7.5 and guidance_rescale=0.7
82-
noise_pred = jax.lax.cond(
83-
guidance_rescale[0] > 0,
84-
lambda _: rescale_noise_cfg(noise_pred, noise_prediction_text, guidance_rescale),
85-
lambda _: noise_pred,
86-
operand=None,
87-
)
88-
latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
89-
90-
return latents, scheduler_state, state
91-
92-
93-
def get_embeddings(prompt_ids, pipeline, params):
94-
te_1_inputs = prompt_ids[:, 0, :]
95-
te_2_inputs = prompt_ids[:, 1, :]
96-
97-
prompt_embeds = pipeline.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True)
98-
prompt_embeds = prompt_embeds["hidden_states"][-2]
99-
prompt_embeds_2_out = pipeline.text_encoder_2(te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True)
100-
prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2]
101-
text_embeds = prompt_embeds_2_out["text_embeds"]
102-
prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1)
103-
return prompt_embeds, text_embeds
104-
105-
106-
def tokenize(prompt, pipeline):
107-
inputs = []
108-
for _tokenizer in [pipeline.tokenizer, pipeline.tokenizer_2]:
109-
text_inputs = _tokenizer(
110-
prompt, padding="max_length", max_length=_tokenizer.model_max_length, truncation=True, return_tensors="np"
111-
)
112-
inputs.append(text_inputs.input_ids)
113-
inputs = jnp.stack(inputs, axis=1)
114-
return inputs
115-
116-
117-
def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
118-
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
119-
120-
vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1)
121-
prompt_ids = [config.prompt] * batch_size
122-
prompt_ids = tokenize(prompt_ids, pipeline)
123-
negative_prompt_ids = [config.negative_prompt] * batch_size
124-
negative_prompt_ids = tokenize(negative_prompt_ids, pipeline)
125-
guidance_scale = config.guidance_scale
126-
guidance_rescale = config.guidance_rescale
127-
num_inference_steps = config.num_inference_steps
128-
height = config.resolution
129-
width = config.resolution
130-
text_encoder_params = {
131-
"text_encoder": states["text_encoder_state"].params,
132-
"text_encoder_2": states["text_encoder_2_state"].params,
133-
}
134-
prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, text_encoder_params)
135-
136-
batch_size = prompt_embeds.shape[0]
137-
add_time_ids = get_add_time_ids(
138-
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype
139-
)
140-
141-
if config.do_classifier_free_guidance:
142-
if negative_prompt_ids is None:
143-
negative_prompt_embeds = jnp.zeros_like(prompt_embeds)
144-
negative_pooled_embeds = jnp.zeros_like(pooled_embeds)
145-
else:
146-
negative_prompt_embeds, negative_pooled_embeds = get_embeddings(negative_prompt_ids, pipeline, text_encoder_params)
147-
148-
prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0)
149-
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
150-
add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)
151-
152-
else:
153-
add_text_embeds = pooled_embeds
154-
155-
# Ensure model output will be `float32` before going into the scheduler
156-
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
157-
guidance_rescale = jnp.array([guidance_rescale], dtype=jnp.float32)
158-
159-
latents_shape = (
160-
batch_size,
161-
pipeline.unet.config.in_channels,
162-
height // vae_scale_factor,
163-
width // vae_scale_factor,
164-
)
165-
166-
latents = jax.random.normal(rng, shape=latents_shape, dtype=jnp.float32)
167-
168-
scheduler_state = pipeline.scheduler.set_timesteps(
169-
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
170-
)
171-
172-
latents = latents * scheduler_state.init_noise_sigma
173-
174-
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
175-
latents = jax.device_put(latents, data_sharding)
176-
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
177-
added_cond_kwargs["text_embeds"] = jax.device_put(added_cond_kwargs["text_embeds"], data_sharding)
178-
added_cond_kwargs["time_ids"] = jax.device_put(added_cond_kwargs["time_ids"], data_sharding)
179-
180-
return latents, prompt_embeds, added_cond_kwargs, guidance_scale, guidance_rescale, scheduler_state
181-
182-
183-
def vae_decode(latents, state, pipeline):
184-
latents = 1 / pipeline.vae.config.scaling_factor * latents
185-
image = pipeline.vae.apply({"params": state.params}, latents, method=pipeline.vae.decode).sample
186-
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
187-
return image
188-
189-
190-
def run_inference(states, pipeline, params, config, rng, mesh, batch_size):
191-
unet_state = states["unet_state"]
192-
vae_state = states["vae_state"]
193-
194-
(latents, prompt_embeds, added_cond_kwargs, guidance_scale, guidance_rescale, scheduler_state) = get_unet_inputs(
195-
pipeline, params, states, config, rng, mesh, batch_size
196-
)
197-
198-
loop_body_p = functools.partial(
199-
loop_body,
200-
model=pipeline.unet,
201-
pipeline=pipeline,
202-
added_cond_kwargs=added_cond_kwargs,
203-
prompt_embeds=prompt_embeds,
204-
guidance_scale=guidance_scale,
205-
guidance_rescale=guidance_rescale,
206-
config=config,
207-
)
208-
vae_decode_p = functools.partial(vae_decode, pipeline=pipeline)
209-
210-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
211-
latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, loop_body_p, (latents, scheduler_state, unet_state))
212-
image = vae_decode_p(latents, vae_state)
213-
return image
214-
20+
from maxdiffusion import pyconfig, max_logging
21+
from maxdiffusion.inference.loader import InferenceLoader
22+
from maxdiffusion.inference.runner import DiffusionRunner
21523

21624
def run(config):
217-
checkpoint_loader = GenerateSDXL(config)
218-
mesh = checkpoint_loader.mesh
219-
with mesh:
220-
pipeline, params = checkpoint_loader.load_checkpoint()
221-
222-
noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config)
223-
224-
weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng)
225-
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
226-
pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False
227-
)
228-
229-
# load unet params from orbax checkpoint
230-
unet_params = load_params_from_path(
231-
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "unet_state"
232-
)
233-
if unet_params:
234-
params["unet"] = unet_params
235-
236-
# maybe load lora and create interceptor
237-
params, lora_interceptors = maybe_load_sdxl_lora(config, pipeline, params)
238-
239-
if config.lightning_repo:
240-
pipeline, params = load_sdxllightning_unet(config, pipeline, params)
241-
242-
# Don't restore the full train state, instead, just restore params
243-
# and create an inference state.
244-
with ExitStack() as stack:
245-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
246-
unet_state, unet_state_shardings = max_utils.setup_initial_state(
247-
model=pipeline.unet,
248-
tx=None,
249-
config=config,
250-
mesh=checkpoint_loader.mesh,
251-
weights_init_fn=weights_init_fn,
252-
model_params=None,
253-
training=False,
254-
)
255-
unet_state = unet_state.replace(params=params.get("unet", None))
256-
unet_state = jax.device_put(unet_state, unet_state_shardings)
257-
258-
vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
259-
pipeline, params, checkpoint_item_name="vae_state", is_training=False
260-
)
261-
with ExitStack() as stack:
262-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
263-
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
264-
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
265-
)
266-
267-
text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state(
268-
pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False
269-
)
270-
states = {}
271-
state_shardings = {}
272-
273-
state_shardings["vae_state"] = vae_state_shardings
274-
state_shardings["unet_state"] = unet_state_shardings
275-
state_shardings["text_encoder_state"] = text_encoder_state_shardings
276-
state_shardings["text_encoder_2_state"] = text_encoder_2_state_shardings
277-
278-
states["unet_state"] = unet_state
279-
states["vae_state"] = vae_state
280-
states["text_encoder_state"] = text_encoder_state
281-
states["text_encoder_2_state"] = text_encoder_2_state
282-
283-
pipeline.scheduler = noise_scheduler
284-
params["scheduler"] = noise_scheduler_state
285-
286-
p_run_inference = jax.jit(
287-
functools.partial(
288-
run_inference,
289-
pipeline=pipeline,
290-
params=params,
291-
config=config,
292-
rng=checkpoint_loader.rng,
293-
mesh=checkpoint_loader.mesh,
294-
batch_size=checkpoint_loader.total_train_batch_size,
295-
),
296-
in_shardings=(state_shardings,),
297-
out_shardings=None,
298-
)
299-
300-
s = time.time()
301-
with ExitStack() as stack:
302-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
303-
p_run_inference(states).block_until_ready()
304-
print("compile time: ", (time.time() - s))
305-
s = time.time()
306-
with ExitStack() as stack:
307-
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
308-
images = p_run_inference(states).block_until_ready()
309-
print("inference time: ", (time.time() - s))
310-
images = jax.experimental.multihost_utils.process_allgather(images, tiled=True)
311-
numpy_images = np.array(images)
312-
images = VaeImageProcessor.numpy_to_pil(numpy_images)
313-
for i, image in enumerate(images):
314-
image.save(f"image_sdxl_{i}.png")
315-
316-
return images
317-
25+
# 1. Load Model
26+
max_logging.log("Initializing InferenceLoader...")
27+
loaded_model = InferenceLoader.load(config)
28+
29+
# 2. Initialize Runner
30+
max_logging.log("Initializing DiffusionRunner...")
31+
runner = DiffusionRunner(loaded_model, config)
32+
33+
# 3. Run Inference
34+
max_logging.log("Starting Inference...")
35+
t0 = time.perf_counter()
36+
pil_images = runner.run()
37+
t1 = time.perf_counter()
38+
max_logging.log(f"Inference time: {t1 - t0:.2f}s")
39+
40+
# 4. Save Images
41+
for i, image in enumerate(pil_images):
42+
save_path = f"image_sdxl_{i}.png"
43+
image.save(save_path)
44+
max_logging.log(f"Saved image to {save_path}")
45+
46+
return pil_images
31847

31948
def main(argv: Sequence[str]) -> None:
32049
pyconfig.initialize(argv)
32150
run(pyconfig.config)
32251

323-
32452
if __name__ == "__main__":
325-
app.run(main)
53+
app.run(main)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# MaxDiffusion Inference Framework
2+
3+
This module provides a unified, production-ready inference stack for MaxDiffusion models (Wan, Flux, SDXL). It decouples inference from training dependencies and provides a consistent interface for both offline generation scripts and online serving.
4+
5+
## Components
6+
7+
* **`loader`**: Unified Model Loader. Handles loading weights from `orbax` (MaxDiffusion checkpoints) or HuggingFace/Diffusers (Safetensors) without instantiating Training Trainers.
8+
* **`runner`**: Core Inference Runner. Encapsulates the JAX/TPU mesh, JIT compilation of inference steps, and the denoising loop.
9+
* **`server`**: A high-performance decoupled serving stack.
10+
* **Frontend**: FastAPI server handling HTTP requests.
11+
* **Backend**: ZeroMQ-based Scheduler and TPU Worker.
12+
13+
## Usage
14+
15+
### 1. Offline Generation (Scripts)
16+
17+
The root level scripts `generate_flux.py` and `generate_sdxl.py` have been refactored to use this framework.
18+
19+
```bash
20+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml
21+
```
22+
23+
### 2. Online Serving
24+
25+
To start the serving stack:
26+
27+
**Start Scheduler (Backend)**
28+
```bash
29+
python -m maxdiffusion.inference.server.scheduler src/maxdiffusion/configs/base_flux_dev.yml
30+
```
31+
32+
**Start API (Frontend)**
33+
```bash
34+
python -m maxdiffusion.inference.server.api
35+
```
36+
37+
**Send Request**
38+
```bash
39+
curl -X POST http://localhost:8000/generate -d '{"prompt": "A photo of a cat", "num_inference_steps": 20}'
40+
```
41+
42+
## Architecture
43+
44+
```
45+
InferenceLoader -> [Pipeline, Params, State] -> DiffusionRunner -> [Images]
46+
^
47+
|
48+
TPU Worker
49+
^
50+
| ZMQ
51+
v
52+
Scheduler <-> API <-> User
53+
```
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Inference module for MaxDiffusion."""
2+
from .loader import InferenceLoader
3+
from .runner import DiffusionRunner

0 commit comments

Comments
 (0)