Skip to content

Commit d168a76

Browse files
willtebbuttgithub-actions[bot]theogf
authored
AD Performance (#467)
* Tuples rather than vectors * Testing * Bump patch * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Docstring line lengths * Data generation * Add StableRNGs to test Project * Add ad tests to test utils * Test constant kernels ad perf * Improve ad perf testing * Fix rational performance issues * Improve FBM AD performance * Fix FBM AD performance * Test matern kernels peformance * Optimise NeuralNetworkKernel * Enable all nn tests * linear and polynomial performance * Uncomment nn test * Add performance tests to basekernels * Uncomment tests * Finish up merge conflict resolution * Test non-base-kernels * Check transform performance * Update runtests * 1.6-compatible test_broken * Failing tests on LTS * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix formatting * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Improve docstrings * Apply suggestions from code review Co-authored-by: Théo Galy-Fajou <[email protected]> * Remove redundant only calls * Tidy up docstring * __example_inputs -> example_inputs * Eltype in ones * Sort ard promotion * Bump patch * Fix typo * Factor out _to_colvecs * Use struct rather than closue * Test all rationalquadratickernel methods * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Uncomment tests Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Théo Galy-Fajou <[email protected]>
1 parent 1831cc6 commit d168a76

37 files changed

+476
-49
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.42"
3+
version = "0.10.43"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/TestUtils.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,13 @@ end
149149
"""
150150
test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}; kwargs...) where {T<:Real}
151151
152-
Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`, `Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
152+
Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`,
153+
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
153154
154155
For other input types, please provide the data manually.
155156
156-
The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the randomly generated inputs.
157+
The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the
158+
randomly generated inputs.
157159
"""
158160
function test_interface(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T<:Real}
159161
@testset "Vector{$T}" begin
@@ -174,4 +176,27 @@ function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...)
174176
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
175177
end
176178

179+
"""
180+
example_inputs(rng::AbstractRNG, type)
181+
182+
Return a tuple of 4 inputs of type `type`. See `methods(example_inputs)` for information
183+
around supported types. It is recommended that you utilise `StableRNGs.jl` for `rng` here
184+
to ensure consistency across Julia versions.
185+
"""
186+
function example_inputs(rng::AbstractRNG, ::Type{Vector{Float64}})
187+
return map(n -> randn(rng, Float64, n), (1, 2, 3, 4))
188+
end
189+
190+
function example_inputs(
191+
rng::AbstractRNG, ::Type{ColVecs{Float64,Matrix{Float64}}}; dim::Int=2
192+
)
193+
return map(n -> ColVecs(randn(rng, dim, n)), (1, 2, 3, 4))
194+
end
195+
196+
function example_inputs(
197+
rng::AbstractRNG, ::Type{RowVecs{Float64,Matrix{Float64}}}; dim::Int=2
198+
)
199+
return map(n -> RowVecs(randn(rng, n, dim)), (1, 2, 3, 4))
200+
end
201+
177202
end # module

src/basekernels/fbm.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ _mod(x::RowVecs) = vec(sum(abs2, x.X; dims=2))
5050

5151
function kernelmatrix::FBMKernel, x::AbstractVector)
5252
modx = _mod(x)
53+
modx_wide = modx * ones(eltype(modx), 1, length(modx)) # ad perf hack -- is unit tested
5354
modxx = pairwise(SqEuclidean(), x)
54-
return _fbm.(modx, modx', modxx, κ.h)
55+
return _fbm.(modx_wide, modx_wide', modxx, only(κ.h))
5556
end
5657

5758
function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, x::AbstractVector)
@@ -63,7 +64,9 @@ end
6364

6465
function kernelmatrix::FBMKernel, x::AbstractVector, y::AbstractVector)
6566
modxy = pairwise(SqEuclidean(), x, y)
66-
return _fbm.(_mod(x), _mod(y)', modxy, κ.h)
67+
modx_wide = _mod(x) * ones(eltype(modxy), 1, length(y)) # ad perf hack -- is unit tested
68+
mody_wide = _mod(y) * ones(eltype(modxy), 1, length(x)) # ad perf hack -- is unit tested
69+
return _fbm.(modx_wide, mody_wide', modxy, only.h))
6770
end
6871

6972
function kernelmatrix!(
@@ -77,10 +80,10 @@ end
7780
function kernelmatrix_diag::FBMKernel, x::AbstractVector)
7881
modx = _mod(x)
7982
modxx = colwise(SqEuclidean(), x)
80-
return _fbm.(modx, modx, modxx, κ.h)
83+
return _fbm.(modx, modx, modxx, only(κ.h))
8184
end
8285

8386
function kernelmatrix_diag::FBMKernel, x::AbstractVector, y::AbstractVector)
8487
modxy = colwise(SqEuclidean(), x, y)
85-
return _fbm.(_mod(x), _mod(y), modxy, κ.h)
88+
return _fbm.(_mod(x), _mod(y), modxy, only(κ.h))
8689
end

src/basekernels/nn.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,26 @@ function (κ::NeuralNetworkKernel)(x, y)
3737
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
3838
end
3939

40+
function kernelmatrix(
41+
k::NeuralNetworkKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
42+
)
43+
return kernelmatrix(k, _to_colvecs(x), _to_colvecs(y))
44+
end
45+
46+
function kernelmatrix(k::NeuralNetworkKernel, x::AbstractVector{<:Real})
47+
return kernelmatrix(k, _to_colvecs(x))
48+
end
49+
50+
function kernelmatrix_diag(
51+
k::NeuralNetworkKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
52+
)
53+
return kernelmatrix_diag(k, _to_colvecs(x), _to_colvecs(y))
54+
end
55+
56+
function kernelmatrix_diag(k::NeuralNetworkKernel, x::AbstractVector{<:Real})
57+
return kernelmatrix_diag(k, _to_colvecs(x))
58+
end
59+
4060
function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
4161
validate_inputs(x, y)
4262
X_2 = sum(x.X .* x.X; dims=1)

src/basekernels/polynomial.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,28 @@ LinearKernel(; c::Real=0.0) = LinearKernel(c)
2626

2727
@functor LinearKernel
2828

29-
kappa::LinearKernel, xᵀy::Real) = xᵀy + only.c)
29+
__linear_kappa(c::Real, xᵀy::Real) = xᵀy + c
30+
31+
kappa::LinearKernel, xᵀy::Real) = __linear_kappa(only.c), xᵀy)
3032

3133
metric(::LinearKernel) = DotProduct()
3234

35+
function kernelmatrix(k::LinearKernel, x::AbstractVector, y::AbstractVector)
36+
return __linear_kappa.(only(k.c), pairwise(metric(k), x, y))
37+
end
38+
39+
function kernelmatrix(k::LinearKernel, x::AbstractVector)
40+
return __linear_kappa.(only(k.c), pairwise(metric(k), x))
41+
end
42+
43+
function kernelmatrix_diag(k::LinearKernel, x::AbstractVector, y::AbstractVector)
44+
return __linear_kappa.(only(k.c), colwise(metric(k), x, y))
45+
end
46+
47+
function kernelmatrix_diag(k::LinearKernel, x::AbstractVector)
48+
return __linear_kappa.(only(k.c), colwise(metric(k), x))
49+
end
50+
3351
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", only.c), ")")
3452

3553
"""
@@ -68,10 +86,32 @@ function Functors.functor(::Type{<:PolynomialKernel}, x)
6886
return (c=x.c,), reconstruct_polynomialkernel
6987
end
7088

71-
kappa::PolynomialKernel, xᵀy::Real) = (xᵀy + only.c))^κ.degree
89+
struct _PolynomialKappa
90+
degree::Int
91+
end
92+
93+
::_PolynomialKappa)(c::Real, xᵀy::Real) = (xᵀy + c)^κ.degree
94+
95+
kappa::PolynomialKernel, xᵀy::Real) = _PolynomialKappa.degree)(only.c), xᵀy)
7296

7397
metric(::PolynomialKernel) = DotProduct()
7498

99+
function kernelmatrix(k::PolynomialKernel, x::AbstractVector, y::AbstractVector)
100+
return _PolynomialKappa(k.degree).(only(k.c), pairwise(metric(k), x, y))
101+
end
102+
103+
function kernelmatrix(k::PolynomialKernel, x::AbstractVector)
104+
return _PolynomialKappa(k.degree).(only(k.c), pairwise(metric(k), x))
105+
end
106+
107+
function kernelmatrix_diag(k::PolynomialKernel, x::AbstractVector, y::AbstractVector)
108+
return _PolynomialKappa(k.degree).(only(k.c), colwise(metric(k), x, y))
109+
end
110+
111+
function kernelmatrix_diag(k::PolynomialKernel, x::AbstractVector)
112+
return _PolynomialKappa(k.degree).(only(k.c), colwise(metric(k), x))
113+
end
114+
75115
function Base.show(io::IO, κ::PolynomialKernel)
76116
return print(io, "Polynomial Kernel (c = ", only.c), ", degree = ", κ.degree, ")")
77117
end

src/basekernels/rational.jl

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,32 @@ end
3131

3232
@functor RationalKernel
3333

34-
function kappa::RationalKernel, d::Real)
35-
return (one(d) + d / only.α))^(-only.α))
36-
end
34+
__rational_kappa::Real, d::Real) = (one(d) + d / α)^(-α)
35+
36+
kappa::RationalKernel, d::Real) = __rational_kappa(only.α), d)
3737

3838
metric(k::RationalKernel) = k.metric
3939

40+
# AD-performance optimisation. Is unit tested.
41+
function kernelmatrix(k::RationalKernel, x::AbstractVector, y::AbstractVector)
42+
return __rational_kappa.(only(k.α), pairwise(metric(k), x, y))
43+
end
44+
45+
# AD-performance optimisation. Is unit tested.
46+
function kernelmatrix(k::RationalKernel, x::AbstractVector)
47+
return __rational_kappa.(only(k.α), pairwise(metric(k), x))
48+
end
49+
50+
# AD-performance optimisation. Is unit tested.
51+
function kernelmatrix_diag(k::RationalKernel, x::AbstractVector, y::AbstractVector)
52+
return __rational_kappa.(only(k.α), colwise(metric(k), x, y))
53+
end
54+
55+
# AD-performance optimisation. Is unit tested.
56+
function kernelmatrix_diag(k::RationalKernel, x::AbstractVector)
57+
return __rational_kappa.(only(k.α), colwise(metric(k), x))
58+
end
59+
4060
function Base.show(io::IO, κ::RationalKernel)
4161
return print(io, "Rational Kernel (α = ", only.α), ", metric = ", κ.metric, ")")
4262
end
@@ -69,18 +89,59 @@ struct RationalQuadraticKernel{Tα<:Real,M} <: SimpleKernel
6989
end
7090
end
7191

92+
const _RQ_Euclidean = RationalQuadraticKernel{<:Real,<:Euclidean}
93+
7294
@functor RationalQuadraticKernel
7395

74-
function kappa::RationalQuadraticKernel, d::Real)
75-
return (one(d) + d^2 / (2 * only.α)))^(-only.α))
76-
end
77-
function kappa::RationalQuadraticKernel{<:Real,<:Euclidean}, d²::Real)
78-
return (one(d²) +/ (2 * only.α)))^(-only.α))
79-
end
96+
__rq_kappa::Real, d::Real) = (one(d) + d^2 / (2 * α))^(-α)
97+
__rq_kappa_euclidean::Real, d²::Real) = (one(d²) +/ (2 * α))^(-α)
98+
99+
kappa::RationalQuadraticKernel, d::Real) = __rq_kappa(only.α), d)
100+
kappa::_RQ_Euclidean, d²::Real) = __rq_kappa_euclidean(only.α), d²)
80101

81102
metric(k::RationalQuadraticKernel) = k.metric
82103
metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean()
83104

105+
# AD-performance optimisation. Is unit tested.
106+
function kernelmatrix(k::RationalQuadraticKernel, x::AbstractVector, y::AbstractVector)
107+
return __rq_kappa.(only(k.α), pairwise(metric(k), x, y))
108+
end
109+
110+
# AD-performance optimisation. Is unit tested.
111+
function kernelmatrix(k::RationalQuadraticKernel, x::AbstractVector)
112+
return __rq_kappa.(only(k.α), pairwise(metric(k), x))
113+
end
114+
115+
# AD-performance optimisation. Is unit tested.
116+
function kernelmatrix_diag(k::RationalQuadraticKernel, x::AbstractVector, y::AbstractVector)
117+
return __rq_kappa.(only(k.α), colwise(metric(k), x, y))
118+
end
119+
120+
# AD-performance optimisation. Is unit tested.
121+
function kernelmatrix_diag(k::RationalQuadraticKernel, x::AbstractVector)
122+
return __rq_kappa.(only(k.α), colwise(metric(k), x))
123+
end
124+
125+
# AD-performance optimisation. Is unit tested.
126+
function kernelmatrix(k::_RQ_Euclidean, x::AbstractVector, y::AbstractVector)
127+
return __rq_kappa_euclidean.(only(k.α), pairwise(SqEuclidean(), x, y))
128+
end
129+
130+
# AD-performance optimisation. Is unit tested.
131+
function kernelmatrix(k::_RQ_Euclidean, x::AbstractVector)
132+
return __rq_kappa_euclidean.(only(k.α), pairwise(SqEuclidean(), x))
133+
end
134+
135+
# AD-performance optimisation. Is unit tested.
136+
function kernelmatrix_diag(k::_RQ_Euclidean, x::AbstractVector, y::AbstractVector)
137+
return __rq_kappa_euclidean.(only(k.α), colwise(SqEuclidean(), x, y))
138+
end
139+
140+
# AD-performance optimisation. Is unit tested.
141+
function kernelmatrix_diag(k::_RQ_Euclidean, x::AbstractVector)
142+
return __rq_kappa_euclidean.(only(k.α), colwise(SqEuclidean(), x))
143+
end
144+
84145
function Base.show(io::IO, κ::RationalQuadraticKernel)
85146
return print(
86147
io, "Rational Quadratic Kernel (α = ", only.α), ", metric = ", κ.metric, ")"
@@ -121,12 +182,32 @@ end
121182

122183
@functor GammaRationalKernel
123184

124-
function kappa::GammaRationalKernel, d::Real)
125-
return (one(d) + d^only.γ) / only.α))^(-only.α))
126-
end
185+
__grk_kappa::Real, γ::Real, d::Real) = (one(d) + d^γ / α)^(-α)
186+
187+
kappa::GammaRationalKernel, d::Real) = __grk_kappa(only.α), only.γ), d)
127188

128189
metric(k::GammaRationalKernel) = k.metric
129190

191+
# AD-performance optimisation. Is unit tested.
192+
function kernelmatrix(k::GammaRationalKernel, x::AbstractVector, y::AbstractVector)
193+
return __grk_kappa.(only(k.α), only(k.γ), pairwise(metric(k), x, y))
194+
end
195+
196+
# AD-performance optimisation. Is unit tested.
197+
function kernelmatrix(k::GammaRationalKernel, x::AbstractVector)
198+
return __grk_kappa.(only(k.α), only(k.γ), pairwise(metric(k), x))
199+
end
200+
201+
# AD-performance optimisation. Is unit tested.
202+
function kernelmatrix_diag(k::GammaRationalKernel, x::AbstractVector, y::AbstractVector)
203+
return __grk_kappa.(only(k.α), only(k.γ), colwise(metric(k), x, y))
204+
end
205+
206+
# AD-performance optimisation. Is unit tested.
207+
function kernelmatrix_diag(k::GammaRationalKernel, x::AbstractVector)
208+
return __grk_kappa.(only(k.α), only(k.γ), colwise(metric(k), x))
209+
end
210+
130211
function Base.show(io::IO, κ::GammaRationalKernel)
131212
return print(
132213
io,

src/kernels/normalizedkernel.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@ end
2020
::NormalizedKernel)(x, y) = κ.kernel(x, y) / sqrt.kernel(x, x) * κ.kernel(y, y))
2121

2222
function kernelmatrix::NormalizedKernel, x::AbstractVector, y::AbstractVector)
23-
return kernelmatrix.kernel, x, y) ./
24-
sqrt.(
25-
kernelmatrix_diag.kernel, x) .* permutedims(kernelmatrix_diag.kernel, y))
26-
)
23+
x_diag = kernelmatrix_diag.kernel, x)
24+
x_diag_wide = x_diag * ones(eltype(x_diag), 1, length(y)) # ad perf hack. Is unit tested
25+
y_diag = kernelmatrix_diag.kernel, y)
26+
y_diag_wide = y_diag * ones(eltype(y_diag), 1, length(x)) # ad perf hack. Is unit tested
27+
return kernelmatrix.kernel, x, y) ./ sqrt.(x_diag_wide .* y_diag_wide')
2728
end
2829

2930
function kernelmatrix::NormalizedKernel, x::AbstractVector)
3031
x_diag = kernelmatrix_diag.kernel, x)
31-
return kernelmatrix.kernel, x) ./ sqrt.(x_diag .* permutedims(x_diag))
32+
x_diag_wide = x_diag * ones(eltype(x_diag), 1, length(x)) # ad perf hack. Is unit tested
33+
return kernelmatrix.kernel, x) ./ sqrt.(x_diag_wide .* x_diag_wide')
3234
end
3335

3436
function kernelmatrix_diag::NormalizedKernel, x::AbstractVector)

src/kernels/scaledkernel.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@ end
2626
(k::ScaledKernel)(x, y) = only(k.σ²) * k.kernel(x, y)
2727

2828
function kernelmatrix::ScaledKernel, x::AbstractVector, y::AbstractVector)
29-
return κ.σ² .* kernelmatrix.kernel, x, y)
29+
return only(κ.σ²) * kernelmatrix.kernel, x, y)
3030
end
3131

3232
function kernelmatrix::ScaledKernel, x::AbstractVector)
33-
return κ.σ² .* kernelmatrix.kernel, x)
33+
return only(κ.σ²) * kernelmatrix.kernel, x)
3434
end
3535

3636
function kernelmatrix_diag::ScaledKernel, x::AbstractVector)
37-
return κ.σ² .* kernelmatrix_diag.kernel, x)
37+
return only(κ.σ²) * kernelmatrix_diag.kernel, x)
3838
end
3939

4040
function kernelmatrix_diag::ScaledKernel, x::AbstractVector, y::AbstractVector)
41-
return κ.σ² .* kernelmatrix_diag.kernel, x, y)
41+
return only(κ.σ²) * kernelmatrix_diag.kernel, x, y)
4242
end
4343

4444
function kernelmatrix!(

src/transform/ardtransform.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,14 @@ dim(t::ARDTransform) = length(t.v)
3535
(t::ARDTransform)(x::Real) = only(t.v) * x
3636
(t::ARDTransform)(x) = t.v .* x
3737

38-
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
39-
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
40-
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
38+
# Quite specific implementations required to pass correctness and performance tests.
39+
_map(t::ARDTransform, x::AbstractVector{<:Real}) = x * only(t.v)
40+
function _map(t::ARDTransform, x::ColVecs)
41+
return ColVecs((t.v * ones(eltype(t.v), 1, size(x.X, 2))) .* x.X)
42+
end
43+
function _map(t::ARDTransform, x::RowVecs)
44+
return RowVecs(x.X .* (ones(eltype(t.v), size(x.X, 1)) * collect(t.v')))
45+
end
4146

4247
Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)
4348

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ Base.zero(x::ColVecs) = ColVecs(zero(x.X))
9797

9898
dim(x::ColVecs) = size(x.X, 1)
9999

100+
_to_colvecs(x::AbstractVector{<:Real}) = ColVecs(reshape(x, 1, :))
101+
100102
pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
101103
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
102104
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1515
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
16+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1819
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

0 commit comments

Comments
 (0)