Skip to content

Commit 7d4b2a9

Browse files
committed
moved upsampler
1 parent bb61ecb commit 7d4b2a9

1 file changed

Lines changed: 45 additions & 95 deletions

File tree

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 45 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,24 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import argparse
1514
import math
1615
import os
17-
import random
1816
from jax import Array
19-
from datetime import datetime
20-
from pathlib import Path
17+
from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler
2118
from diffusers import AutoencoderKL
2219
from typing import Optional, List, Union, Tuple
2320
from einops import rearrange
2421
import torch.nn.functional as F
2522
from diffusers.utils.torch_utils import randn_tensor
26-
# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27-
import yaml
28-
from transformers import (CLIPTokenizer, FlaxCLIPTextModel,
29-
T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)
30-
31-
32-
import imageio
33-
import json
34-
import numpy as np
35-
import torch
36-
from safetensors import safe_open
37-
from PIL import Image
3823
from transformers import (
39-
T5EncoderModel,
40-
T5Tokenizer,
24+
FlaxT5EncoderModel,
25+
AutoTokenizer,
4126
AutoModelForCausalLM,
4227
AutoProcessor,
43-
AutoTokenizer,
44-
)
28+
AutoTokenizer,)
29+
import json
30+
import numpy as np
31+
import torch
4532
from huggingface_hub import hf_hub_download
4633
from maxdiffusion.models.ltx_video.autoencoders.causal_video_autoencoder import (
4734
CausalVideoAutoencoder,
@@ -55,78 +42,29 @@
5542
normalize_latents,
5643
)
5744
from diffusers.image_processor import VaeImageProcessor
58-
from ltx_video.schedulers.rf import RectifiedFlowScheduler
5945
from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler
60-
import ltx_video.pipelines.crf_compressor as crf_compressor
6146
from maxdiffusion.models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt
6247
from math import e
6348
from types import NoneType
6449
from typing import Any, Dict
6550
import numpy as np
66-
import inspect
6751

6852
import jax
6953
import jax.numpy as jnp
7054
from jax.sharding import Mesh, PartitionSpec as P
7155
from typing import Optional, Union, List
72-
import torch
73-
from maxdiffusion.checkpointing import checkpointing_utils
74-
from flax.linen import partitioning as nn_partitioning
7556
from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier
76-
from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
7757
from ...pyconfig import HyperParameters
78-
# from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
7958
from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState
8059
from ...max_utils import (
8160
create_device_mesh,
8261
setup_initial_state,
8362
get_memory_allocations
8463
)
8564
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
86-
import os
8765
import json
8866
import functools
8967
import orbax.checkpoint as ocp
90-
import pickle
91-
92-
93-
class PickleCheckpointHandler(ocp.CheckpointHandler):
94-
def save(self, directory: str, item, args=None):
95-
os.makedirs(directory, exist_ok=True)
96-
with open(os.path.join(directory, 'checkpoint.pkl'), 'wb') as f:
97-
pickle.dump(item, f)
98-
99-
def restore(self, directory: str, args=None):
100-
with open(os.path.join(directory, 'checkpoint.pkl'), 'rb') as f:
101-
return pickle.load(f)
102-
103-
def structure(self, directory: str):
104-
return {} # not needed for simple pickle-based handling
105-
106-
107-
def save_tensor_dict(tensor_dict, timestep):
108-
base_dir = os.path.dirname(__file__)
109-
local_path = os.path.join(base_dir, f"schedulerTest{timestep}")
110-
111-
try:
112-
torch.save(tensor_dict, local_path)
113-
print(f"Dictionary of tensors saved to: {local_path}")
114-
except Exception as e:
115-
print(f"Error saving dictionary: {e}")
116-
raise
117-
118-
119-
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids):
120-
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
121-
print("fractional_coords.shape: ",
122-
fractional_coords.shape, fractional_coords.dtype)
123-
print("latents.shape: ", latents.shape, latents.dtype)
124-
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
125-
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
126-
# print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
127-
print("encoder_attention_segment_ids.shape: ",
128-
encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)
129-
13068

13169
def prepare_extra_step_kwargs(generator):
13270
extra_step_kwargs = {}
@@ -817,7 +755,6 @@ def __call__(
817755
skip_initial_inference_steps: int = 0,
818756
skip_final_inference_steps: int = 0,
819757
cfg_star_rescale: bool = False,
820-
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
821758
skip_block_list: Optional[Union[List[List[int]], List[int]]] = None,
822759
**kwargs,
823760
):
@@ -1076,7 +1013,6 @@ def __call__(
10761013
rescaling_scale = rescaling_scale,
10771014
batch_size=batch_size,
10781015
skip_layer_masks = skip_layer_masks,
1079-
skip_layer_strategy = skip_layer_strategy,
10801016
cfg_star_rescale = cfg_star_rescale
10811017
)
10821018

@@ -1149,7 +1085,6 @@ def transformer_forward_pass( # need to jit this? wan didnt
11491085
segment_ids,
11501086
encoder_attention_segment_ids,
11511087
skip_layer_mask,
1152-
skip_layer_strategy,
11531088
):
11541089
noise_pred = transformer.apply(
11551090
{"params": state.params},
@@ -1160,13 +1095,12 @@ def transformer_forward_pass( # need to jit this? wan didnt
11601095
segment_ids=segment_ids,
11611096
encoder_attention_segment_ids=encoder_attention_segment_ids,
11621097
skip_layer_mask=skip_layer_mask,
1163-
skip_layer_strategy=skip_layer_strategy
1164-
) # need .param here?
1098+
)
11651099
return noise_pred, state
11661100

11671101

11681102
def run_inference(
1169-
transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks, skip_layer_strategy, cfg_star_rescale
1103+
transformer_state, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, num_inference_steps, scheduler, segment_ids, encoder_attention_segment_ids, scheduler_state, do_classifier_free_guidance, num_conds, guidance_scale, do_spatio_temporal_guidance, stg_scale, do_rescaling, rescaling_scale, batch_size, skip_layer_masks,cfg_star_rescale
11701104
):
11711105
# do_classifier_free_guidance = guidance_scale > 1.0
11721106
# for step in range(num_inference_steps):
@@ -1206,7 +1140,7 @@ def run_inference(
12061140
skip_layer_masks[i]
12071141
if skip_layer_masks is not None
12081142
else None
1209-
), skip_layer_strategy = skip_layer_strategy)
1143+
))
12101144
# ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128))
12111145

12121146
# # latents = self.denoising
@@ -1294,6 +1228,31 @@ def adain_filter_latent(
12941228
return result
12951229

12961230
class LTXMultiScalePipeline:
1231+
1232+
@classmethod
1233+
def load_latent_upsampler(cls, config):
1234+
spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path
1235+
1236+
if spatial_upscaler_model_name_or_path and not os.path.isfile(
1237+
spatial_upscaler_model_name_or_path
1238+
):
1239+
spatial_upscaler_model_path = hf_hub_download(
1240+
repo_id="Lightricks/LTX-Video",
1241+
filename=spatial_upscaler_model_name_or_path,
1242+
local_dir= "/mnt/disks/diffusionproj",
1243+
repo_type="model",
1244+
)
1245+
else:
1246+
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
1247+
if not config.spatial_upscaler_model_path:
1248+
raise ValueError(
1249+
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
1250+
)
1251+
latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path)
1252+
latent_upsampler.eval()
1253+
return latent_upsampler
1254+
1255+
12971256
def _upsample_latents(
12981257
self, latest_upsampler: LatentUpsampler, latents: torch.Tensor
12991258
):
@@ -1309,37 +1268,29 @@ def _upsample_latents(
13091268
return upsampled_latents
13101269

13111270
def __init__(
1312-
self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler
1271+
self, video_pipeline: LTXVideoPipeline
13131272
):
13141273
self.video_pipeline = video_pipeline
13151274
self.vae = video_pipeline.vae
1316-
self.latent_upsampler = latent_upsampler
1317-
1275+
13181276
def __call__(
13191277
self,
13201278
height,
13211279
width,
13221280
num_frames,
1323-
is_video,
13241281
output_type,
13251282
generator,
13261283
config
13271284
) -> Any:
13281285

1286+
latent_upsampler = self.load_latent_upsampler(config)
13291287
original_output_type = output_type
1330-
original_width = width
1331-
original_height = height
1332-
x_width = int(width * config.downscale_factor)
1333-
downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor)
1334-
x_height = int(height * config.downscale_factor)
1335-
downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor)
1336-
#use original height and width here to see
13371288
output_type = 'latent'
1338-
result = self.video_pipeline(height=original_height, width=original_width, num_frames=num_frames,
1289+
result = self.video_pipeline(height=height, width=width, num_frames=num_frames,
13391290
is_video=True, output_type=output_type, generator=generator, guidance_scale = config.first_pass["guidance_scale"], stg_scale = config.first_pass["stg_scale"], rescaling_scale = config.first_pass["rescaling_scale"], skip_initial_inference_steps= config.first_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.first_pass["skip_final_inference_steps"],
1340-
num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"])
1291+
num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_block_list=config.first_pass["skip_block_list"])
13411292
latents = result
1342-
upsampled_latents = self._upsample_latents(self.latent_upsampler, latents)
1293+
upsampled_latents = self._upsample_latents(latent_upsampler, latents)
13431294
upsampled_latents = adain_filter_latent(
13441295
latents=upsampled_latents, reference_latents=latents
13451296
)
@@ -1348,20 +1299,19 @@ def __call__(
13481299

13491300
latents = upsampled_latents
13501301
output_type = original_output_type
1351-
width = downscaled_width * 2
1352-
height = downscaled_height * 2
1302+
13531303

1354-
result = self.video_pipeline(height=original_height*2, width=original_width*2, num_frames=num_frames,
1304+
result = self.video_pipeline(height=height*2, width=width*2, num_frames=num_frames,
13551305
is_video=True, output_type=output_type, latents = latents, generator=generator, guidance_scale = config.second_pass["guidance_scale"], stg_scale = config.second_pass["stg_scale"], rescaling_scale = config.second_pass["rescaling_scale"], skip_initial_inference_steps= config.second_pass["skip_initial_inference_steps"], skip_final_inference_steps= config.second_pass["skip_final_inference_steps"],
1356-
num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.second_pass["skip_block_list"])
1306+
num_inference_steps = config.second_pass["num_inference_steps"], guidance_timesteps = config.second_pass["guidance_timesteps"], cfg_star_rescale = config.second_pass["cfg_star_rescale"], skip_block_list=config.second_pass["skip_block_list"])
13571307

13581308
if original_output_type != "latent":
13591309
num_frames = result.shape[2]
13601310
videos = rearrange(result, "b c f h w -> (b f) c h w")
13611311

13621312
videos = F.interpolate(
13631313
videos,
1364-
size=(original_height, original_width),
1314+
size=(height, width),
13651315
mode="bilinear",
13661316
align_corners=False,
13671317
)

0 commit comments

Comments
 (0)