Skip to content

Commit b7bf09b

Browse files
Merge pull request #2791 from AayushSabharwal/as/sym-obs
feat: store observed equation lhs in `symbol_to_variable` mapping
2 parents c62cff0 + 3e7a721 commit b7bf09b

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

src/systems/abstractsystem.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,10 +495,40 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
495495
end
496496

497497
function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
498-
return let _fn = build_explicit_observed_function(sys, sym)
499-
fn(u, p, t) = _fn(u, p, t)
500-
fn(u, p::MTKParameters, t) = _fn(u, p..., t)
501-
fn
498+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
499+
if sym isa Symbol
500+
_sym = get(ic.symbol_to_variable, sym, nothing)
501+
if _sym === nothing
502+
throw(ArgumentError("Symbol $sym does not exist in the system"))
503+
end
504+
sym = _sym
505+
elseif sym isa AbstractArray && symbolic_type(sym) isa NotSymbolic &&
506+
any(x -> x isa Symbol, sym)
507+
sym = map(sym) do s
508+
if s isa Symbol
509+
_s = get(ic.symbol_to_variable, s, nothing)
510+
if _s === nothing
511+
throw(ArgumentError("Symbol $s does not exist in the system"))
512+
end
513+
return _s
514+
end
515+
return unwrap(s)
516+
end
517+
end
518+
end
519+
_fn = build_explicit_observed_function(sys, sym)
520+
if is_time_dependent(sys)
521+
return let _fn = _fn
522+
fn1(u, p, t) = _fn(u, p, t)
523+
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
524+
fn1
525+
end
526+
else
527+
return let _fn = _fn
528+
fn2(u, p) = _fn(u, p)
529+
fn2(u, p::MTKParameters) = _fn(u, p...)
530+
fn2
531+
end
502532
end
503533
end
504534

src/systems/index_cache.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ function IndexCache(sys::AbstractSystem)
8080
end
8181
end
8282

83+
for eq in observed(sys)
84+
if symbolic_type(eq.lhs) != NotSymbolic() && hasname(eq.lhs)
85+
symbol_to_variable[getname(eq.lhs)] = eq.lhs
86+
end
87+
end
88+
8389
disc_buffers = Dict{Any, Set{BasicSymbolic}}()
8490
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
8591
constant_buffers = Dict{Any, Set{BasicSymbolic}}()

test/symbolic_indexing_interface.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,16 @@ using SymbolicIndexingInterface
101101
prob = ODEProblem(complete(sys))
102102
get_dep = @test_nowarn getu(prob, 2p1)
103103
@test get_dep(prob) == [2.0, 4.0]
104+
105+
@testset "Observed functions with variables as `Symbol`s" begin
106+
@variables x(t) y(t) z(t)[1:2]
107+
@parameters p1 p2[1:2, 1:2]
108+
@mtkbuild sys = ODESystem([D(x) ~ x * t + p1, y ~ 2x, D(z) ~ p2 * z], t)
109+
prob = ODEProblem(
110+
sys, [x => 1.0, z => ones(2)], (0.0, 1.0), [p1 => 2.0, p2 => ones(2, 2)])
111+
@test getu(prob, x)(prob) == getu(prob, :x)(prob)
112+
@test getu(prob, [x, y])(prob) == getu(prob, [:x, :y])(prob)
113+
@test getu(prob, z)(prob) == getu(prob, :z)(prob)
114+
@test getu(prob, p1)(prob) == getu(prob, :p1)(prob)
115+
@test getu(prob, p2)(prob) == getu(prob, :p2)(prob)
116+
end

0 commit comments

Comments
 (0)