@@ -26,9 +26,6 @@ function (kernel::TensorProduct)(x, y)
26
26
return prod (k (xi, yi) for (k, xi, yi) in zip (kernel. kernels, x, y))
27
27
end
28
28
29
- # TODO : General implementation of `kernelmatrix` and `kerneldiagmatrix`
30
- # Default implementation assumes 1D observations
31
-
32
29
function validate_domain (k:: TensorProduct , x:: AbstractVector )
33
30
dim (x) == length (k) ||
34
31
error (" number of kernels and groups of features are not consistent" )
@@ -70,25 +67,6 @@ function kernelmatrix!(
70
67
return K
71
68
end
72
69
73
- # mapreduce with multiple iterators requires Julia 1.2 or later.
74
-
75
- function kernelmatrix (k:: TensorProduct , x:: AbstractVector )
76
- validate_domain (k, x)
77
-
78
- return mapreduce ((x, y) -> x .* y, zip (k. kernels, slices (x))) do (k, xi)
79
- kernelmatrix (k, xi)
80
- end
81
- end
82
-
83
- function kernelmatrix (k:: TensorProduct , x:: AbstractVector , y:: AbstractVector )
84
- validate_domain (k, x)
85
-
86
- kernels_and_inputs = zip (k. kernels, slices (x), slices (y))
87
- return mapreduce ((x, y) -> x .* y, kernels_and_inputs) do (k, xi, yi)
88
- kernelmatrix (k, xi, yi)
89
- end
90
- end
91
-
92
70
function kerneldiagmatrix! (K:: AbstractVector , k:: TensorProduct , x:: AbstractVector )
93
71
validate_inplace_dims (K, x)
94
72
validate_domain (k, x)
@@ -102,13 +80,19 @@ function kerneldiagmatrix!(K::AbstractVector, k::TensorProduct, x::AbstractVecto
102
80
return K
103
81
end
104
82
105
- function kerneldiagmatrix (k:: TensorProduct , x:: AbstractVector )
83
+ function kernelmatrix (k:: TensorProduct , x:: AbstractVector )
106
84
validate_domain (k, x)
85
+ return mapreduce (kernelmatrix, hadamard, k. kernels, slices (x))
86
+ end
107
87
108
- kernels_and_inputs = zip (k. kernels, slices (x))
109
- return mapreduce ((x, y) -> x .* y, kernels_and_inputs) do (k, xi)
110
- kerneldiagmatrix (k, xi)
111
- end
88
+ function kernelmatrix (k:: TensorProduct , x:: AbstractVector , y:: AbstractVector )
89
+ validate_domain (k, x)
90
+ return mapreduce (kernelmatrix, hadamard, k. kernels, slices (x), slices (y))
91
+ end
92
+
93
+ function kerneldiagmatrix (k:: TensorProduct , x:: AbstractVector )
94
+ validate_domain (k, x)
95
+ return mapreduce (kerneldiagmatrix, hadamard, k. kernels, slices (x))
112
96
end
113
97
114
98
Base. show (io:: IO , kernel:: TensorProduct ) = printshifted (io, kernel, 0 )
0 commit comments