Skip to content

Commit 3bedc5d

Browse files
lint.
1 parent cd031b8 commit 3bedc5d

18 files changed

Lines changed: 1814 additions & 1945 deletions

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,35 @@
1515
"""
1616

1717
from abc import ABC
18-
import jax
1918
from flax import nnx
2019
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2120
from ..pipelines.wan.wan_pipeline import WanPipeline
2221
from .. import max_logging, max_utils
2322

2423
WAN_CHECKPOINT = "WAN_CHECKPOINT"
2524

25+
2626
class WanCheckpointer(ABC):
27+
2728
def __init__(self, config, checkpoint_type):
2829
self.config = config
2930
self.checkpoint_type = checkpoint_type
3031

3132
self.checkpoint_manager = create_orbax_checkpoint_manager(
32-
self.config.checkpoint_dir,
33-
enable_checkpointing=True,
34-
save_interval_steps=1,
35-
checkpoint_type=checkpoint_type,
36-
dataset_type=config.dataset_type
33+
self.config.checkpoint_dir,
34+
enable_checkpointing=True,
35+
save_interval_steps=1,
36+
checkpoint_type=checkpoint_type,
37+
dataset_type=config.dataset_type,
3738
)
38-
39+
3940
def _create_optimizer(self, model, config, learning_rate):
4041
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
41-
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
42+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
4243
)
4344
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
4445
return nnx.Optimizer(model, tx), learning_rate_scheduler
4546

46-
4747
def load_wan_configs_from_orbax(self, step):
4848
max_logging.log("Restoring stable diffusion configs")
4949
if step is None:
@@ -59,8 +59,8 @@ def load_checkpoint(self, step=None):
5959
model_configs = self.load_wan_configs_from_orbax(step)
6060

6161
if model_configs:
62-
raise NotImplemented("model configs should not exist in orbax")
62+
raise NotImplementedError("model configs should not exist in orbax")
6363
else:
6464
pipeline = self.load_diffusers_checkpoint()
65-
66-
return pipeline
65+
66+
return pipeline

src/maxdiffusion/generate_wan.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,21 @@
2020
from absl import app
2121
from maxdiffusion.utils import export_to_video
2222

23+
2324
def run(config):
2425
print("seed: ", config.seed)
2526
pipeline = WanPipeline.from_pretrained(config)
2627
s0 = time.perf_counter()
27-
28+
2829
# Skip layer guidance
2930
slg_layers = config.slg_layers
3031
slg_start = config.slg_start
3132
slg_end = config.slg_end
3233

3334
prompt = [config.prompt] * jax.device_count()
34-
negative_prompt= [config.negative_prompt] * jax.device_count()
35-
36-
videos = pipeline(
37-
prompt=prompt,
38-
negative_prompt=negative_prompt,
39-
height=config.height,
40-
width=config.width,
41-
num_frames=config.num_frames,
42-
num_inference_steps=config.num_inference_steps,
43-
guidance_scale=config.guidance_scale,
44-
slg_layers=slg_layers,
45-
slg_start=slg_start,
46-
slg_end=slg_end
47-
)
35+
negative_prompt = [config.negative_prompt] * jax.device_count()
4836

49-
print("compile time: ", (time.perf_counter() - s0))
50-
for i in range(len(videos)):
51-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
52-
s0 = time.perf_counter()
53-
with jax.profiler.trace("/tmp/trace/"):
54-
videos = pipeline(
37+
videos = pipeline(
5538
prompt=prompt,
5639
negative_prompt=negative_prompt,
5740
height=config.height,
@@ -61,7 +44,25 @@ def run(config):
6144
guidance_scale=config.guidance_scale,
6245
slg_layers=slg_layers,
6346
slg_start=slg_start,
64-
slg_end=slg_end
47+
slg_end=slg_end,
48+
)
49+
50+
print("compile time: ", (time.perf_counter() - s0))
51+
for i in range(len(videos)):
52+
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
53+
s0 = time.perf_counter()
54+
with jax.profiler.trace("/tmp/trace/"):
55+
videos = pipeline(
56+
prompt=prompt,
57+
negative_prompt=negative_prompt,
58+
height=config.height,
59+
width=config.width,
60+
num_frames=config.num_frames,
61+
num_inference_steps=config.num_inference_steps,
62+
guidance_scale=config.guidance_scale,
63+
slg_layers=slg_layers,
64+
slg_start=slg_start,
65+
slg_end=slg_end,
6566
)
6667
print("generation time: ", (time.perf_counter() - s0))
6768
for i in range(len(videos)):
@@ -74,4 +75,4 @@ def main(argv: Sequence[str]) -> None:
7475

7576

7677
if __name__ == "__main__":
77-
app.run(main)
78+
app.run(main)

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -286,21 +286,23 @@ def get_dummy_flux_inputs(config, pipeline, batch_size):
286286

287287
return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states)
288288

289+
289290
def get_dummy_wan_inputs(config, pipeline, batch_size):
290291
latents = pipeline.prepare_latents(
291-
batch_size,
292-
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
293-
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
294-
height=config.height,
295-
width=config.width,
296-
num_frames=config.num_frames,
297-
num_channels_latents=pipeline.transformer.config.in_channels
292+
batch_size,
293+
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
294+
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
295+
height=config.height,
296+
width=config.width,
297+
num_frames=config.num_frames,
298+
num_channels_latents=pipeline.transformer.config.in_channels,
298299
)
299300
bsz = latents.shape[0]
300301
prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096))
301302
timesteps = jnp.array([0] * bsz, dtype=jnp.int32)
302303
return (latents, prompt_embeds, timesteps)
303304

305+
304306
def calculate_wan_tflops(config, pipeline, batch_size, rngs, train):
305307
"""
306308
Calculates jflux tflops.
@@ -309,10 +311,10 @@ def calculate_wan_tflops(config, pipeline, batch_size, rngs, train):
309311
"""
310312
(latents, prompt_embeds, timesteps) = get_dummy_wan_inputs(config, pipeline, batch_size)
311313
return max_utils.calculate_model_tflops(
312-
pipeline.transformer,
313-
314+
pipeline.transformer,
314315
)
315316

317+
316318
def calculate_flux_tflops(config, pipeline, batch_size, rngs, train):
317319
"""
318320
Calculates jflux tflops.

0 commit comments

Comments
 (0)