Skip to content

Commit 34199d3

Browse files
committed
enable vae slicing
1 parent 365605c commit 34199d3

2 files changed

Lines changed: 17 additions & 0 deletions

File tree

src/maxdiffusion/generate_ltx2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
119119
checkpoint_loader = LTX2Checkpointer(config=config)
120120
pipeline, _, _ = checkpoint_loader.load_checkpoint()
121121

122+
pipeline.enable_vae_slicing()
123+
pipeline.enable_vae_tiling()
124+
122125
s0 = time.perf_counter()
123126

124127
# Using global_batch_size_to_train_on to map prompts

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,20 @@ def _init_dummy_shape(node):
260260
return jnp.zeros(node.shape, dtype=node.dtype)
261261
return node
262262

263+
def enable_vae_slicing(self):
264+
self.vae.use_slicing = True
265+
266+
def disable_vae_slicing(self):
267+
self.vae.use_slicing = False
268+
269+
def enable_vae_tiling(self):
270+
if hasattr(self.vae, "enable_tiling"):
271+
self.vae.enable_tiling()
272+
self.vae.use_tiling = True
273+
274+
def disable_vae_tiling(self):
275+
self.vae.use_tiling = False
276+
263277
@classmethod
264278
def load_tokenizer(cls, config: HyperParameters):
265279
max_logging.log("Loading Gemma Tokenizer...")

0 commit comments

Comments
 (0)