Skip to content

Commit a328556

Browse files
Merge pull request #86 from SciML/as/empty-getu
fix: fix and test `getu` and `getp` with empty arrays
2 parents e7dd822 + 8a1d5f7 commit a328556

File tree

4 files changed

+52
-1
lines changed

4 files changed

+52
-1
lines changed

src/parameter_indexing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,10 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <:
323323
end
324324

325325
function MultipleParametersGetter(getters)
326+
if isempty(getters)
327+
return MultipleParametersGetter{IndexerNotTimeseries, typeof(getters), Nothing}(
328+
getters, nothing)
329+
end
326330
has_timeseries_indexers = any(getters) do g
327331
is_indexer_timeseries(g) == IndexerTimeseries()
328332
end

src/state_indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ for (t1, t2) in [
236236
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
237237
num_observed = count(x -> is_observed(sys, x), sym)
238238
if num_observed == 0 || num_observed == 1 && sym isa Tuple
239-
if all(Base.Fix1(is_parameter, sys), sym) &&
239+
if !isempty(sym) && all(Base.Fix1(is_parameter, sys), sym) &&
240240
all(!Base.Fix1(is_timeseries_parameter, sys), sym)
241241
GetpAtStateTime(getp(sys, sym))
242242
else

test/parameter_indexing_test.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,23 @@ for sys in [
123123
@test buffer == collect(val)
124124
end
125125
end
126+
127+
getter = getp(sys, [])
128+
@test getter(fi) == []
129+
getter = getp(sys, ())
130+
@test getter(fi) == ()
126131
end
127132
end
128133

134+
let
135+
sc = SymbolCache(nothing, nothing, :t)
136+
fi = FakeIntegrator(sc, nothing, 0.0, Ref(0))
137+
getter = getp(sc, [])
138+
@test getter(fi) == []
139+
getter = getp(sc, ())
140+
@test getter(fi) == ()
141+
end
142+
129143
struct MyDiffEqArray
130144
t::Vector{Float64}
131145
u::Vector{Vector{Float64}}
@@ -387,6 +401,13 @@ end
387401

388402
@test_throws ErrorException getp(sys, :not_a_param)
389403

404+
let fs = fs, sys = sys
405+
getter = getp(sys, [])
406+
@test getter(fs) == []
407+
getter = getp(sys, ())
408+
@test getter(fs) == ()
409+
end
410+
390411
struct FakeNoTimeSolution
391412
sys::SymbolCache
392413
u::Vector{Float64}

test/state_indexing_test.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,19 @@ for (sym, val, check_inference) in [
9191
@test get(fi) == val
9292
end
9393

94+
let fi = fi, sys = sys
95+
getter = getu(sys, [])
96+
@test getter(fi) == []
97+
getter = getu(sys, ())
98+
@test getter(fi) == ()
99+
sc = SymbolCache(nothing, [:a, :b], :t)
100+
fi = FakeIntegrator(sys, nothing, [1.0, 2.0], 3.0)
101+
getter = getu(sc, [])
102+
@test getter(fi) == []
103+
getter = getu(sc, ())
104+
@test getter(fi) == ()
105+
end
106+
94107
for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
95108
(:b, p[2], 5.0, true)
96109
(:c, p[3], 6.0, true)
@@ -232,6 +245,19 @@ for (sym, val) in [(:a, p[1])
232245
@test get(sol) == val
233246
end
234247

248+
let sol = sol, sys = sys
249+
getter = getu(sys, [])
250+
@test getter(sol) == [[] for _ in 1:length(sol.t)]
251+
getter = getu(sys, ())
252+
@test getter(sol) == [() for _ in 1:length(sol.t)]
253+
sc = SymbolCache(nothing, [:a, :b], :t)
254+
sol = FakeSolution(sys, [], [1.0, 2.0], [])
255+
getter = getu(sc, [])
256+
@test getter(sol) == []
257+
getter = getu(sc, ())
258+
@test getter(sol) == []
259+
end
260+
235261
sys = SymbolCache([:x, :y, :z], [:a, :b, :c])
236262
u = [1.0, 2.0, 3.0]
237263
p = [10.0, 20.0, 30.0]

0 commit comments

Comments
 (0)