Skip to content

Commit b5bb9e7

Browse files
Adding a MOInputHeterotopic type
1 parent 9f708d0 commit b5bb9e7

File tree

4 files changed

+71
-2
lines changed

4 files changed

+71
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.23"
3+
version = "0.10.24"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
3737

3838
export ColVecs, RowVecs
3939

40-
export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_output_data
40+
export MOInput, MOInputHeterotopic,
41+
prepare_isotopic_multi_output_data, prepare_heterotopic_multi_output_data
4142
export IndependentMOKernel,
4243
LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel
4344

src/mokernels/moinput.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,43 @@ struct MOInputIsotopicByOutputs{S,T<:AbstractVector{S}} <: AbstractVector{Tuple{
5858
out_dim::Integer
5959
end
6060

61+
"""
62+
MOInputsHeterotopic(x::AbstractVector, output_indices::Integer)
63+
64+
`MOInputsHeterotopic(x, output_indices)` has length `length(x)`.
65+
66+
```jldoctest
67+
julia> x = [1, 2, 3, 4, 5, 6];
68+
69+
julia> out_inds = [1, 1, 2, 3, 2, 1];
70+
71+
julia> KernelFunctions.MOInputsHeterotopic(x, out_inds)
72+
6-element KernelFunctions.MOInputsHeterotopic{Int64, Vector{Int64}}:
73+
(1, 1)
74+
(2, 1)
75+
(3, 2)
76+
(4, 3)
77+
(5, 2)
78+
(6, 1)
79+
```
80+
81+
Accommodates modelling multi-dimensional output data where not all outputs are observed
82+
for every input.
83+
84+
As shown above, an `MOInputsHeterotopic` represents a vector of tuples.
85+
The `length(x)` elements represent the inputs that are observed at the locations specified
86+
by `output_indices`.
87+
"""
88+
struct MOInputsHeterotopic{S ,T<:AbstractVector{S}} <: AbstractVector{Tuple{S,Int}}
89+
x::T
90+
output_indices::AbstractVector{Int}
91+
end
92+
93+
# Return the inputs at a specific output
94+
function get_inputs_at_output(inp::MOInputsHeterotopic, output)
95+
return [input[1] for input in inputs if input[2]==output]
96+
end
97+
6198
const IsotopicMOInputsUnion = Union{MOInputIsotopicByFeatures,MOInputIsotopicByOutputs}
6299

63100
function Base.getindex(inp::MOInputIsotopicByOutputs, ind::Integer)
@@ -74,7 +111,13 @@ function Base.getindex(inp::MOInputIsotopicByFeatures, ind::Integer)
74111
return feature, output_index
75112
end
76113

114+
function Base.getindex(inp::MOInputsHeterotopic, ind::Integer)
115+
@boundscheck checkbounds(inp, ind)
116+
return inp.x[ind], inp.output_indices[ind]
117+
end
118+
77119
Base.size(inp::IsotopicMOInputsUnion) = (inp.out_dim * length(inp.x),)
120+
Base.size(inp::MOInputsHeterotopic) = (length(inp.output_indices),)
78121

79122
function Base.vcat(x::MOInputIsotopicByFeatures, y::MOInputIsotopicByFeatures)
80123
x.out_dim == y.out_dim || throw(DimensionMismatch("out_dim mismatch"))
@@ -86,6 +129,10 @@ function Base.vcat(x::MOInputIsotopicByOutputs, y::MOInputIsotopicByOutputs)
86129
return MOInputIsotopicByOutputs(vcat(x.x, y.x), x.out_dim)
87130
end
88131

132+
function Base.vcat(x::MOInputsHeterotopic, y::MOInputsHeterotopic)
133+
return MOInputsHeterotopic(vcat(x.x, y.x), vcat(x.output_indices, y.output_indices))
134+
end
135+
89136
"""
90137
MOInput(x::AbstractVector, out_dim::Integer)
91138

test/mokernels/moinput.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@
4848
@test all([(x_, i) for x_ in x for i in 1:3] .== ibf)
4949
end
5050

51+
@testset "heterotopic" begin
52+
out_inds = [1, 2, 3, 2]
53+
mo_input = KernelFunctions.MOInputsHeterotopic(x, out_inds)
54+
@test isa(mo_input, type_1) == true
55+
@test isa(mo_input, type_2) == false
56+
57+
@test length(mo_input) == 4
58+
@test size(mo_input) == (4,)
59+
@test size(mo_input, 1) == 4
60+
@test size(mo_input, 2) == 1
61+
@test lastindex(mo_input) == 4
62+
@test firstindex(mo_input) == 1
63+
@test_throws BoundsError mo_input[0]
64+
@test vcat(mo_input, mo_input) == KernelFunctions.MOInputsHeterotopic(vcat(x, x), vcat(out_inds, out_inds))
65+
66+
@test mo_input[2] == (x[2], 2)
67+
@test mo_input[3] == (x[3], 3)
68+
@test mo_input[4] == (x[4], 2)
69+
@test all([(x_, i) for (x_, i) in zip(x, out_inds)] .== mo_input)
70+
end
71+
5172
@testset "prepare_isotopic_multi_output_data" begin
5273
@testset "ColVecs" begin
5374
N = 5

0 commit comments

Comments
 (0)