Skip to content

Commit 635595a

Browse files
fix: observed getu generation, tuple wrapper
1 parent c6bf421 commit 635595a

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/state_indexing.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,14 @@ struct AsTupleWrapper{G} <: AbstractIndexer
135135
getter::G
136136
end
137137

138-
function (atw::AsTupleWrapper)(::IsTimeseriesTrait, args...)
139-
return Tuple(atw.getter(args...))
138+
function (atw::AsTupleWrapper)(::Timeseries, prob)
139+
return Tuple.(atw.getter(prob))
140+
end
141+
function (atw::AsTupleWrapper)(::Timeseries, prob, i)
142+
return Tuple(atw.getter(prob, i))
143+
end
144+
function (atw::AsTupleWrapper)(::NotTimeseries, prob)
145+
return Tuple(atw.getter(prob))
140146
end
141147

142148
for (t1, t2) in [
@@ -151,7 +157,7 @@ for (t1, t2) in [
151157
return MultipleGetters(getters)
152158
else
153159
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
154-
getter = if is_timeseries(sys)
160+
getter = if is_time_dependent(sys)
155161
TimeDependentObservedFunction(obs)
156162
else
157163
TimeIndependentObservedFunction(obs)

test/state_indexing_test.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
7575
@test get(u) == val
7676
end
7777

78+
for (sym, val, check_inference) in [
79+
(:(x + y), u[1] + u[2], true),
80+
([:(x + y), :z], [u[1] + u[2], u[3]], false),
81+
((:(x + y), :(z + y)), (u[1] + u[2], u[2] + u[3]), false)
82+
]
83+
get = getu(sys, sym)
84+
if check_inference
85+
@inferred get(fi)
86+
end
87+
@test get(fi) == val
88+
end
89+
7890
for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
7991
(:b, p[2], 5.0, true)
8092
(:c, p[3], 6.0, true)
@@ -101,7 +113,7 @@ end
101113
for (sym, val, check_inference) in [
102114
(:t, t, true),
103115
([:x, :a, :t], [u[1], p[1], t], false),
104-
((:x, :a, :t), (u[1], p[1], t), true)
116+
((:x, :a, :t), (u[1], p[1], t), false)
105117
]
106118
get = getu(fi, sym)
107119
if check_inference
@@ -182,6 +194,18 @@ for (sym, ans, check_inference) in [(:x, xvals, true)
182194
end
183195
end
184196

197+
for (sym, val, check_inference) in [
198+
(:(x + y), xvals .+ yvals, true),
199+
([:(x + y), :z], vcat.(xvals .+ yvals, zvals), false),
200+
((:(x + y), :(z + y)), tuple.(xvals .+ yvals, yvals .+ zvals), false)
201+
]
202+
get = getu(sys, sym)
203+
if check_inference
204+
@inferred get(sol)
205+
end
206+
@test get(sol) == val
207+
end
208+
185209
for (sym, val) in [(:a, p[1])
186210
(:b, p[2])
187211
(:c, p[3])

0 commit comments

Comments
 (0)