Skip to content

Commit 9723a5c

Browse files
authored
Merge pull request #104 from devmotion/_scale
Optimization of ScaleTransform for SimpleKernel
2 parents 754bee4 + a65b857 commit 9723a5c

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

src/generic.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@ Base.iterate(k::Kernel, ::Any) = nothing
55

66
printshifted(io::IO, o, shift::Int) = print(io, o)
77

8-
# See https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/96
9-
_scale(t::ScaleTransform, metric::Euclidean, x, y) = first(t.s) * evaluate(metric, x, y)
10-
_scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y) = first(t.s)^2 * evaluate(metric, x, y)
11-
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, apply(t, x), apply(t, y))
12-
138
### Syntactic sugar for creating matrices and using kernel functions
149
function concretetypes(k, ktypes::Vector)
1510
isempty(subtypes(k)) ? push!(ktypes, k) : concretetypes.(subtypes(k), Ref(ktypes))

src/kernels/transformedkernel.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,24 @@ end
1111

1212
(k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y))
1313

14+
# Optimizations for scale transforms of simple kernels to save allocations:
15+
# Instead of a multiplying every element of the inputs before evaluating the metric,
16+
# we perform a scalar multiplcation of the distance of the original inputs, if possible.
17+
function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})(
18+
x::AbstractVector{<:Real},
19+
y::AbstractVector{<:Real},
20+
)
21+
return kappa(k.kernel, _scale(k.transform, metric(k.kernel), x, y))
22+
end
23+
24+
function _scale(t::ScaleTransform, metric::Euclidean, x, y)
25+
return first(t.s) * evaluate(metric, x, y)
26+
end
27+
function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y)
28+
return first(t.s)^2 * evaluate(metric, x, y)
29+
end
30+
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))
31+
1432
"""
1533
```julia
1634
transform(k::BaseKernel, t::Transform) (1)

0 commit comments

Comments
 (0)