Skip to content

Commit f9b6ff9

Browse files
authored
Shard video_condition to prevent OOM in WAN 2.2 I2V (#313)
* sharding before call to vae encoder * sharding before call to vae encoder * ruff check * pyink checks * pyink check
1 parent 042932e commit f9b6ff9

20 files changed

Lines changed: 20 additions & 1 deletion

code_style.sh

100644100755
File mode changed.

src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import inspect
1718
from importlib import import_module
1819
from typing import Any, Dict, Optional, Tuple

src/maxdiffusion/pedagogical_examples/attention_comparison.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import os
1718
import time
1819

src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import os
1718
import argparse
1819
import tensorflow as tf

src/maxdiffusion/pedagogical_examples/parameter_count.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from typing import Sequence
1718
from absl import app
1819
import jax

src/maxdiffusion/pipelines/controlnet/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from typing import TYPE_CHECKING
1718

1819
from ...utils import (

src/maxdiffusion/pipelines/flux/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
_import_structure = {"pipeline_jflux": "JfluxPipeline"}
1718

1819
from .flux_pipeline import (

src/maxdiffusion/pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from typing import TYPE_CHECKING
1718

1819
from ...utils import (

src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from typing import TYPE_CHECKING
1718

1819
from ...utils import (

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,9 @@ def prepare_latents_i2v_base(
543543

544544
vae_dtype = getattr(self.vae, "dtype", jnp.float32)
545545
video_condition = video_condition.astype(vae_dtype)
546-
547546
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
547+
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
548+
video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec)
548549
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()
549550

550551
# Normalize latents

0 commit comments

Comments
 (0)