-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathlorentz_MLR.py
More file actions
52 lines (40 loc) · 1.59 KB
/
lorentz_MLR.py
File metadata and controls
52 lines (40 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
Lorentz Multinomial Logistic Regression (MLR) module.
Based on:
- Fully Hyperbolic Convolutional Neural Networks for Computer Vision (https://arxiv.org/abs/2303.15919)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from ...manifolds import Lorentz
class LorentzMLR(nn.Module):
""" Multinomial logistic regression (MLR) in the Lorentz model
"""
def __init__(
self,
manifold: Lorentz,
num_features: int,
num_classes: int
):
super(LorentzMLR, self).__init__()
self.manifold = manifold
self.a = torch.nn.Parameter(torch.zeros(num_classes,))
self.z = torch.nn.Parameter(F.pad(torch.zeros(num_classes, num_features-2), pad=(1,0), value=1)) # z should not be (0,0)
self.init_weights()
self.c = manifold.c
def forward(self, x):
# Hyperplane
sqrt_mK = 1/self.c.sqrt()
norm_z = torch.norm(self.z, dim=-1)
w_t = (torch.sinh(sqrt_mK*self.a)*norm_z)
w_s = torch.cosh(sqrt_mK*self.a.view(-1,1))*self.z
beta = torch.sqrt(-w_t**2+torch.norm(w_s, dim=-1)**2)
alpha = -w_t*x.narrow(-1, 0, 1) + (torch.cosh(sqrt_mK*self.a)*torch.inner(x.narrow(-1, 1, x.shape[-1]-1), self.z))
d = self.c.sqrt()*torch.abs(torch.asinh(sqrt_mK*alpha/beta)) # Distance to hyperplane
logits = torch.sign(alpha)*beta*d
return logits
def init_weights(self):
stdv = 1. / math.sqrt(self.z.size(1))
nn.init.uniform_(self.z, -stdv, stdv)
nn.init.uniform_(self.a, -stdv, stdv)