Skip to content

Commit 27df802

Browse files
Merge pull request #55 from SciML/as/getu-param
fix: fix `getu` with parameter symbols
2 parents a7c70c8 + ad18dbf commit 27df802

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/state_indexing.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,16 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
9595
return getu(sys, idx)
9696
elseif is_parameter(sys, sym)
9797
return let fn = getp(sys, sym)
98-
getter(prob, args...) = fn(prob)
99-
getter
98+
_getter_p(::NotTimeseries, prob) = fn(prob)
99+
function _getter_p(::Timeseries, prob)
100+
[fn(parameter_values_at_state_time(prob, i))
101+
for i in eachindex(current_time(prob))]
102+
end
103+
_getter_p(::Timeseries, prob, i) = fn(parameter_values_at_state_time(prob, i))
104+
let _getter = _getter_p
105+
getter(prob, args...) = _getter(is_timeseries(prob), prob, args...)
106+
getter
107+
end
100108
end
101109
elseif is_independent_variable(sys, sym)
102110
_getter(::IsTimeseriesTrait, prob) = current_time(prob)

test/parameter_indexing_test.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ end
9090
function Base.getproperty(fs::FakeSolution, s::Symbol)
9191
s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s)
9292
end
93+
SymbolicIndexingInterface.state_values(fs::FakeSolution) = fs.u
94+
SymbolicIndexingInterface.current_time(fs::FakeSolution) = fs.t
9395
SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys
9496
SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p[end]
9597
SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[end][i]
@@ -149,3 +151,20 @@ for (sym, val, arrval, check_inference) in [
149151
@test get(fs, sub_inds) == arrval[sub_inds]
150152
end
151153
end
154+
155+
ps = fs.p[2:2:end]
156+
avals = getindex.(ps, 1)
157+
bvals = getindex.(ps, 2)
158+
cvals = getindex.(ps, 3)
159+
for (sym, val, arrval) in [
160+
(:a, p[1], avals),
161+
((:b, :c), p[2:3], tuple.(bvals, cvals)),
162+
([:c, :a], p[[3, 1]], vcat.(cvals, avals))
163+
]
164+
get = getu(sys, sym)
165+
@inferred get(fs)
166+
@test get(fs) == arrval
167+
for i in eachindex(ps)
168+
@test get(fs, i) == arrval[i]
169+
end
170+
end

0 commit comments

Comments
 (0)