Skip to content

Commit b355811

Browse files
authored
Added setindex! for ColVecs and RowVecs with tests (#196)
1 parent bc5760c commit b355811

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.7"
3+
version = "0.8.8"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/utils.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Base.size(D::ColVecs) = (size(D.X, 2),)
3838
Base.getindex(D::ColVecs, i::Int) = view(D.X, :, i)
3939
Base.getindex(D::ColVecs, i::CartesianIndex{1}) = view(D.X, :, i)
4040
Base.getindex(D::ColVecs, i) = ColVecs(view(D.X, :, i))
41+
Base.setindex!(D::ColVecs, v::AbstractVector, i) = setindex!(D.X, v, :, i)
4142

4243
dim(x::ColVecs) = size(x.X, 1)
4344

@@ -76,6 +77,7 @@ Base.size(D::RowVecs) = (size(D.X, 1),)
7677
Base.getindex(D::RowVecs, i::Int) = view(D.X, i, :)
7778
Base.getindex(D::RowVecs, i::CartesianIndex{1}) = view(D.X, i, :)
7879
Base.getindex(D::RowVecs, i) = RowVecs(view(D.X, i, :))
80+
Base.setindex!(D::RowVecs, v::AbstractVector, i) = setindex!(D.X, v, i, :)
7981

8082
dim(x::RowVecs) = size(x.X, 2)
8183

@@ -94,17 +96,6 @@ function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
9496
return Distances.pairwise!(out, d, x.X, y.X; dims=1)
9597
end
9698

97-
"""
98-
Will be implemented at some point
99-
```julia
100-
params(k::Kernel)
101-
params(t::Transform)
102-
```
103-
For a kernel return a tuple with parameters of the transform followed by the specific parameters of the kernel
104-
For a transform return its parameters, for a `ChainTransform` return a vector of `params(t)`.
105-
"""
106-
#params
107-
10899
dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype.
109100
dim(x::AbstractVector) = dim(first(x))
110101
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))

test/utils.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
@testset "utils" begin
22
using KernelFunctions: vec_of_vecs, ColVecs, RowVecs
33
rng, N, D = MersenneTwister(123456), 10, 4
4-
x, X = randn(rng, N), randn(rng, D, N)
4+
x = randn(rng, N)
5+
X = randn(rng, D, N)
6+
v = randn(rng, D)
7+
w = randn(rng, N)
8+
59
@testset "VecOfVecs" begin
610
@test vec_of_vecs(X, obsdim = 2) == ColVecs(X)
711
@test vec_of_vecs(X, obsdim = 1) == RowVecs(X)
@@ -19,6 +23,9 @@
1923
@test getindex(DX, :) == ColVecs(X)
2024
@test eachindex(DX) == 1:N
2125
@test first(DX) == X[:, 1]
26+
DX[2] = v
27+
@test DX[2] == v
28+
@test X[:, 2] == v
2229

2330
Y = randn(rng, D, N + 1)
2431
DY = ColVecs(Y)
@@ -53,6 +60,9 @@
5360
@test getindex(DX, :) == RowVecs(X)
5461
@test eachindex(DX) == 1:D
5562
@test first(DX) == X[1, :]
63+
DX[2] = w
64+
@test DX[2] == w
65+
@test X[2, :] == w
5666

5767
Y = randn(rng, D + 1, N)
5868
DY = RowVecs(Y)

0 commit comments

Comments
 (0)