-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathardtransform.jl
More file actions
47 lines (34 loc) · 1.4 KB
/
ardtransform.jl
File metadata and controls
47 lines (34 loc) · 1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
ARDTransform(v::AbstractVector)
Transformation that multiplies the input elementwise by `v`.
# Examples
```jldoctest
julia> v = rand(10); t = ARDTransform(v); X = rand(10, 100);
julia> map(t, ColVecs(X)) == ColVecs(v .* X)
true
```
"""
struct ARDTransform{Tv<:AbstractVector{<:Real}} <: Transform
v::Tv
end
"""
ARDTransform(s::Real, dims::Integer)
Create an [`ARDTransform`](@ref) with vector `fill(s, dims)`.
"""
ARDTransform(s::Real, dims::Integer) = ARDTransform(fill(s, dims))
function ParameterHandling.flatten(::Type{T}, t::ARDTransform{S}) where {T<:Real,S}
unflatten_to_ardtransform(v::Vector{T}) = ARDTransform(convert(S, map(exp, v)))
return convert(Vector, map(T ∘ log, t.v)), unflatten_to_ardtransform
end
function set!(t::ARDTransform{<:AbstractVector{T}}, ρ::AbstractVector{T}) where {T<:Real}
@assert length(ρ) == dim(t) "Trying to set a vector of size $(length(ρ)) to ARDTransform of dimension $(dim(t))"
return t.v .= ρ
end
dim(t::ARDTransform) = length(t.v)
(t::ARDTransform)(x::Real) = only(t.v) * x
(t::ARDTransform)(x) = t.v .* x
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)
Base.show(io::IO, t::ARDTransform) = print(io, "ARD Transform (dims: ", dim(t), ")")