-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathselecttransform.jl
More file actions
34 lines (23 loc) · 918 Bytes
/
selecttransform.jl
File metadata and controls
34 lines (23 loc) · 918 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
SelectTransform(dims)
Transformation that selects the dimensions `dims` of the input.
# Examples
```jldoctest
julia> dims = [1, 3, 5, 6, 7]; t = SelectTransform(dims); X = rand(100, 10);
julia> map(t, ColVecs(X)) == ColVecs(X[dims, :])
true
```
"""
struct SelectTransform{T} <: Transform
select::T
end
set!(t::SelectTransform, dims) = t.select .= dims
duplicate(t::SelectTransform, θ) = t
(t::SelectTransform)(x::AbstractVector) = _maybe_unwrap(view(x, t.select))
_maybe_unwrap(x) = x
_maybe_unwrap(x::AbstractArray{<:Any,0}) = x[]
Base.map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
Base.map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)
_wrap(x::AbstractVector{<:Real}, ::Any) = x
_wrap(X::AbstractMatrix{<:Real}, ::Type{T}) where {T} = T(X)
Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")")