Skip to content

Commit 8af9ac5

Browse files
committed
Added select transform
1 parent 1aacddf commit 8af9ac5

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/transform/selecttransform.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
SelectTransform
3+
```
4+
dims = [1,3,5,6,7]
5+
tr = SelectTransform(dims)
6+
X = rand(100,10)
7+
transform(tr,X,obsdim=2) == X[dims,:]
8+
```
9+
Select the dimensions `dims` that the kernel is applied to.
10+
"""
11+
struct SelectTransform{T<:AbstractVector{<:Int}} <: Transform
12+
select::T
13+
dim_max::Int
14+
end
15+
16+
function SelectTransform(dims::AbstractVector{T}) where {T<:Int}
17+
@assert all(dims.>0) "Selective dimensions should all be positive integers"
18+
SelectTransform{T}(dims,maximum(dims))
19+
end
20+
21+
function set!(t::SelectTransform{<:AbstractVector{T}},s::AbstractVector{T}) where {T<:Real}
22+
t.proj .= s
23+
end
24+
25+
Base.maximum(t::SelectTransform) = maximum(t.select)
26+
27+
function transform(t::SelectTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs)
28+
@boundscheck t.dim_max <= size(X,feature_dim(obsdim)) ?
29+
throw(DimensionMismatch("The highest index $(t.dim_max) is higher then the feature dimension of X : $(size(X,feature_dim(obsdim)))")) : nothing
30+
@inbounds _transform(t,X,obsdim)
31+
end
32+
33+
function transform(t::SelectTransform,x::AbstractVector{<:Real},obsdim::Int=defaultobs) #TODO Add test
34+
@assert t.dim_max <= length(x) "The highest index $(t.dim_max) is higher then the vector length : $(length(x))"
35+
return x[t.select]
36+
end
37+
38+
_transform(t::SelectTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 2 ? X[t.select,:] : X[:,t.select]

src/transform/transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ transform
1919
include("scaletransform.jl")
2020
include("lowranktransform.jl")
2121
include("functiontransform.jl")
22-
22+
include("selecttransform.jl")
2323

2424
"""
2525
Chain a series of transform, here `t1` will be called first

0 commit comments

Comments
 (0)