Skip to content

Commit 4984504

Browse files
committed
pyink check
1 parent 04c2320 commit 4984504

11 files changed

Lines changed: 289 additions & 324 deletions

File tree

code_style.sh

100644100755
File mode changed.

src/maxdiffusion/__init__.py

Lines changed: 182 additions & 196 deletions
Large diffs are not rendered by default.

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -202,29 +202,27 @@ def setup(self):
202202
dtype=self.dtype,
203203
param_dtype=self.weights_dtype,
204204
)
205-
self.img_mlp = nn.Sequential(
206-
[
207-
nn.Dense(
208-
int(self.dim * self.mlp_ratio),
209-
use_bias=True,
210-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
211-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
212-
dtype=self.dtype,
213-
param_dtype=self.weights_dtype,
214-
precision=self.precision,
215-
),
216-
nn.gelu,
217-
nn.Dense(
218-
self.dim,
219-
use_bias=True,
220-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
221-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
222-
dtype=self.dtype,
223-
param_dtype=self.weights_dtype,
224-
precision=self.precision,
225-
),
226-
]
227-
)
205+
self.img_mlp = nn.Sequential([
206+
nn.Dense(
207+
int(self.dim * self.mlp_ratio),
208+
use_bias=True,
209+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
210+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
211+
dtype=self.dtype,
212+
param_dtype=self.weights_dtype,
213+
precision=self.precision,
214+
),
215+
nn.gelu,
216+
nn.Dense(
217+
self.dim,
218+
use_bias=True,
219+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
220+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
221+
dtype=self.dtype,
222+
param_dtype=self.weights_dtype,
223+
precision=self.precision,
224+
),
225+
])
228226

229227
self.txt_norm2 = nn.LayerNorm(
230228
use_bias=False,
@@ -233,29 +231,27 @@ def setup(self):
233231
dtype=self.dtype,
234232
param_dtype=self.weights_dtype,
235233
)
236-
self.txt_mlp = nn.Sequential(
237-
[
238-
nn.Dense(
239-
int(self.dim * self.mlp_ratio),
240-
use_bias=True,
241-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
242-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
243-
dtype=self.dtype,
244-
param_dtype=self.weights_dtype,
245-
precision=self.precision,
246-
),
247-
nn.gelu,
248-
nn.Dense(
249-
self.dim,
250-
use_bias=True,
251-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
252-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
253-
dtype=self.dtype,
254-
param_dtype=self.weights_dtype,
255-
precision=self.precision,
256-
),
257-
]
258-
)
234+
self.txt_mlp = nn.Sequential([
235+
nn.Dense(
236+
int(self.dim * self.mlp_ratio),
237+
use_bias=True,
238+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
239+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
240+
dtype=self.dtype,
241+
param_dtype=self.weights_dtype,
242+
precision=self.precision,
243+
),
244+
nn.gelu,
245+
nn.Dense(
246+
self.dim,
247+
use_bias=True,
248+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
249+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
250+
dtype=self.dtype,
251+
param_dtype=self.weights_dtype,
252+
precision=self.precision,
253+
),
254+
])
259255

260256
# let chunk size default to None
261257
self._chunk_size = None

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,13 +460,11 @@ def __call__(
460460

461461
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
462462
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
463-
control_hidden_states_padding = jnp.zeros(
464-
(
465-
batch_size,
466-
control_hidden_states.shape[1],
467-
hidden_states.shape[2] - control_hidden_states.shape[2],
468-
)
469-
)
463+
control_hidden_states_padding = jnp.zeros((
464+
batch_size,
465+
control_hidden_states.shape[1],
466+
hidden_states.shape[2] - control_hidden_states.shape[2],
467+
))
470468

471469
control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2)
472470

src/maxdiffusion/pedagogical_examples/to_tfrecords.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,12 @@
5454
dl_manager = tfds.download.DownloadManager(download_dir="/tmp")
5555
tmp_dataset = "dataset"
5656

57-
TRANSFORMS = transforms.Compose(
58-
[
59-
transforms.ToTensor(),
60-
transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC),
61-
transforms.CenterCrop(size=512),
62-
transforms.Normalize([0.5], [0.5]),
63-
]
64-
)
57+
TRANSFORMS = transforms.Compose([
58+
transforms.ToTensor(),
59+
transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC),
60+
transforms.CenterCrop(size=512),
61+
transforms.Normalize([0.5], [0.5]),
62+
])
6563

6664

6765
def delete_files(path):

src/maxdiffusion/pipelines/__init__.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,14 @@
5151

5252
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
5353
else:
54-
_import_structure["stable_diffusion"].extend(
55-
[
56-
"OnnxStableDiffusionImg2ImgPipeline",
57-
"OnnxStableDiffusionInpaintPipeline",
58-
"OnnxStableDiffusionInpaintPipelineLegacy",
59-
"OnnxStableDiffusionPipeline",
60-
"OnnxStableDiffusionUpscalePipeline",
61-
"StableDiffusionOnnxPipeline",
62-
]
63-
)
54+
_import_structure["stable_diffusion"].extend([
55+
"OnnxStableDiffusionImg2ImgPipeline",
56+
"OnnxStableDiffusionInpaintPipeline",
57+
"OnnxStableDiffusionInpaintPipelineLegacy",
58+
"OnnxStableDiffusionPipeline",
59+
"OnnxStableDiffusionUpscalePipeline",
60+
"StableDiffusionOnnxPipeline",
61+
])
6462

6563
try:
6664
if not is_flax_available():
@@ -82,18 +80,14 @@
8280
_import_structure["controlnet"].extend(
8381
["FlaxStableDiffusionControlNetPipeline", "FlaxStableDiffusionXLControlNetPipeline"]
8482
)
85-
_import_structure["stable_diffusion"].extend(
86-
[
87-
"FlaxStableDiffusionImg2ImgPipeline",
88-
"FlaxStableDiffusionInpaintPipeline",
89-
"FlaxStableDiffusionPipeline",
90-
]
91-
)
92-
_import_structure["stable_diffusion_xl"].extend(
93-
[
94-
"FlaxStableDiffusionXLPipeline",
95-
]
96-
)
83+
_import_structure["stable_diffusion"].extend([
84+
"FlaxStableDiffusionImg2ImgPipeline",
85+
"FlaxStableDiffusionInpaintPipeline",
86+
"FlaxStableDiffusionPipeline",
87+
])
88+
_import_structure["stable_diffusion_xl"].extend([
89+
"FlaxStableDiffusionXLPipeline",
90+
])
9791
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
9892
try:
9993
if not is_onnx_available():

src/maxdiffusion/pipelines/stable_diffusion/__init__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,11 @@
8585
StableDiffusionPix2PixZeroPipeline,
8686
)
8787

88-
_dummy_objects.update(
89-
{
90-
"StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline,
91-
"StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline,
92-
"StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline,
93-
}
94-
)
88+
_dummy_objects.update({
89+
"StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline,
90+
"StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline,
91+
"StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline,
92+
})
9593
else:
9694
_import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"]
9795
_import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]

src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -528,13 +528,11 @@ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
528528
)
529529

530530
def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
531-
timestep_list = jnp.array(
532-
[
533-
state.timesteps[step_index - 2],
534-
state.timesteps[step_index - 1],
535-
state.timesteps[step_index],
536-
]
537-
)
531+
timestep_list = jnp.array([
532+
state.timesteps[step_index - 2],
533+
state.timesteps[step_index - 1],
534+
state.timesteps[step_index],
535+
])
538536
return self.multistep_dpm_solver_third_order_update(
539537
state,
540538
state.model_outputs,

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,11 @@ def __init__(
136136
if self.config.use_beta_sigmas and not is_scipy_available():
137137
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
138138
if (
139-
sum(
140-
[
141-
self.config.use_beta_sigmas,
142-
self.config.use_exponential_sigmas,
143-
self.config.use_karras_sigmas,
144-
]
145-
)
139+
sum([
140+
self.config.use_beta_sigmas,
141+
self.config.use_exponential_sigmas,
142+
self.config.use_karras_sigmas,
143+
])
146144
> 1
147145
):
148146
raise ValueError(

src/maxdiffusion/schedulers/scheduling_utils_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def create(cls, scheduler):
262262
elif config.beta_schedule == "scaled_linear":
263263
# this schedule is very specific to the latent diffusion model.
264264
betas = (
265-
jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) ** 2
265+
jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype)
266+
** 2
266267
)
267268
elif config.beta_schedule == "squaredcos_cap_v2":
268269
# Glide cosine schedule

0 commit comments

Comments
 (0)