Skip to content

Commit 52d415b

Browse files
committed
ruff checks
1 parent f9cc8f8 commit 52d415b

9 files changed

Lines changed: 33 additions & 34 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,4 @@ def config_to_json(model_or_config):
9292

9393
# Save the checkpoint
9494
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
95-
max_logging.log(f"Checkpoint for step {train_step} saved.")
95+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@ def config_to_json(model_or_config):
111111

112112
# Save the checkpoint
113113
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
114-
max_logging.log(f"Checkpoint for step {train_step} saved.")
114+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
391391
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
392392
if not is_sparse:
393393
# down_weight is copied to each split
394-
ait_sd.update({k: down_weight for k in ait_down_keys})
394+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
395395

396396
# up_weight is split to each split
397397
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -534,7 +534,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
534534
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
535535

536536
# down_weight is copied to each split
537-
ait_sd.update({k: down_weight for k in ait_down_keys})
537+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
538538

539539
# up_weight is split to each split
540540
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@ def __init__(
10101010
dtype=dtype, param_dtype=weights_dtype, precision=precision,
10111011
bias_init=nnx.with_partitioning(
10121012
nnx.initializers.zeros,
1013-
("embed",),
1013+
("embed",),
10141014
),
10151015
)
10161016
self.add_v_proj = nnx.Linear(
@@ -1129,7 +1129,7 @@ def __call__(
11291129
encoder_hidden_states_img = None
11301130
encoder_hidden_states_text = encoder_hidden_states
11311131
encoder_attention_mask_img = None
1132-
1132+
11331133
if self.qk_norm:
11341134
with self.conditional_named_scope("attn_q_norm"):
11351135
query_proj_text = self.norm_q(query_proj_raw)

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def load_base_wan_transformer(
277277
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
278278
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
279279
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
280-
280+
281281
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
282282
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
283283
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

src/maxdiffusion/pipelines/pipeline_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def load_module(name, value):
473473
class_obj = import_flax_or_no_model(pipeline_module, class_name)
474474

475475
importable_classes = ALL_IMPORTABLE_CLASSES
476-
class_candidates = {c: class_obj for c in importable_classes.keys()}
476+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
477477
else:
478478
# else we just import it from the library.
479479

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def load_scheduler(cls, config):
399399
flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p
400400
)
401401
return scheduler, scheduler_state
402-
402+
403403
def encode_image(self, image: PipelineImageInput, num_videos_per_prompt: int = 1):
404404
if not isinstance(image, list):
405405
image = [image]
@@ -516,7 +516,7 @@ def prepare_latents_i2v_base(
516516
"""
517517
height, width = image.shape[-2:]
518518
image = image[:, :, jnp.newaxis, :, :] # [B, C, 1, H, W]
519-
519+
520520
if last_image is None:
521521
video_condition = jnp.concatenate(
522522
[image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2
@@ -574,7 +574,7 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
574574
"vae": wan_vae, "vae_cache": vae_cache,
575575
"devices_array": devices_array, "rngs": rngs, "mesh": mesh,
576576
"tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None,
577-
"image_processor": None, "image_encoder": None
577+
"image_processor": None, "image_encoder": None
578578
}
579579

580580
if not vae_only:
@@ -621,7 +621,7 @@ def _prepare_model_inputs_i2v(
621621
# 2. Encode Image (only for WAN 2.1 I2V which uses CLIP image embeddings)
622622
# WAN 2.2 I2V does not use CLIP image embeddings, it uses VAE latent conditioning instead
623623
transformer_dtype = self.config.activations_dtype
624-
624+
625625
if self.config.model_name == "wan2.1":
626626
# WAN 2.1 I2V: Use CLIP image encoder
627627
if image_embeds is None:
@@ -635,7 +635,7 @@ def _prepare_model_inputs_i2v(
635635

636636
if batch_size > 1:
637637
image_embeds = jnp.tile(image_embeds, (batch_size, 1, 1))
638-
638+
639639
image_embeds = image_embeds.astype(transformer_dtype)
640640
else:
641641
# WAN 2.2 I2V: No CLIP image embeddings, set to None or empty tensor

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import List, Union, Optional, Tuple
2020
from ...pyconfig import HyperParameters
2121
from functools import partial
22-
import numpy as np
2322
from flax import nnx
2423
from flax.linen import partitioning as nn_partitioning
2524
import jax
@@ -88,7 +87,7 @@ def prepare_latents(
8887
last_image: Optional[jax.Array] = None,
8988
num_videos_per_prompt: int = 1,
9089
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:
91-
90+
9291
if hasattr(image, "detach"):
9392
image = image.detach().cpu().numpy()
9493
image = jnp.array(image)
@@ -97,12 +96,12 @@ def prepare_latents(
9796
if hasattr(last_image, "detach"):
9897
last_image = last_image.detach().cpu().numpy()
9998
last_image = jnp.array(last_image)
100-
99+
101100
if num_videos_per_prompt > 1:
102101
image = jnp.repeat(image, num_videos_per_prompt, axis=0)
103102
if last_image is not None:
104103
last_image = jnp.repeat(last_image, num_videos_per_prompt, axis=0)
105-
104+
106105
num_channels_latents = self.vae.z_dim
107106
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
108107
latent_height = height // self.vae_scale_factor_spatial
@@ -119,16 +118,16 @@ def prepare_latents(
119118
if last_image is None:
120119
mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0)
121120
else:
122-
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
121+
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
123122
first_frame_mask = mask_lat_size[:, :, 0:1]
124123
first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2)
125124
mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2)
126125
mask_lat_size = mask_lat_size.reshape(
127-
batch_size,
126+
batch_size,
128127
1,
129-
num_latent_frames,
130-
self.vae_scale_factor_temporal,
131-
latent_height,
128+
num_latent_frames,
129+
self.vae_scale_factor_temporal,
130+
latent_height,
132131
latent_width
133132
)
134133
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1)
@@ -210,7 +209,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
210209
scheduler_state = self.scheduler.set_timesteps(
211210
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
212211
)
213-
212+
214213
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
215214
data_sharding = NamedSharding(self.mesh, P())
216215
if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0:
@@ -234,7 +233,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
234233
scheduler=self.scheduler,
235234
)
236235

237-
236+
238237
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
239238
latents = p_run_inference(
240239
latents=latents,
@@ -246,7 +245,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
246245
)
247246
latents = jnp.transpose(latents, (0, 4, 1, 2, 3))
248247
latents = self._denormalize_latents(latents)
249-
248+
250249
if output_type == "latent":
251250
return latents
252251
return self._decode_latents_to_video(latents)
@@ -287,5 +286,5 @@ def run_inference_2_1_i2v(
287286
encoder_hidden_states_image=image_embeds,
288287
)
289288
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
290-
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents, return_dict=False)
291-
return latents
289+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents, return_dict=False)
290+
return latents

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def prepare_latents(
8787
last_image: Optional[jax.Array] = None,
8888
num_videos_per_prompt: int = 1,
8989
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:
90-
90+
9191
if hasattr(image, "detach"):
9292
image = image.detach().cpu().numpy()
9393
image = jnp.array(image)
@@ -109,12 +109,12 @@ def prepare_latents(
109109
else:
110110
latents = latents.astype(dtype)
111111

112-
latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image)
112+
latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image)
113113
mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype)
114114
if last_image is None:
115115
mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0)
116116
else:
117-
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
117+
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
118118

119119
first_frame_mask = mask_lat_size[:, :, 0:1]
120120
first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2)
@@ -123,9 +123,9 @@ def prepare_latents(
123123
batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width
124124
)
125125
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1)
126-
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1)
126+
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1)
127127
return latents, condition, None
128-
128+
129129
def __call__(
130130
self,
131131
prompt: Union[str, List[str]],
@@ -297,7 +297,7 @@ def low_noise_branch(operands):
297297
latents_input = jnp.concatenate([latents, latents], axis=0)
298298
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
299299
timestep = jnp.broadcast_to(t, latents_input.shape[0])
300-
300+
301301
use_high_noise = jnp.greater_equal(t, boundary)
302302
noise_pred, _ = jax.lax.cond(
303303
use_high_noise,
@@ -307,4 +307,4 @@ def low_noise_branch(operands):
307307
)
308308
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
309309
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
310-
return latents
310+
return latents

0 commit comments

Comments
 (0)