Skip to content

Commit 09cd20e

Browse files
committed
Change to output covariance type - revert
1 parent f46bd61 commit 09cd20e

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

src/mokernels/independent.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@ function (κ::IndependentMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,I
2727
return κ.kernel(x, y) * (px == py)
2828
end
2929

30-
_mo_output_covariance(k::IndependentMOKernel, out_dim) = Eye{Bool}(out_dim)
30+
function _mo_output_covariance(k::IndependentMOKernel, Kfeatures, out_dim)
31+
return Eye{eltype(Kfeatures)}(out_dim)
32+
end
3133

3234
function kernelmatrix(
3335
k::IndependentMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
3436
)
3537
@assert x.out_dim == y.out_dim
3638
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
37-
Koutputs = _mo_output_covariance(k, x.out_dim)
39+
Koutputs = _mo_output_covariance(k, Kfeatures, x.out_dim)
3840
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
3941
end
4042

@@ -47,7 +49,7 @@ if VERSION >= v"1.6"
4749
)
4850
@assert x.out_dim == y.out_dim
4951
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
50-
Koutputs = _mo_output_covariance(k, x.out_dim)
52+
Koutputs = _mo_output_covariance(k, Kfeatures, x.out_dim)
5153
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
5254
end
5355
end

src/mokernels/intrinsiccoregion.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{
4242
return k.B[px, py] * k.kernel(x, y)
4343
end
4444

45-
function _mo_output_covariance(k::IntrinsicCoregionMOKernel, out_dim)
45+
function _mo_output_covariance(k::IntrinsicCoregionMOKernel, Kfeatures, out_dim)
4646
@assert size(k.B) == (out_dim, out_dim)
4747
return k.B
4848
end
@@ -52,7 +52,7 @@ function kernelmatrix(
5252
)
5353
@assert x.out_dim == y.out_dim
5454
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
55-
Koutputs = _mo_output_covariance(k, x.out_dim)
55+
Koutputs = _mo_output_covariance(k, Kfeatures, x.out_dim)
5656
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
5757
end
5858

@@ -65,7 +65,7 @@ if VERSION >= v"1.6"
6565
)
6666
@assert x.out_dim == y.out_dim
6767
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
68-
Koutputs = _mo_output_covariance(k, x.out_dim)
68+
Koutputs = _mo_output_covariance(k, Kfeatures, x.out_dim)
6969
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
7070
end
7171
end

0 commit comments

Comments
 (0)