Skip to content

Commit dd67abb

Browse files
fix: fix and test getu and getp with empty arrays
1 parent e7dd822 commit dd67abb

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

src/parameter_indexing.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <:
323323
end
324324

325325
function MultipleParametersGetter(getters)
326+
if isempty(getters)
327+
return MultipleParametersGetter{IndexerNotTimeseries, typeof(getters), Nothing}(getters, nothing)
328+
end
326329
has_timeseries_indexers = any(getters) do g
327330
is_indexer_timeseries(g) == IndexerTimeseries()
328331
end

src/state_indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ for (t1, t2) in [
236236
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
237237
num_observed = count(x -> is_observed(sys, x), sym)
238238
if num_observed == 0 || num_observed == 1 && sym isa Tuple
239-
if all(Base.Fix1(is_parameter, sys), sym) &&
239+
if !isempty(sym) && all(Base.Fix1(is_parameter, sys), sym) &&
240240
all(!Base.Fix1(is_timeseries_parameter, sys), sym)
241241
GetpAtStateTime(getp(sys, sym))
242242
else

test/parameter_indexing_test.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,19 @@ for sys in [
123123
@test buffer == collect(val)
124124
end
125125
end
126+
127+
getter = getp(sys, [])
128+
@test getter(fi) == []
126129
end
127130
end
128131

132+
let
133+
sc = SymbolCache(nothing, nothing, :t)
134+
fi = FakeIntegrator(sc, nothing, 0.0, Ref(0))
135+
getter = getp(sc, [])
136+
@test getter(fi) == []
137+
end
138+
129139
struct MyDiffEqArray
130140
t::Vector{Float64}
131141
u::Vector{Vector{Float64}}
@@ -387,6 +397,11 @@ end
387397

388398
@test_throws ErrorException getp(sys, :not_a_param)
389399

400+
let fs = fs, sys = sys
401+
getter = getp(sys, [])
402+
@test getter(fs) == []
403+
end
404+
390405
struct FakeNoTimeSolution
391406
sys::SymbolCache
392407
u::Vector{Float64}

test/state_indexing_test.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,15 @@ for (sym, val, check_inference) in [
9191
@test get(fi) == val
9292
end
9393

94+
let fi = fi, sys = sys
95+
getter = getu(sys, [])
96+
@test getter(fi) == []
97+
sc = SymbolCache(nothing, [:a, :b], :t)
98+
fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0)
99+
getter = getu(sc, [])
100+
@test getter(fi) == []
101+
end
102+
94103
for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
95104
(:b, p[2], 5.0, true)
96105
(:c, p[3], 6.0, true)
@@ -232,6 +241,15 @@ for (sym, val) in [(:a, p[1])
232241
@test get(sol) == val
233242
end
234243

244+
let sol = sol, sys = sys
245+
getter = getu(sys, [])
246+
@test getter(sol) == [[] for _ in 1:length(sol.t)]
247+
sc = SymbolCache(nothing, [:a, :b], :t)
248+
sol = FakeSolution(sys, nothing, [1.0, 2.0], [0.0])
249+
getter = getu(sc, [])
250+
@test getter(sol) == [[]]
251+
end
252+
235253
sys = SymbolCache([:x, :y, :z], [:a, :b, :c])
236254
u = [1.0, 2.0, 3.0]
237255
p = [10.0, 20.0, 30.0]

0 commit comments

Comments
 (0)