@@ -10,7 +10,14 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
10
10
end
11
11
12
12
(k:: TransformedKernel )(x, y) = k. kernel (k. transform (x), k. transform (y))
13
- function (k:: TransformedKernel{<:SimpleKernel,<:ScaleTransform} )(x, y)
13
+
14
+ # Optimizations for scale transforms of simple kernels to save allocations:
15
+ # Instead of a multiplying every element of the inputs before evaluating the metric,
16
+ # we perform a scalar multiplcation of the distance of the original inputs, if possible.
17
+ function (k:: TransformedKernel{<:SimpleKernel,<:ScaleTransform} )(
18
+ x:: AbstractVector{<:Real} ,
19
+ y:: AbstractVector{<:Real}
20
+ )
14
21
return kappa (k. kernel, _scale (k. transform, metric (k. kernel), x, y))
15
22
end
16
23
20
27
function _scale (t:: ScaleTransform , metric:: Union{SqEuclidean,DotProduct} , x, y)
21
28
return first (t. s)^ 2 * evaluate (metric, x, y)
22
29
end
23
- function _scale (t:: ScaleTransform , metric, x, y)
24
- evaluate (metric, t (x), t (y))
25
- end
30
+ _scale (t:: ScaleTransform , metric, x, y) = evaluate (metric, t (x), t (y))
26
31
27
32
"""
28
33
```julia
0 commit comments