Skip to content

Commit c18c0c6

Browse files
committed
fixed importing error
1 parent f23eeef commit c18c0c6

10 files changed

Lines changed: 2000 additions & 28 deletions

File tree

setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,4 @@ else
110110
fi
111111

112112
# Install maxdiffusion
113-
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
113+
pip3 install -U . || echo "Failed to install maxdiffusion" >&2

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ activations_dtype: 'bfloat16'
99

1010
run_name: ''
1111
output_dir: ''
12+
config_path: ''
1213
save_config_to_gcs: False
1314

1415
#Checkpoints

src/maxdiffusion/generate_ltx_video.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,6 @@ def run(config):
9797
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt)
9898
if config.pipeline_type == "multi-scale":
9999
pipeline = LTXMultiScalePipeline(pipeline)
100-
# s0 = time.perf_counter()
101-
# images = pipeline(
102-
# height=height_padded,
103-
# width=width_padded,
104-
# num_frames=num_frames_padded,
105-
# is_video=True,
106-
# output_type="pt",
107-
# config=config,
108-
# enhance_prompt=enhance_prompt,
109-
# seed = config.seed
110-
# )
111-
# print("compile time: ", (time.perf_counter() - s0))
112100
s0 = time.perf_counter()
113101
images = pipeline(
114102
height=height_padded,

src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py

Whitespace-only changes.

src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py

Lines changed: 1264 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
2+
import math
3+
4+
import numpy as np
5+
import torch
6+
from einops import rearrange
7+
from torch import nn
8+
9+
10+
def get_timestep_embedding(
11+
timesteps: torch.Tensor,
12+
embedding_dim: int,
13+
flip_sin_to_cos: bool = False,
14+
downscale_freq_shift: float = 1,
15+
scale: float = 1,
16+
max_period: int = 10000,
17+
):
18+
"""
19+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
20+
21+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
22+
These may be fractional.
23+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
24+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
25+
"""
26+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
27+
28+
half_dim = embedding_dim // 2
29+
exponent = -math.log(max_period) * torch.arange(
30+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
31+
)
32+
exponent = exponent / (half_dim - downscale_freq_shift)
33+
34+
emb = torch.exp(exponent)
35+
emb = timesteps[:, None].float() * emb[None, :]
36+
37+
# scale embeddings
38+
emb = scale * emb
39+
40+
# concat sine and cosine embeddings
41+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
42+
43+
# flip sine and cosine embeddings
44+
if flip_sin_to_cos:
45+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
46+
47+
# zero pad
48+
if embedding_dim % 2 == 1:
49+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
50+
return emb
51+
52+
53+
def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
54+
"""
55+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
56+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
57+
"""
58+
grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
59+
grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
60+
grid = grid.reshape([3, 1, w, h, f])
61+
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
62+
pos_embed = pos_embed.transpose(1, 0, 2, 3)
63+
return rearrange(pos_embed, "h w f c -> (f h w) c")
64+
65+
66+
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
67+
if embed_dim % 3 != 0:
68+
raise ValueError("embed_dim must be divisible by 3")
69+
70+
# use half of dimensions to encode grid_h
71+
emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
72+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
73+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
74+
75+
emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
76+
return emb
77+
78+
79+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
80+
"""
81+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
82+
"""
83+
if embed_dim % 2 != 0:
84+
raise ValueError("embed_dim must be divisible by 2")
85+
86+
omega = np.arange(embed_dim // 2, dtype=np.float64)
87+
omega /= embed_dim / 2.0
88+
omega = 1.0 / 10000**omega # (D/2,)
89+
90+
pos_shape = pos.shape
91+
92+
pos = pos.reshape(-1)
93+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
94+
out = out.reshape([*pos_shape, -1])[0]
95+
96+
emb_sin = np.sin(out) # (M, D/2)
97+
emb_cos = np.cos(out) # (M, D/2)
98+
99+
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
100+
return emb
101+
102+
103+
class SinusoidalPositionalEmbedding(nn.Module):
104+
"""Apply positional information to a sequence of embeddings.
105+
106+
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
107+
them
108+
109+
Args:
110+
embed_dim: (int): Dimension of the positional embedding.
111+
max_seq_length: Maximum sequence length to apply positional embeddings
112+
113+
"""
114+
115+
def __init__(self, embed_dim: int, max_seq_length: int = 32):
116+
super().__init__()
117+
position = torch.arange(max_seq_length).unsqueeze(1)
118+
div_term = torch.exp(
119+
torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
120+
)
121+
pe = torch.zeros(1, max_seq_length, embed_dim)
122+
pe[0, :, 0::2] = torch.sin(position * div_term)
123+
pe[0, :, 1::2] = torch.cos(position * div_term)
124+
self.register_buffer("pe", pe)
125+
126+
def forward(self, x):
127+
_, seq_length, _ = x.shape
128+
x = x + self.pe[:, :seq_length]
129+
return x
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Tuple
3+
4+
import torch
5+
from diffusers.configuration_utils import ConfigMixin
6+
from einops import rearrange
7+
from torch import Tensor
8+
9+
10+
class Patchifier(ConfigMixin, ABC):
11+
def __init__(self, patch_size: int):
12+
super().__init__()
13+
self._patch_size = (1, patch_size, patch_size)
14+
15+
@abstractmethod
16+
def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
17+
raise NotImplementedError("Patchify method not implemented")
18+
19+
@abstractmethod
20+
def unpatchify(
21+
self,
22+
latents: Tensor,
23+
output_height: int,
24+
output_width: int,
25+
out_channels: int,
26+
) -> Tuple[Tensor, Tensor]:
27+
pass
28+
29+
@property
30+
def patch_size(self):
31+
return self._patch_size
32+
33+
def get_latent_coords(
34+
self, latent_num_frames, latent_height, latent_width, batch_size, device
35+
):
36+
"""
37+
Return a tensor of shape [batch_size, 3, num_patches] containing the
38+
top-left corner latent coordinates of each latent patch.
39+
The tensor is repeated for each batch element.
40+
"""
41+
latent_sample_coords = torch.meshgrid(
42+
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
43+
torch.arange(0, latent_height, self._patch_size[1], device=device),
44+
torch.arange(0, latent_width, self._patch_size[2], device=device),
45+
)
46+
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
47+
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
48+
latent_coords = rearrange(
49+
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
50+
)
51+
return latent_coords
52+
53+
54+
class SymmetricPatchifier(Patchifier):
55+
def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
56+
b, _, f, h, w = latents.shape
57+
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
58+
latents = rearrange(
59+
latents,
60+
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
61+
p1=self._patch_size[0],
62+
p2=self._patch_size[1],
63+
p3=self._patch_size[2],
64+
)
65+
return latents, latent_coords
66+
67+
def unpatchify(
68+
self,
69+
latents: Tensor,
70+
output_height: int,
71+
output_width: int,
72+
out_channels: int,
73+
) -> Tuple[Tensor, Tensor]:
74+
output_height = output_height // self._patch_size[1]
75+
output_width = output_width // self._patch_size[2]
76+
latents = rearrange(
77+
latents,
78+
"b (f h w) (c p q) -> b c f (h p) (w q)",
79+
h=output_height,
80+
w=output_width,
81+
p=self._patch_size[1],
82+
q=self._patch_size[2],
83+
)
84+
return latents

0 commit comments

Comments
 (0)