Skip to content

Commit c8134ec

Browse files
authored
Add explanation and be more conservative
1 parent 6b094ad commit c8134ec

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/kernels/transformedkernel.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
1010
end
1111

1212
(k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y))
13-
function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})(x, y)
13+
14+
# Optimizations for scale transforms of simple kernels to save allocations:
15+
# Instead of a multiplying every element of the inputs before evaluating the metric,
16+
# we perform a scalar multiplcation of the distance of the original inputs, if possible.
17+
function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})(
18+
x::AbstractVector{<:Real},
19+
y::AbstractVector{<:Real}
20+
)
1421
return kappa(k.kernel, _scale(k.transform, metric(k.kernel), x, y))
1522
end
1623

@@ -20,9 +27,7 @@ end
2027
function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y)
2128
return first(t.s)^2 * evaluate(metric, x, y)
2229
end
23-
function _scale(t::ScaleTransform, metric, x, y)
24-
evaluate(metric, t(x), t(y))
25-
end
30+
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))
2631

2732
"""
2833
```julia

0 commit comments

Comments
 (0)