Skip to content

Commit 5fd5157

Browse files
authored
Rename TensorProduct and implement TensorCore.tensor (#232)
1 parent 0ff8761 commit 5fd5157

17 files changed

+251
-243
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.18"
3+
version = "0.8.19"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -13,6 +13,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1313
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1414
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1515
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
16+
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1819

@@ -24,5 +25,6 @@ Requires = "1.0.1"
2425
SpecialFunctions = "0.8, 0.9, 0.10, 1"
2526
StatsBase = "0.32, 0.33"
2627
StatsFuns = "0.8, 0.9"
28+
TensorCore = "0.1"
2729
ZygoteRules = "0.2"
2830
julia = "1.3"

docs/src/kernels.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ transform(::Kernel, ::AbstractVector)
124124
ScaledKernel
125125
KernelSum
126126
KernelProduct
127-
TensorProduct
127+
KernelTensorProduct
128128
```
129129

130130
## Multi-output Kernels

src/KernelFunctions.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ export LinearKernel, PolynomialKernel
2929
export RationalQuadraticKernel, GammaRationalQuadraticKernel
3030
export GaborKernel, PiecewisePolynomialKernel
3131
export PeriodicKernel, NeuralNetworkKernel
32-
export KernelSum, KernelProduct
32+
export KernelSum, KernelProduct, KernelTensorProduct
3333
export TransformedKernel, ScaledKernel
34-
export TensorProduct
3534

3635
export Transform,
3736
SelectTransform,
@@ -52,6 +51,9 @@ export ColVecs, RowVecs
5251
export MOInput
5352
export IndependentMOKernel, LatentFactorMOKernel
5453

54+
# Reexports
55+
export tensor,
56+
5557
using Compat
5658
using Requires
5759
using Distances, LinearAlgebra
@@ -61,6 +63,7 @@ using ZygoteRules: @adjoint, pullback
6163
using StatsFuns: logtwo
6264
using InteractiveUtils: subtypes
6365
using StatsBase
66+
using TensorCore
6467

6568
abstract type Kernel end
6669
abstract type SimpleKernel <: Kernel end
@@ -100,7 +103,8 @@ include(joinpath("kernels", "scaledkernel.jl"))
100103
include(joinpath("matrix", "kernelmatrix.jl"))
101104
include(joinpath("kernels", "kernelsum.jl"))
102105
include(joinpath("kernels", "kernelproduct.jl"))
103-
include(joinpath("kernels", "tensorproduct.jl"))
106+
include(joinpath("kernels", "kerneltensorproduct.jl"))
107+
include(joinpath("kernels", "overloads.jl"))
104108
include(joinpath("approximations", "nystrom.jl"))
105109
include("generic.jl")
106110

src/basekernels/sm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function spectral_mixture_product_kernel(
9292
if !(size(αs) == size(γs) == size(ωs))
9393
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
9494
end
95-
return TensorProduct(
95+
return KernelTensorProduct(
9696
spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :)) for
9797
(α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs))
9898
)

src/deprecated.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@
77
@deprecate PiecewisePolynomialKernel{V}(A::AbstractMatrix{<:Real}) where {V} transform(
88
PiecewisePolynomialKernel{V}(size(A, 1)), LinearTransform(cholesky(A).U)
99
)
10+
11+
@deprecate TensorProduct(kernels) KernelTensorProduct(kernels)
12+
@deprecate TensorProduct(kernel::Kernel, kernels::Kernel...) KernelTensorProduct(
13+
kernel, kernels...
14+
)

src/kernels/kernelproduct.jl

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,6 @@ end
4141

4242
@functor KernelProduct
4343

44-
Base.:*(k1::Kernel, k2::Kernel) = KernelProduct(k1, k2)
45-
46-
function Base.:*(
47-
k1::KernelProduct{<:AbstractVector{<:Kernel}},
48-
k2::KernelProduct{<:AbstractVector{<:Kernel}},
49-
)
50-
return KernelProduct(vcat(k1.kernels, k2.kernels))
51-
end
52-
53-
function Base.:*(k1::KernelProduct, k2::KernelProduct)
54-
return KernelProduct(k1.kernels..., k2.kernels...)
55-
end
56-
57-
function Base.:*(k::Kernel, ks::KernelProduct{<:AbstractVector{<:Kernel}})
58-
return KernelProduct(vcat(k, ks.kernels))
59-
end
60-
61-
Base.:*(k::Kernel, kp::KernelProduct) = KernelProduct(k, kp.kernels...)
62-
63-
function Base.:*(ks::KernelProduct{<:AbstractVector{<:Kernel}}, k::Kernel)
64-
return KernelProduct(vcat(ks.kernels, k))
65-
end
66-
67-
Base.:*(kp::KernelProduct, k::Kernel) = KernelProduct(kp.kernels..., k)
68-
6944
Base.length(k::KernelProduct) = length(k.kernels)
7045

7146
::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)

src/kernels/kernelsum.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,28 +41,6 @@ end
4141

4242
@functor KernelSum
4343

44-
Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2)
45-
46-
function Base.:+(
47-
k1::KernelSum{<:AbstractVector{<:Kernel}}, k2::KernelSum{<:AbstractVector{<:Kernel}}
48-
)
49-
return KernelSum(vcat(k1.kernels, k2.kernels))
50-
end
51-
52-
Base.:+(k1::KernelSum, k2::KernelSum) = KernelSum(k1.kernels..., k2.kernels...)
53-
54-
function Base.:+(k::Kernel, ks::KernelSum{<:AbstractVector{<:Kernel}})
55-
return KernelSum(vcat(k, ks.kernels))
56-
end
57-
58-
Base.:+(k::Kernel, ks::KernelSum) = KernelSum(k, ks.kernels...)
59-
60-
function Base.:+(ks::KernelSum{<:AbstractVector{<:Kernel}}, k::Kernel)
61-
return KernelSum(vcat(ks.kernels, k))
62-
end
63-
64-
Base.:+(ks::KernelSum, k::Kernel) = KernelSum(ks.kernels..., k)
65-
6644
Base.length(k::KernelSum) = length(k.kernels)
6745

6846
::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels)

src/kernels/kerneltensorproduct.jl

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
KernelTensorProduct
3+
4+
Tensor product of kernels.
5+
6+
# Definition
7+
8+
For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor
9+
product of kernels ``k_1, \\ldots, k_n`` is defined as
10+
```math
11+
k(x, x'; k_1, \\ldots, k_n) = \\Big(\\bigotimes_{i=1}^n k_i\\Big)(x, x') = \\prod_{i=1}^n k_i(x_i, x'_i).
12+
```
13+
14+
# Construction
15+
16+
The simplest way to specify a `KernelTensorProduct` is to use the overloaded `tensor`
17+
operator or its alias `⊗` (can be typed by `\\otimes<tab>`).
18+
```jldoctest tensorproduct
19+
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);
20+
21+
julia> kernelmatrix(k1 ⊗ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) .* kernelmatrix(k2, X[:, 2])
22+
true
23+
```
24+
25+
You can also specify a `KernelTensorProduct` by providing kernels as individual arguments
26+
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
27+
individual arguments guarantees that `KernelTensorProduct` is concretely typed but might
28+
lead to large compilation times if the number of kernels is large.
29+
```jldoctest tensorproduct
30+
julia> KernelTensorProduct(k1, k2) == k1 ⊗ k2
31+
true
32+
33+
julia> KernelTensorProduct((k1, k2)) == k1 ⊗ k2
34+
true
35+
36+
julia> KernelTensorProduct([k1, k2]) == k1 ⊗ k2
37+
true
38+
```
39+
"""
40+
struct KernelTensorProduct{K} <: Kernel
41+
kernels::K
42+
end
43+
44+
function KernelTensorProduct(kernel::Kernel, kernels::Kernel...)
45+
return KernelTensorProduct((kernel, kernels...))
46+
end
47+
48+
@functor KernelTensorProduct
49+
50+
Base.length(kernel::KernelTensorProduct) = length(kernel.kernels)
51+
52+
function (kernel::KernelTensorProduct)(x, y)
53+
if !(length(x) == length(y) == length(kernel))
54+
throw(DimensionMismatch("number of kernels and number of features
55+
are not consistent"))
56+
end
57+
return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
58+
end
59+
60+
function validate_domain(k::KernelTensorProduct, x::AbstractVector)
61+
return dim(x) == length(k) ||
62+
error("number of kernels and groups of features are not consistent")
63+
end
64+
65+
# Utility for slicing up inputs.
66+
slices(x::AbstractVector{<:Real}) = (x,)
67+
slices(x::ColVecs) = eachrow(x.X)
68+
slices(x::RowVecs) = eachcol(x.X)
69+
70+
function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector)
71+
validate_inplace_dims(K, x)
72+
validate_domain(k, x)
73+
74+
kernels_and_inputs = zip(k.kernels, slices(x))
75+
kernelmatrix!(K, first(kernels_and_inputs)...)
76+
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
77+
K .*= kernelmatrix(k, xi)
78+
end
79+
80+
return K
81+
end
82+
83+
function kernelmatrix!(
84+
K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector
85+
)
86+
validate_inplace_dims(K, x, y)
87+
validate_domain(k, x)
88+
89+
kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
90+
kernelmatrix!(K, first(kernels_and_inputs)...)
91+
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1)
92+
K .*= kernelmatrix(k, xi, yi)
93+
end
94+
95+
return K
96+
end
97+
98+
function kerneldiagmatrix!(K::AbstractVector, k::KernelTensorProduct, x::AbstractVector)
99+
validate_inplace_dims(K, x)
100+
validate_domain(k, x)
101+
102+
kernels_and_inputs = zip(k.kernels, slices(x))
103+
kerneldiagmatrix!(K, first(kernels_and_inputs)...)
104+
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
105+
K .*= kerneldiagmatrix(k, xi)
106+
end
107+
108+
return K
109+
end
110+
111+
function kernelmatrix(k::KernelTensorProduct, x::AbstractVector)
112+
validate_domain(k, x)
113+
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x))
114+
end
115+
116+
function kernelmatrix(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector)
117+
validate_domain(k, x)
118+
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y))
119+
end
120+
121+
function kerneldiagmatrix(k::KernelTensorProduct, x::AbstractVector)
122+
validate_domain(k, x)
123+
return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x))
124+
end
125+
126+
Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0)
127+
128+
function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct)
129+
return (
130+
length(x.kernels) == length(y.kernels) &&
131+
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
132+
)
133+
end
134+
135+
function printshifted(io::IO, kernel::KernelTensorProduct, shift::Int)
136+
print(io, "Tensor product of ", length(kernel), " kernels:")
137+
for k in kernel.kernels
138+
print(io, "\n")
139+
for _ in 1:(shift + 1)
140+
print(io, "\t")
141+
end
142+
printshifted(io, k, shift + 2)
143+
end
144+
end

src/kernels/overloads.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
for (M, op, T) in (
2+
(:Base, :+, :KernelSum),
3+
(:Base, :*, :KernelProduct),
4+
(:TensorCore, :tensor, :KernelTensorProduct),
5+
)
6+
@eval begin
7+
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)
8+
9+
$M.$op(k1::$T, k2::$T) = $T(k1.kernels..., k2.kernels...)
10+
function $M.$op(
11+
k1::$T{<:AbstractVector{<:Kernel}}, k2::$T{<:AbstractVector{<:Kernel}}
12+
)
13+
return $T(vcat(k1.kernels, k2.kernels))
14+
end
15+
16+
$M.$op(k::Kernel, ks::$T) = $T(k, ks.kernels...)
17+
$M.$op(k::Kernel, ks::$T{<:AbstractVector{<:Kernel}}) = $T(vcat(k, ks.kernels))
18+
19+
$M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k)
20+
$M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k))
21+
end
22+
end

0 commit comments

Comments
 (0)