Skip to content

Commit a461737

Browse files
authored
SelectTransform with single index (#181)
* Fix up select implementation * Bump patch * Uncomment tests * Tidy up a bit * Test vector-of-vectors * Fix up evaluation
1 parent 3d3a6c7 commit a461737

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.3"
3+
version = "0.8.4"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/transform/selecttransform.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@ set!(t::SelectTransform, dims) = t.select .= dims
1717

1818
duplicate(t::SelectTransform,θ) = t
1919

20-
(t::SelectTransform)(x::AbstractVector) = view(x, t.select)
20+
(t::SelectTransform)(x::AbstractVector) = _maybe_unwrap(view(x, t.select))
2121

22-
_map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
23-
_map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))
22+
_maybe_unwrap(x) = x
23+
_maybe_unwrap(x::AbstractArray{<:Any, 0}) = x[]
24+
25+
_map(t::SelectTransform, x::ColVecs) = _wrap(view(x.X, t.select, :), ColVecs)
26+
_map(t::SelectTransform, x::RowVecs) = _wrap(view(x.X, :, t.select), RowVecs)
27+
28+
_wrap(x::AbstractVector{<:Real}, ::Any) = x
29+
_wrap(X::AbstractMatrix{<:Real}, ::Type{T}) where {T} = T(X)
2430

2531
Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")")

test/transform/selecttransform.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,15 @@
109109
testfunction(ta_col, a, C, 1)
110110
end
111111
end
112+
113+
@testset "single-index" begin
114+
t = SelectTransform(4)
115+
@testset "$(name)" for (name, x) in [
116+
("Vector{<:Vector}", [randn(6) for _ in 1:3]),
117+
("ColVecs", ColVecs(randn(5, 10))),
118+
("RowVecs", RowVecs(randn(11, 4))),
119+
]
120+
@test KernelFunctions._map(t, x) isa AbstractVector{Float64}
121+
end
122+
end
112123
end

0 commit comments

Comments
 (0)