Skip to content

Commit b73051c

Browse files
author
Will Tebbutt
committed
Factor out _to_colvecs
1 parent bcb15c3 commit b73051c

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/basekernels/nn.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,21 @@ end
4040
function kernelmatrix(
4141
k::NeuralNetworkKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
4242
)
43-
return kernelmatrix(k, ColVecs(reshape(x, 1, :)), ColVecs(reshape(y, 1, :)))
43+
return kernelmatrix(k, _to_colvecs(x), _to_colvecs(y))
4444
end
4545

4646
function kernelmatrix(k::NeuralNetworkKernel, x::AbstractVector{<:Real})
47-
return kernelmatrix(k, ColVecs(reshape(x, 1, :)))
47+
return kernelmatrix(k, _to_colvecs(x))
4848
end
4949

5050
function kernelmatrix_diag(
5151
k::NeuralNetworkKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
5252
)
53-
return kernelmatrix_diag(k, ColVecs(reshape(x, 1, :)), ColVecs(reshape(y, 1, :)))
53+
return kernelmatrix_diag(k, _to_colvecs(x), _to_colvecs(y))
5454
end
5555

5656
function kernelmatrix_diag(k::NeuralNetworkKernel, x::AbstractVector{<:Real})
57-
return kernelmatrix_diag(k, ColVecs(reshape(x, 1, :)))
57+
return kernelmatrix_diag(k, _to_colvecs(x))
5858
end
5959

6060
function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ Base.zero(x::ColVecs) = ColVecs(zero(x.X))
9797

9898
dim(x::ColVecs) = size(x.X, 1)
9999

100+
_to_colvecs(x::AbstractVector{<:Real}) = ColVecs(reshape(x, 1, :))
101+
100102
pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
101103
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
102104
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)

0 commit comments

Comments
 (0)