Skip to content

Commit 6e7ca17

Browse files
Crown421theogfwilltebbuttst--
authored
Fix input types, improve readability (#369)
* Fix input types, improve readability * Add missing bit * Add doc string * Fix mistake * Add docstring to docs * Reformulate * Bump version * Update src/matrix/kernelkroneckermat.jl Co-authored-by: Théo Galy-Fajou <[email protected]> * Bump version further, rename api section * Apply format suggestions from code review Co-authored-by: willtebbutt <[email protected]> * Improve error handling. Co-authored-by: st-- <[email protected]> * Formatter Co-authored-by: Théo Galy-Fajou <[email protected]> Co-authored-by: willtebbutt <[email protected]> Co-authored-by: st-- <[email protected]>
1 parent 3356fa6 commit 6e7ca17

File tree

6 files changed

+61
-38
lines changed

6 files changed

+61
-38
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.19"
3+
version = "0.10.20"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/api.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,12 @@ kernelpdmat
8888
nystrom
8989
NystromFact
9090
```
91+
92+
## Conditional Utilities
93+
To keep the dependencies of KernelFunctions lean, some functionality is only available if specific other packages are explicitly loaded (`using`).
94+
95+
### Kronecker.jl
96+
[*https://github.com/MichielStock/Kronecker.jl*](https://github.com/MichielStock/Kronecker.jl)
97+
```@docs
98+
kronecker_kernelmatrix
99+
```

src/matrix/kernelkroneckermat.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,43 @@ end
2727
"""
2828
@inline iskroncompatible::Kernel) = false # Default return for kernels
2929

30-
function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
30+
function _kernelmatrix_kroneckerjl_helper(
31+
::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
32+
)
3133
return Kronecker.kronecker(Kfeatures, Koutputs)
3234
end
3335

34-
function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
36+
function _kernelmatrix_kroneckerjl_helper(
37+
::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs
38+
)
3539
return Kronecker.kronecker(Koutputs, Kfeatures)
3640
end
3741

42+
"""
43+
kronecker_kernelmatrix(
44+
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI, y::MOI
45+
) where {MOI<:IsotopicMOInputsUnion}
46+
47+
Requires Kronecker.jl: Computes the `kernelmatrix` for the `IndependentMOKernel` and the
48+
`IntrinsicCoregionMOKernel`, but returns a lazy kronecker product. This object can be very
49+
efficiently inverted or decomposed. See also [`kernelmatrix`](@ref).
50+
"""
3851
function kronecker_kernelmatrix(
39-
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel},
40-
x::IsotopicMOInputsUnion,
41-
y::IsotopicMOInputsUnion,
42-
)
43-
@assert x.out_dim == y.out_dim
52+
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI, y::MOI
53+
) where {MOI<:IsotopicMOInputsUnion}
54+
x.out_dim == y.out_dim ||
55+
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
4456
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
4557
Koutputs = _mo_output_covariance(k, x.out_dim)
46-
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
58+
return _kernelmatrix_kroneckerjl_helper(MOI, Kfeatures, Koutputs)
4759
end
4860

4961
function kronecker_kernelmatrix(
50-
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::IsotopicMOInputsUnion
51-
)
62+
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI
63+
) where {MOI<:IsotopicMOInputsUnion}
5264
Kfeatures = kernelmatrix(k.kernel, x.x)
5365
Koutputs = _mo_output_covariance(k, x.out_dim)
54-
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
66+
return _kernelmatrix_kroneckerjl_helper(MOI, Kfeatures, Koutputs)
5567
end
5668

5769
function kronecker_kernelmatrix(

src/mokernels/independent.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,24 @@ end
3030
_mo_output_covariance(k::IndependentMOKernel, out_dim) = Eye{Bool}(out_dim)
3131

3232
function kernelmatrix(
33-
k::IndependentMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
34-
)
35-
@assert x.out_dim == y.out_dim
33+
k::IndependentMOKernel, x::MOI, y::MOI
34+
) where {MOI<:IsotopicMOInputsUnion}
35+
x.out_dim == y.out_dim ||
36+
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
3637
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
3738
Koutputs = _mo_output_covariance(k, x.out_dim)
38-
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
39+
return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs)
3940
end
4041

4142
if VERSION >= v"1.6"
4243
function kernelmatrix!(
43-
K::AbstractMatrix,
44-
k::IndependentMOKernel,
45-
x::IsotopicMOInputsUnion,
46-
y::IsotopicMOInputsUnion,
47-
)
48-
@assert x.out_dim == y.out_dim
44+
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
45+
) where {MOI<:IsotopicMOInputsUnion}
46+
x.out_dim == y.out_dim ||
47+
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
4948
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
5049
Koutputs = _mo_output_covariance(k, x.out_dim)
51-
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
50+
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
5251
end
5352
end
5453

src/mokernels/intrinsiccoregion.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,24 @@ function _mo_output_covariance(k::IntrinsicCoregionMOKernel, out_dim)
4848
end
4949

5050
function kernelmatrix(
51-
k::IntrinsicCoregionMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
52-
)
53-
@assert x.out_dim == y.out_dim
51+
k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
52+
) where {MOI<:IsotopicMOInputsUnion}
53+
x.out_dim == y.out_dim ||
54+
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
5455
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
5556
Koutputs = _mo_output_covariance(k, x.out_dim)
56-
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
57+
return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs)
5758
end
5859

5960
if VERSION >= v"1.6"
6061
function kernelmatrix!(
61-
K::AbstractMatrix,
62-
k::IntrinsicCoregionMOKernel,
63-
x::IsotopicMOInputsUnion,
64-
y::IsotopicMOInputsUnion,
65-
)
66-
@assert x.out_dim == y.out_dim
62+
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
63+
) where {MOI<:IsotopicMOInputsUnion}
64+
x.out_dim == y.out_dim ||
65+
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
6766
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
6867
Koutputs = _mo_output_covariance(k, x.out_dim)
69-
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
68+
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
7069
end
7170
end
7271

src/mokernels/mokernel.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@ Abstract type for kernels with multiple outpus.
55
"""
66
abstract type MOKernel <: Kernel end
77

8-
function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
8+
function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs)
99
return kron(Kfeatures, Koutputs)
1010
end
1111

12-
function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
12+
function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs)
1313
return kron(Koutputs, Kfeatures)
1414
end
1515

1616
if VERSION >= v"1.6"
17-
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
17+
function _kernelmatrix_kron_helper!(
18+
K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
19+
)
1820
return kron!(K, Kfeatures, Koutputs)
1921
end
2022

21-
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
23+
function _kernelmatrix_kron_helper!(
24+
K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs
25+
)
2226
return kron!(K, Koutputs, Kfeatures)
2327
end
2428
end

0 commit comments

Comments
 (0)