diff --git a/code_style.sh b/code_style.sh old mode 100644 new mode 100755 diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py index ceede9433..93b16851d 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import inspect from importlib import import_module from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/pedagogical_examples/attention_comparison.py b/src/maxdiffusion/pedagogical_examples/attention_comparison.py index 07831550f..6981e092f 100644 --- a/src/maxdiffusion/pedagogical_examples/attention_comparison.py +++ b/src/maxdiffusion/pedagogical_examples/attention_comparison.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import time diff --git a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py index 2300d0bd1..cc547cc86 100644 --- a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py +++ b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import argparse import tensorflow as tf diff --git a/src/maxdiffusion/pedagogical_examples/parameter_count.py b/src/maxdiffusion/pedagogical_examples/parameter_count.py index cf9f8b8dc..e9fe4542c 100644 --- a/src/maxdiffusion/pedagogical_examples/parameter_count.py +++ b/src/maxdiffusion/pedagogical_examples/parameter_count.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import Sequence from absl import app import jax diff --git a/src/maxdiffusion/pipelines/controlnet/__init__.py b/src/maxdiffusion/pipelines/controlnet/__init__.py index 0cf92cd44..fe7070b07 100644 --- a/src/maxdiffusion/pipelines/controlnet/__init__.py +++ b/src/maxdiffusion/pipelines/controlnet/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/flux/__init__.py b/src/maxdiffusion/pipelines/flux/__init__.py index c39cc364d..39ea05b57 100644 --- a/src/maxdiffusion/pipelines/flux/__init__.py +++ b/src/maxdiffusion/pipelines/flux/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + _import_structure = {"pipeline_jflux": "JfluxPipeline"} from .flux_pipeline import ( diff --git a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py index cbef1d5fa..564b0dfa7 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py index 13e201aea..1ae1b6414 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 415bcfeaa..f0621c8ea 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -543,8 +543,9 @@ def prepare_latents_i2v_base( vae_dtype = getattr(self.vae, "dtype", jnp.float32) video_condition = video_condition.astype(vae_dtype) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) + video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec) encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() # Normalize latents diff --git a/src/maxdiffusion/tests/configuration_utils_test.py b/src/maxdiffusion/tests/configuration_utils_test.py index ee761a383..df46ea755 100644 --- a/src/maxdiffusion/tests/configuration_utils_test.py +++ b/src/maxdiffusion/tests/configuration_utils_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import json import os diff --git a/src/maxdiffusion/tests/flop_calculations_test.py b/src/maxdiffusion/tests/flop_calculations_test.py index ca0a5020c..a58d5dcce 100644 --- a/src/maxdiffusion/tests/flop_calculations_test.py +++ b/src/maxdiffusion/tests/flop_calculations_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest from unittest.mock import Mock diff --git a/src/maxdiffusion/tests/generate_flux_smoke_test.py b/src/maxdiffusion/tests/generate_flux_smoke_test.py index 12bfe77b8..4f1747161 100644 --- a/src/maxdiffusion/tests/generate_flux_smoke_test.py +++ b/src/maxdiffusion/tests/generate_flux_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py index ff823010e..e2b4d772c 100644 --- a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py +++ b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_smoke_test.py b/src/maxdiffusion/tests/generate_smoke_test.py index 2c5b783a8..b6722b3ac 100644 --- a/src/maxdiffusion/tests/generate_smoke_test.py +++ b/src/maxdiffusion/tests/generate_smoke_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os import unittest import pytest diff --git a/src/maxdiffusion/utils/deprecation_utils.py b/src/maxdiffusion/utils/deprecation_utils.py index 265a60b52..a7077ed7f 100644 --- a/src/maxdiffusion/utils/deprecation_utils.py +++ b/src/maxdiffusion/utils/deprecation_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import inspect import warnings from typing import Any, Dict, Optional, Union diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index 51b05d30c..fa394129f 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import io import random import struct diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index f2b72cbd5..735d22615 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import os from typing import Callable, List, Optional, Union diff --git a/src/maxdiffusion/utils/pil_utils.py b/src/maxdiffusion/utils/pil_utils.py index a05aa47a2..86d07c66a 100644 --- a/src/maxdiffusion/utils/pil_utils.py +++ b/src/maxdiffusion/utils/pil_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from typing import List import PIL.Image diff --git a/src/maxdiffusion/utils/testing_utils.py b/src/maxdiffusion/utils/testing_utils.py index 6194a03a5..55be62ace 100644 --- a/src/maxdiffusion/utils/testing_utils.py +++ b/src/maxdiffusion/utils/testing_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import functools import importlib import inspect