Skip to content

Commit 04c2320

Browse files
committed
pyink checks
1 parent 42d9366 commit 04c2320

28 files changed

Lines changed: 377 additions & 326 deletions

src/maxdiffusion/__init__.py

Lines changed: 196 additions & 182 deletions
Large diffs are not rendered by default.

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -202,27 +202,29 @@ def setup(self):
202202
dtype=self.dtype,
203203
param_dtype=self.weights_dtype,
204204
)
205-
self.img_mlp = nn.Sequential([
206-
nn.Dense(
207-
int(self.dim * self.mlp_ratio),
208-
use_bias=True,
209-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
210-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
211-
dtype=self.dtype,
212-
param_dtype=self.weights_dtype,
213-
precision=self.precision,
214-
),
215-
nn.gelu,
216-
nn.Dense(
217-
self.dim,
218-
use_bias=True,
219-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
220-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
221-
dtype=self.dtype,
222-
param_dtype=self.weights_dtype,
223-
precision=self.precision,
224-
),
225-
])
205+
self.img_mlp = nn.Sequential(
206+
[
207+
nn.Dense(
208+
int(self.dim * self.mlp_ratio),
209+
use_bias=True,
210+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
211+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
212+
dtype=self.dtype,
213+
param_dtype=self.weights_dtype,
214+
precision=self.precision,
215+
),
216+
nn.gelu,
217+
nn.Dense(
218+
self.dim,
219+
use_bias=True,
220+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
221+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
222+
dtype=self.dtype,
223+
param_dtype=self.weights_dtype,
224+
precision=self.precision,
225+
),
226+
]
227+
)
226228

227229
self.txt_norm2 = nn.LayerNorm(
228230
use_bias=False,
@@ -231,27 +233,29 @@ def setup(self):
231233
dtype=self.dtype,
232234
param_dtype=self.weights_dtype,
233235
)
234-
self.txt_mlp = nn.Sequential([
235-
nn.Dense(
236-
int(self.dim * self.mlp_ratio),
237-
use_bias=True,
238-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
239-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
240-
dtype=self.dtype,
241-
param_dtype=self.weights_dtype,
242-
precision=self.precision,
243-
),
244-
nn.gelu,
245-
nn.Dense(
246-
self.dim,
247-
use_bias=True,
248-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
249-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
250-
dtype=self.dtype,
251-
param_dtype=self.weights_dtype,
252-
precision=self.precision,
253-
),
254-
])
236+
self.txt_mlp = nn.Sequential(
237+
[
238+
nn.Dense(
239+
int(self.dim * self.mlp_ratio),
240+
use_bias=True,
241+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
242+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
243+
dtype=self.dtype,
244+
param_dtype=self.weights_dtype,
245+
precision=self.precision,
246+
),
247+
nn.gelu,
248+
nn.Dense(
249+
self.dim,
250+
use_bias=True,
251+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
252+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
253+
dtype=self.dtype,
254+
param_dtype=self.weights_dtype,
255+
precision=self.precision,
256+
),
257+
]
258+
)
255259

256260
# let chunk size default to None
257261
self._chunk_size = None

src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import inspect
1718
from importlib import import_module
1819
from typing import Any, Dict, Optional, Tuple

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,13 @@ def __call__(
460460

461461
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
462462
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
463-
control_hidden_states_padding = jnp.zeros((
464-
batch_size,
465-
control_hidden_states.shape[1],
466-
hidden_states.shape[2] - control_hidden_states.shape[2],
467-
))
463+
control_hidden_states_padding = jnp.zeros(
464+
(
465+
batch_size,
466+
control_hidden_states.shape[1],
467+
hidden_states.shape[2] - control_hidden_states.shape[2],
468+
)
469+
)
468470

469471
control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2)
470472

src/maxdiffusion/pedagogical_examples/attention_comparison.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import os
1718
import time
1819

src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import os
1718
import argparse
1819
import tensorflow as tf

src/maxdiffusion/pedagogical_examples/parameter_count.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from typing import Sequence
1718
from absl import app
1819
import jax

src/maxdiffusion/pedagogical_examples/to_tfrecords.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@
5454
dl_manager = tfds.download.DownloadManager(download_dir="/tmp")
5555
tmp_dataset = "dataset"
5656

57-
TRANSFORMS = transforms.Compose([
58-
transforms.ToTensor(),
59-
transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC),
60-
transforms.CenterCrop(size=512),
61-
transforms.Normalize([0.5], [0.5]),
62-
])
57+
TRANSFORMS = transforms.Compose(
58+
[
59+
transforms.ToTensor(),
60+
transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC),
61+
transforms.CenterCrop(size=512),
62+
transforms.Normalize([0.5], [0.5]),
63+
]
64+
)
6365

6466

6567
def delete_files(path):

src/maxdiffusion/pipelines/__init__.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,16 @@
5151

5252
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
5353
else:
54-
_import_structure["stable_diffusion"].extend([
55-
"OnnxStableDiffusionImg2ImgPipeline",
56-
"OnnxStableDiffusionInpaintPipeline",
57-
"OnnxStableDiffusionInpaintPipelineLegacy",
58-
"OnnxStableDiffusionPipeline",
59-
"OnnxStableDiffusionUpscalePipeline",
60-
"StableDiffusionOnnxPipeline",
61-
])
54+
_import_structure["stable_diffusion"].extend(
55+
[
56+
"OnnxStableDiffusionImg2ImgPipeline",
57+
"OnnxStableDiffusionInpaintPipeline",
58+
"OnnxStableDiffusionInpaintPipelineLegacy",
59+
"OnnxStableDiffusionPipeline",
60+
"OnnxStableDiffusionUpscalePipeline",
61+
"StableDiffusionOnnxPipeline",
62+
]
63+
)
6264

6365
try:
6466
if not is_flax_available():
@@ -80,14 +82,18 @@
8082
_import_structure["controlnet"].extend(
8183
["FlaxStableDiffusionControlNetPipeline", "FlaxStableDiffusionXLControlNetPipeline"]
8284
)
83-
_import_structure["stable_diffusion"].extend([
84-
"FlaxStableDiffusionImg2ImgPipeline",
85-
"FlaxStableDiffusionInpaintPipeline",
86-
"FlaxStableDiffusionPipeline",
87-
])
88-
_import_structure["stable_diffusion_xl"].extend([
89-
"FlaxStableDiffusionXLPipeline",
90-
])
85+
_import_structure["stable_diffusion"].extend(
86+
[
87+
"FlaxStableDiffusionImg2ImgPipeline",
88+
"FlaxStableDiffusionInpaintPipeline",
89+
"FlaxStableDiffusionPipeline",
90+
]
91+
)
92+
_import_structure["stable_diffusion_xl"].extend(
93+
[
94+
"FlaxStableDiffusionXLPipeline",
95+
]
96+
)
9197
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
9298
try:
9399
if not is_onnx_available():

src/maxdiffusion/pipelines/controlnet/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from typing import TYPE_CHECKING
1718

1819
from ...utils import (

0 commit comments

Comments
 (0)