Skip to content

Commit 3e6499c

Browse files
committed
downloaded files
1 parent b3874f5 commit 3e6499c

18 files changed

Lines changed: 4317 additions & 31 deletions
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import Tuple, Union
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class CausalConv3d(nn.Module):
8+
def __init__(
9+
self,
10+
in_channels,
11+
out_channels,
12+
kernel_size: int = 3,
13+
stride: Union[int, Tuple[int]] = 1,
14+
dilation: int = 1,
15+
groups: int = 1,
16+
spatial_padding_mode: str = "zeros",
17+
**kwargs,
18+
):
19+
super().__init__()
20+
21+
self.in_channels = in_channels
22+
self.out_channels = out_channels
23+
24+
kernel_size = (kernel_size, kernel_size, kernel_size)
25+
self.time_kernel_size = kernel_size[0]
26+
27+
dilation = (dilation, 1, 1)
28+
29+
height_pad = kernel_size[1] // 2
30+
width_pad = kernel_size[2] // 2
31+
padding = (0, height_pad, width_pad)
32+
33+
self.conv = nn.Conv3d(
34+
in_channels,
35+
out_channels,
36+
kernel_size,
37+
stride=stride,
38+
dilation=dilation,
39+
padding=padding,
40+
padding_mode=spatial_padding_mode,
41+
groups=groups,
42+
)
43+
44+
def forward(self, x, causal: bool = True):
45+
if causal:
46+
first_frame_pad = x[:, :, :1, :, :].repeat(
47+
(1, 1, self.time_kernel_size - 1, 1, 1)
48+
)
49+
x = torch.concatenate((first_frame_pad, x), dim=2)
50+
else:
51+
first_frame_pad = x[:, :, :1, :, :].repeat(
52+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
53+
)
54+
last_frame_pad = x[:, :, -1:, :, :].repeat(
55+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
56+
)
57+
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
58+
x = self.conv(x)
59+
return x
60+
61+
@property
62+
def weight(self):
63+
return self.conv.weight

0 commit comments

Comments
 (0)