Skip to content

Commit 5494644

Browse files
halves inference time.
1 parent 3bedc5d commit 5494644

4 files changed

Lines changed: 40 additions & 15 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ mesh_axes: ['data', 'fsdp', 'tensor']
112112
# conv_out : conv.shape[-1] weight
113113
logical_axis_rules: [
114114
['batch', 'data'],
115+
['activation_heads', 'fsdp'],
115116
['activation_batch', ['data','fsdp']],
116-
['activation_heads', 'tensor'],
117117
['activation_kv', 'tensor'],
118118
['mlp','tensor'],
119119
['embed','fsdp'],
@@ -182,6 +182,8 @@ num_train_epochs: 1
182182
seed: 0
183183
output_dir: 'sdxl-model-finetuned'
184184
per_device_batch_size: 1
185+
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
186+
global_batch_size: 0
185187

186188
warmup_steps_fraction: 0.1
187189
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

src/maxdiffusion/generate_wan.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import jax
1717
import time
1818
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
19-
from maxdiffusion import pyconfig
19+
from maxdiffusion import pyconfig, max_logging
2020
from absl import app
2121
from maxdiffusion.utils import export_to_video
2222

@@ -30,9 +30,17 @@ def run(config):
3030
slg_layers = config.slg_layers
3131
slg_start = config.slg_start
3232
slg_end = config.slg_end
33+
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
34+
global_batch_size = config.global_batch_size
35+
if global_batch_size != 0:
36+
batch_multiplier = global_batch_size
37+
else:
38+
batch_multiplier = jax.device_count() * config.per_device_batch_size
3339

34-
prompt = [config.prompt] * jax.device_count()
35-
negative_prompt = [config.negative_prompt] * jax.device_count()
40+
prompt = [config.prompt] * batch_multiplier
41+
negative_prompt = [config.negative_prompt] * batch_multiplier
42+
43+
max_logging.log(f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}")
3644

3745
videos = pipeline(
3846
prompt=prompt,
@@ -51,6 +59,23 @@ def run(config):
5159
for i in range(len(videos)):
5260
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
5361
s0 = time.perf_counter()
62+
videos = pipeline(
63+
prompt=prompt,
64+
negative_prompt=negative_prompt,
65+
height=config.height,
66+
width=config.width,
67+
num_frames=config.num_frames,
68+
num_inference_steps=config.num_inference_steps,
69+
guidance_scale=config.guidance_scale,
70+
slg_layers=slg_layers,
71+
slg_start=slg_start,
72+
slg_end=slg_end,
73+
)
74+
print("generation time: ", (time.perf_counter() - s0))
75+
for i in range(len(videos)):
76+
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
77+
78+
s0 = time.perf_counter()
5479
with jax.profiler.trace("/tmp/trace/"):
5580
videos = pipeline(
5681
prompt=prompt,
@@ -65,9 +90,6 @@ def run(config):
6590
slg_end=slg_end,
6691
)
6792
print("generation time: ", (time.perf_counter() - s0))
68-
for i in range(len(videos)):
69-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
70-
7193

7294
def main(argv: Sequence[str]) -> None:
7395
pyconfig.initialize(argv)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -311,29 +311,29 @@ def __init__(
311311

312312
def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array):
313313
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
314-
(self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
314+
(self.scale_shift_table + temb), 6, axis=1
315315
)
316316

317317
# 1. Self-attention
318-
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
318+
norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(
319319
hidden_states.dtype
320320
)
321321
attn_output = self.attn1(
322322
hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, rotary_emb=rotary_emb
323323
)
324-
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
324+
hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype)
325325

326326
# 2. Cross-attention
327-
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32))
327+
norm_hidden_states = self.norm2(hidden_states)
328328
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
329329
hidden_states = hidden_states + attn_output
330330

331331
# 3. Feed-forward
332-
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
332+
norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(
333333
hidden_states.dtype
334334
)
335335
ff_output = self.ffn(norm_hidden_states)
336-
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
336+
hidden_states = (hidden_states + ff_output * c_gate_msa).astype(
337337
hidden_states.dtype
338338
)
339339
return hidden_states
@@ -485,7 +485,7 @@ def __call__(
485485
)
486486
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
487487

488-
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
488+
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).astype(hidden_states.dtype)
489489
hidden_states = self.proj_out(hidden_states)
490490

491491
hidden_states = hidden_states.reshape(

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import flax
2222
import flax.linen as nn
2323
from flax import nnx
24+
from flax.linen import partitioning as nn_partitioning
2425
from ...pyconfig import HyperParameters
2526
from ... import max_logging
2627
from ... import max_utils
@@ -420,7 +421,7 @@ def __call__(
420421
num_transformer_layers=self.transformer.config.num_layers,
421422
)
422423

423-
with self.mesh:
424+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
424425
latents = p_run_inference(
425426
graphdef=graphdef,
426427
sharded_state=state,

0 commit comments

Comments
 (0)