Skip to content

AD Performance #467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6d820ac
Tuples rather than vectors
Aug 18, 2022
f6b66ca
Testing
Aug 19, 2022
dfd1846
Bump patch
Aug 19, 2022
006183e
Formatting
willtebbutt Aug 19, 2022
6ca9e98
Docstring line lengths
Aug 22, 2022
4df98a5
Data generation
Aug 22, 2022
76a88d1
Add StableRNGs to test Project
Aug 22, 2022
6405a50
Add ad tests to test utils
Aug 22, 2022
7f65a49
Test constant kernels ad perf
Aug 22, 2022
2498bd8
Improve ad perf testing
Aug 22, 2022
1c8d105
Fix rational performance issues
Aug 22, 2022
42725db
Improve FBM AD performance
Aug 22, 2022
c2c72d1
Fix FBM AD performance
Aug 22, 2022
26eb90f
Test matern kernels peformance
Aug 22, 2022
f57eb2f
Optimise NeuralNetworkKernel
Aug 22, 2022
c84eed8
Enable all nn tests
Aug 22, 2022
ebee280
linear and polynomial performance
Aug 22, 2022
49c3fa7
Uncomment nn test
Aug 22, 2022
81bc9e6
Add performance tests to basekernels
Aug 22, 2022
ee92219
Merge in master
Aug 22, 2022
617be1c
Uncomment tests
Aug 22, 2022
8af78a9
Finish up merge conflict resolution
Aug 22, 2022
0046b85
Test non-base-kernels
Aug 22, 2022
c8c3f4e
Check transform performance
Aug 22, 2022
4e41b01
Update runtests
Aug 22, 2022
fb3821b
1.6-compatible test_broken
Aug 22, 2022
cc50510
Failing tests on LTS
Aug 23, 2022
de709eb
Formatting
willtebbutt Aug 23, 2022
890ba92
Fix formatting
Aug 23, 2022
ea3d774
Formatting
willtebbutt Aug 23, 2022
9cd3880
Improve docstrings
Aug 23, 2022
5347aa3
Merge branch 'wct/ad-perf' of https://github.com/JuliaGaussianProcess…
Aug 23, 2022
b6aec47
Apply suggestions from code review
willtebbutt Aug 23, 2022
74df981
Remove redundant only calls
Aug 23, 2022
dbc24c2
Tidy up docstring
Aug 23, 2022
79fb7fa
__example_inputs -> example_inputs
Aug 23, 2022
90db327
Eltype in ones
Aug 23, 2022
7b5ed26
Sort ard promotion
Aug 23, 2022
db91a77
Bump patch
Aug 23, 2022
bcb15c3
Fix typo
Aug 23, 2022
b73051c
Factor out _to_colvecs
Aug 25, 2022
f009284
Use struct rather than closue
Aug 25, 2022
956b581
Test all rationalquadratickernel methods
Aug 25, 2022
aaf70e6
Formatting
willtebbutt Aug 25, 2022
be582d5
Uncomment tests
Aug 25, 2022
2e560ff
Merge branch 'wct/ad-perf' of https://github.com/JuliaGaussianProcess…
Aug 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.42"
version = "0.10.43"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
29 changes: 27 additions & 2 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,13 @@ end
"""
test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}; kwargs...) where {T<:Real}

Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`, `Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`,
`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.

For other input types, please provide the data manually.

The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the randomly generated inputs.
The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the
randomly generated inputs.
"""
function test_interface(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T<:Real}
@testset "Vector{$T}" begin
Expand All @@ -174,4 +176,27 @@ function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...)
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
end

"""
example_inputs(rng::AbstractRNG, type)

Return a tuple of 4 inputs of type `type`. See `methods(example_inputs)` for information
around supported types. It is recommended that you utilise `StableRNGs.jl` for `rng` here
to ensure consistency across Julia versions.
"""
function example_inputs(rng::AbstractRNG, ::Type{Vector{Float64}})
return map(n -> randn(rng, Float64, n), (1, 2, 3, 4))
end

function example_inputs(
rng::AbstractRNG, ::Type{ColVecs{Float64,Matrix{Float64}}}; dim::Int=2
)
return map(n -> ColVecs(randn(rng, dim, n)), (1, 2, 3, 4))
end

function example_inputs(
rng::AbstractRNG, ::Type{RowVecs{Float64,Matrix{Float64}}}; dim::Int=2
)
return map(n -> RowVecs(randn(rng, n, dim)), (1, 2, 3, 4))
end

end # module
11 changes: 7 additions & 4 deletions src/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ _mod(x::RowVecs) = vec(sum(abs2, x.X; dims=2))

function kernelmatrix(κ::FBMKernel, x::AbstractVector)
modx = _mod(x)
modx_wide = modx * ones(eltype(modx), 1, length(modx)) # ad perf hack -- is unit tested
modxx = pairwise(SqEuclidean(), x)
return _fbm.(modx, modx', modxx, κ.h)
return _fbm.(modx_wide, modx_wide', modxx, only(κ.h))
end

function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, x::AbstractVector)
Expand All @@ -63,7 +64,9 @@ end

function kernelmatrix(κ::FBMKernel, x::AbstractVector, y::AbstractVector)
modxy = pairwise(SqEuclidean(), x, y)
return _fbm.(_mod(x), _mod(y)', modxy, κ.h)
modx_wide = _mod(x) * ones(eltype(modxy), 1, length(y)) # ad perf hack -- is unit tested
mody_wide = _mod(y) * ones(eltype(modxy), 1, length(x)) # ad perf hack -- is unit tested
return _fbm.(modx_wide, mody_wide', modxy, only(κ.h))
end

function kernelmatrix!(
Expand All @@ -77,10 +80,10 @@ end
function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector)
modx = _mod(x)
modxx = colwise(SqEuclidean(), x)
return _fbm.(modx, modx, modxx, κ.h)
return _fbm.(modx, modx, modxx, only(κ.h))
end

function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector, y::AbstractVector)
modxy = colwise(SqEuclidean(), x, y)
return _fbm.(_mod(x), _mod(y), modxy, κ.h)
return _fbm.(_mod(x), _mod(y), modxy, only(κ.h))
end
20 changes: 20 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ function (κ::NeuralNetworkKernel)(x, y)
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
end

function kernelmatrix(
k::NeuralNetworkKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
)
return kernelmatrix(k, _to_colvecs(x), _to_colvecs(y))
end

function kernelmatrix(k::NeuralNetworkKernel, x::AbstractVector{<:Real})
return kernelmatrix(k, _to_colvecs(x))
end

function kernelmatrix_diag(
k::NeuralNetworkKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
)
return kernelmatrix_diag(k, _to_colvecs(x), _to_colvecs(y))
end

function kernelmatrix_diag(k::NeuralNetworkKernel, x::AbstractVector{<:Real})
return kernelmatrix_diag(k, _to_colvecs(x))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=1)
Expand Down
44 changes: 42 additions & 2 deletions src/basekernels/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,28 @@ LinearKernel(; c::Real=0.0) = LinearKernel(c)

@functor LinearKernel

kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + only(κ.c)
__linear_kappa(c::Real, xᵀy::Real) = xᵀy + c

kappa(κ::LinearKernel, xᵀy::Real) = __linear_kappa(only(κ.c), xᵀy)

metric(::LinearKernel) = DotProduct()

function kernelmatrix(k::LinearKernel, x::AbstractVector, y::AbstractVector)
return __linear_kappa.(only(k.c), pairwise(metric(k), x, y))
end

function kernelmatrix(k::LinearKernel, x::AbstractVector)
return __linear_kappa.(only(k.c), pairwise(metric(k), x))
end

function kernelmatrix_diag(k::LinearKernel, x::AbstractVector, y::AbstractVector)
return __linear_kappa.(only(k.c), colwise(metric(k), x, y))
end

function kernelmatrix_diag(k::LinearKernel, x::AbstractVector)
return __linear_kappa.(only(k.c), colwise(metric(k), x))
end

Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", only(κ.c), ")")

"""
Expand Down Expand Up @@ -68,10 +86,32 @@ function Functors.functor(::Type{<:PolynomialKernel}, x)
return (c=x.c,), reconstruct_polynomialkernel
end

kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + only(κ.c))^κ.degree
struct _PolynomialKappa
degree::Int
end

(κ::_PolynomialKappa)(c::Real, xᵀy::Real) = (xᵀy + c)^κ.degree

kappa(κ::PolynomialKernel, xᵀy::Real) = _PolynomialKappa(κ.degree)(only(κ.c), xᵀy)

metric(::PolynomialKernel) = DotProduct()

function kernelmatrix(k::PolynomialKernel, x::AbstractVector, y::AbstractVector)
return _PolynomialKappa(k.degree).(only(k.c), pairwise(metric(k), x, y))
end

function kernelmatrix(k::PolynomialKernel, x::AbstractVector)
return _PolynomialKappa(k.degree).(only(k.c), pairwise(metric(k), x))
end

function kernelmatrix_diag(k::PolynomialKernel, x::AbstractVector, y::AbstractVector)
return _PolynomialKappa(k.degree).(only(k.c), colwise(metric(k), x, y))
end

function kernelmatrix_diag(k::PolynomialKernel, x::AbstractVector)
return _PolynomialKappa(k.degree).(only(k.c), colwise(metric(k), x))
end

function Base.show(io::IO, κ::PolynomialKernel)
return print(io, "Polynomial Kernel (c = ", only(κ.c), ", degree = ", κ.degree, ")")
end
105 changes: 93 additions & 12 deletions src/basekernels/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,32 @@ end

@functor RationalKernel

function kappa(κ::RationalKernel, d::Real)
return (one(d) + d / only(κ.α))^(-only(κ.α))
end
__rational_kappa(α::Real, d::Real) = (one(d) + d / α)^(-α)

kappa(κ::RationalKernel, d::Real) = __rational_kappa(only(κ.α), d)

metric(k::RationalKernel) = k.metric

# AD-performance optimisation. Is unit tested.
function kernelmatrix(k::RationalKernel, x::AbstractVector, y::AbstractVector)
return __rational_kappa.(only(k.α), pairwise(metric(k), x, y))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix(k::RationalKernel, x::AbstractVector)
return __rational_kappa.(only(k.α), pairwise(metric(k), x))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix_diag(k::RationalKernel, x::AbstractVector, y::AbstractVector)
return __rational_kappa.(only(k.α), colwise(metric(k), x, y))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix_diag(k::RationalKernel, x::AbstractVector)
return __rational_kappa.(only(k.α), colwise(metric(k), x))
end

function Base.show(io::IO, κ::RationalKernel)
return print(io, "Rational Kernel (α = ", only(κ.α), ", metric = ", κ.metric, ")")
end
Expand Down Expand Up @@ -69,18 +89,59 @@ struct RationalQuadraticKernel{Tα<:Real,M} <: SimpleKernel
end
end

const _RQ_Euclidean = RationalQuadraticKernel{<:Real,<:Euclidean}

@functor RationalQuadraticKernel

function kappa(κ::RationalQuadraticKernel, d::Real)
return (one(d) + d^2 / (2 * only(κ.α)))^(-only(κ.α))
end
function kappa(κ::RationalQuadraticKernel{<:Real,<:Euclidean}, d²::Real)
return (one(d²) + d² / (2 * only(κ.α)))^(-only(κ.α))
end
__rq_kappa(α::Real, d::Real) = (one(d) + d^2 / (2 * α))^(-α)
__rq_kappa_euclidean(α::Real, d²::Real) = (one(d²) + d² / (2 * α))^(-α)

kappa(κ::RationalQuadraticKernel, d::Real) = __rq_kappa(only(κ.α), d)
kappa(κ::_RQ_Euclidean, d²::Real) = __rq_kappa_euclidean(only(κ.α), d²)

metric(k::RationalQuadraticKernel) = k.metric
metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean()

# AD-performance optimisation. Is unit tested.
function kernelmatrix(k::RationalQuadraticKernel, x::AbstractVector, y::AbstractVector)
return __rq_kappa.(only(k.α), pairwise(metric(k), x, y))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix(k::RationalQuadraticKernel, x::AbstractVector)
return __rq_kappa.(only(k.α), pairwise(metric(k), x))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix_diag(k::RationalQuadraticKernel, x::AbstractVector, y::AbstractVector)
return __rq_kappa.(only(k.α), colwise(metric(k), x, y))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix_diag(k::RationalQuadraticKernel, x::AbstractVector)
return __rq_kappa.(only(k.α), colwise(metric(k), x))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix(k::_RQ_Euclidean, x::AbstractVector, y::AbstractVector)
return __rq_kappa_euclidean.(only(k.α), pairwise(SqEuclidean(), x, y))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix(k::_RQ_Euclidean, x::AbstractVector)
return __rq_kappa_euclidean.(only(k.α), pairwise(SqEuclidean(), x))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix_diag(k::_RQ_Euclidean, x::AbstractVector, y::AbstractVector)
return __rq_kappa_euclidean.(only(k.α), colwise(SqEuclidean(), x, y))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix_diag(k::_RQ_Euclidean, x::AbstractVector)
return __rq_kappa_euclidean.(only(k.α), colwise(SqEuclidean(), x))
end

function Base.show(io::IO, κ::RationalQuadraticKernel)
return print(
io, "Rational Quadratic Kernel (α = ", only(κ.α), ", metric = ", κ.metric, ")"
Expand Down Expand Up @@ -121,12 +182,32 @@ end

@functor GammaRationalKernel

function kappa(κ::GammaRationalKernel, d::Real)
return (one(d) + d^only(κ.γ) / only(κ.α))^(-only(κ.α))
end
__grk_kappa(α::Real, γ::Real, d::Real) = (one(d) + d^γ / α)^(-α)

kappa(κ::GammaRationalKernel, d::Real) = __grk_kappa(only(κ.α), only(κ.γ), d)

metric(k::GammaRationalKernel) = k.metric

# AD-performance optimisation. Is unit tested.
function kernelmatrix(k::GammaRationalKernel, x::AbstractVector, y::AbstractVector)
return __grk_kappa.(only(k.α), only(k.γ), pairwise(metric(k), x, y))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix(k::GammaRationalKernel, x::AbstractVector)
return __grk_kappa.(only(k.α), only(k.γ), pairwise(metric(k), x))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix_diag(k::GammaRationalKernel, x::AbstractVector, y::AbstractVector)
return __grk_kappa.(only(k.α), only(k.γ), colwise(metric(k), x, y))
end

# AD-performance optimisation. Is unit tested.
function kernelmatrix_diag(k::GammaRationalKernel, x::AbstractVector)
return __grk_kappa.(only(k.α), only(k.γ), colwise(metric(k), x))
end

function Base.show(io::IO, κ::GammaRationalKernel)
return print(
io,
Expand Down
12 changes: 7 additions & 5 deletions src/kernels/normalizedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ end
(κ::NormalizedKernel)(x, y) = κ.kernel(x, y) / sqrt(κ.kernel(x, x) * κ.kernel(y, y))

function kernelmatrix(κ::NormalizedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(κ.kernel, x, y) ./
sqrt.(
kernelmatrix_diag(κ.kernel, x) .* permutedims(kernelmatrix_diag(κ.kernel, y))
)
x_diag = kernelmatrix_diag(κ.kernel, x)
x_diag_wide = x_diag * ones(eltype(x_diag), 1, length(y)) # ad perf hack. Is unit tested
y_diag = kernelmatrix_diag(κ.kernel, y)
y_diag_wide = y_diag * ones(eltype(y_diag), 1, length(x)) # ad perf hack. Is unit tested
return kernelmatrix(κ.kernel, x, y) ./ sqrt.(x_diag_wide .* y_diag_wide')
end

function kernelmatrix(κ::NormalizedKernel, x::AbstractVector)
x_diag = kernelmatrix_diag(κ.kernel, x)
return kernelmatrix(κ.kernel, x) ./ sqrt.(x_diag .* permutedims(x_diag))
x_diag_wide = x_diag * ones(eltype(x_diag), 1, length(x)) # ad perf hack. Is unit tested
return kernelmatrix(κ.kernel, x) ./ sqrt.(x_diag_wide .* x_diag_wide')
end

function kernelmatrix_diag(κ::NormalizedKernel, x::AbstractVector)
Expand Down
8 changes: 4 additions & 4 deletions src/kernels/scaledkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ end
(k::ScaledKernel)(x, y) = only(k.σ²) * k.kernel(x, y)

function kernelmatrix(κ::ScaledKernel, x::AbstractVector, y::AbstractVector)
return κ.σ² .* kernelmatrix(κ.kernel, x, y)
return only(κ.σ²) * kernelmatrix(κ.kernel, x, y)
end

function kernelmatrix(κ::ScaledKernel, x::AbstractVector)
return κ.σ² .* kernelmatrix(κ.kernel, x)
return only(κ.σ²) * kernelmatrix(κ.kernel, x)
end

function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector)
return κ.σ² .* kernelmatrix_diag(κ.kernel, x)
return only(κ.σ²) * kernelmatrix_diag(κ.kernel, x)
end

function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector, y::AbstractVector)
return κ.σ² .* kernelmatrix_diag(κ.kernel, x, y)
return only(κ.σ²) * kernelmatrix_diag(κ.kernel, x, y)
end

function kernelmatrix!(
Expand Down
11 changes: 8 additions & 3 deletions src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ dim(t::ARDTransform) = length(t.v)
(t::ARDTransform)(x::Real) = only(t.v) * x
(t::ARDTransform)(x) = t.v .* x

_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
# Quite specific implementations required to pass correctness and performance tests.
_map(t::ARDTransform, x::AbstractVector{<:Real}) = x * only(t.v)
function _map(t::ARDTransform, x::ColVecs)
return ColVecs((t.v * ones(eltype(t.v), 1, size(x.X, 2))) .* x.X)
end
function _map(t::ARDTransform, x::RowVecs)
return RowVecs(x.X .* (ones(eltype(t.v), size(x.X, 1)) * collect(t.v')))
end

Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)

Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ Base.zero(x::ColVecs) = ColVecs(zero(x.X))

dim(x::ColVecs) = size(x.X, 1)

_to_colvecs(x::AbstractVector{<:Real}) = ColVecs(reshape(x, 1, :))

pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
Loading