-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathmatern.jl
More file actions
130 lines (96 loc) · 3.9 KB
/
matern.jl
File metadata and controls
130 lines (96 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
MaternKernel(; ν::Real=1.5, metric=Euclidean())
Matérn kernel of order `ν` with respect to the `metric`.
# Definition
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the Matérn kernel of order
``\\nu > 0`` is defined as
```math
k(x,x';\\nu) = \\frac{2^{1-\\nu}}{\\Gamma(\\nu)}\\big(\\sqrt{2\\nu} d(x, x')\\big) K_\\nu\\big(\\sqrt{2\\nu} d(x, x')\\big),
```
where ``\\Gamma`` is the Gamma function and ``K_{\\nu}`` is the modified Bessel function of
the second kind of order ``\\nu``.
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
A Gaussian process with a Matérn kernel is ``\\lceil \\nu \\rceil - 1``-times
differentiable in the mean-square sense.
!!! note
Differentiation with respect to the order ν is not currently supported.
See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref)
"""
struct MaternKernel{Tν<:Real,M} <: SimpleKernel
ν::Vector{Tν}
metric::M
function MaternKernel(ν::Real, metric)
@check_args(MaternKernel, ν, ν > zero(ν), "ν > 0")
return new{typeof(ν),typeof(metric)}([ν], metric)
end
end
MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, metric)
@functor MaternKernel
# workaround for Zygote
# unclear why it's needed but it is fine since it's stated officially that we don't support differentiation with respect to ν
@inline _get_ν(k::MaternKernel) = only(k.ν)
function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel}
function _get_ν_pullback(Δ)
dν = ChainRulesCore.@not_implemented(
"derivatives of `MaternKernel` w.r.t. order `ν` are not implemented."
)
return NoTangent(), Tangent{T}(; ν=dν, metric=NoTangent())
end
return _get_ν(k), _get_ν_pullback
end
@inline function kappa(k::MaternKernel, d::Real)
result = _matern(_get_ν(k), d)
return ifelse(iszero(d), one(result), result)
end
function _matern(ν::Real, d::Real)
y = sqrt(2ν) * d
return exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(y) + log(besselk(ν, y)))
end
metric(k::MaternKernel) = k.metric
function Base.show(io::IO, κ::MaternKernel)
return print(io, "Matern Kernel (ν = ", only(κ.ν), ", metric = ", κ.metric, ")")
end
## Matern12Kernel = ExponentialKernel aliased in exponential.jl
"""
Matern32Kernel(; metric=Euclidean())
Matérn kernel of order ``3/2`` with respect to the `metric`.
# Definition
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the Matérn kernel of order ``3/2`` is
given by
```math
k(x, x') = \\big(1 + \\sqrt{3} d(x, x') \\big) \\exp\\big(- \\sqrt{3} d(x, x') \\big).
```
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
See also: [`MaternKernel`](@ref)
"""
struct Matern32Kernel{M} <: SimpleKernel
metric::M
end
Matern32Kernel(; metric=Euclidean()) = Matern32Kernel(metric)
kappa(::Matern32Kernel, d::Real) = (1 + sqrt(3) * d) * exp(-sqrt(3) * d)
metric(k::Matern32Kernel) = k.metric
function Base.show(io::IO, k::Matern32Kernel)
return print(io, "Matern 3/2 Kernel (metric = ", k.metric, ")")
end
"""
Matern52Kernel(; metric=Euclidean())
Matérn kernel of order ``5/2`` with respect to the `metric`.
# Definition
For inputs ``x, x'`` and metric ``d(\\cdot, \\cdot)``, the Matérn kernel of order ``5/2`` is
given by
```math
k(x, x') = \\bigg(1 + \\sqrt{5} d(x, x') + \\frac{5}{3} d(x, x')^2\\bigg)
\\exp\\big(- \\sqrt{5} d(x, x') \\big).
```
By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
See also: [`MaternKernel`](@ref)
"""
struct Matern52Kernel{M} <: SimpleKernel
metric::M
end
Matern52Kernel(; metric=Euclidean()) = Matern52Kernel(metric)
kappa(::Matern52Kernel, d::Real) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d)
metric(k::Matern52Kernel) = k.metric
function Base.show(io::IO, k::Matern52Kernel)
return print(io, "Matern 5/2 Kernel (metric = ", k.metric, ")")
end