Skip to content

Commit 3e7a721

Browse files
feat: support generating observed for Symbol variables
1 parent edf6fcb commit 3e7a721

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

src/systems/abstractsystem.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,10 +495,40 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
495495
end
496496

497497
function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
498-
return let _fn = build_explicit_observed_function(sys, sym)
499-
fn(u, p, t) = _fn(u, p, t)
500-
fn(u, p::MTKParameters, t) = _fn(u, p..., t)
501-
fn
498+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
499+
if sym isa Symbol
500+
_sym = get(ic.symbol_to_variable, sym, nothing)
501+
if _sym === nothing
502+
throw(ArgumentError("Symbol $sym does not exist in the system"))
503+
end
504+
sym = _sym
505+
elseif sym isa AbstractArray && symbolic_type(sym) isa NotSymbolic &&
506+
any(x -> x isa Symbol, sym)
507+
sym = map(sym) do s
508+
if s isa Symbol
509+
_s = get(ic.symbol_to_variable, s, nothing)
510+
if _s === nothing
511+
throw(ArgumentError("Symbol $s does not exist in the system"))
512+
end
513+
return _s
514+
end
515+
return unwrap(s)
516+
end
517+
end
518+
end
519+
_fn = build_explicit_observed_function(sys, sym)
520+
if is_time_dependent(sys)
521+
return let _fn = _fn
522+
fn1(u, p, t) = _fn(u, p, t)
523+
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
524+
fn1
525+
end
526+
else
527+
return let _fn = _fn
528+
fn2(u, p) = _fn(u, p)
529+
fn2(u, p::MTKParameters) = _fn(u, p...)
530+
fn2
531+
end
502532
end
503533
end
504534

test/symbolic_indexing_interface.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,16 @@ using SymbolicIndexingInterface
101101
prob = ODEProblem(complete(sys))
102102
get_dep = @test_nowarn getu(prob, 2p1)
103103
@test get_dep(prob) == [2.0, 4.0]
104+
105+
@testset "Observed functions with variables as `Symbol`s" begin
106+
@variables x(t) y(t) z(t)[1:2]
107+
@parameters p1 p2[1:2, 1:2]
108+
@mtkbuild sys = ODESystem([D(x) ~ x * t + p1, y ~ 2x, D(z) ~ p2 * z], t)
109+
prob = ODEProblem(
110+
sys, [x => 1.0, z => ones(2)], (0.0, 1.0), [p1 => 2.0, p2 => ones(2, 2)])
111+
@test getu(prob, x)(prob) == getu(prob, :x)(prob)
112+
@test getu(prob, [x, y])(prob) == getu(prob, [:x, :y])(prob)
113+
@test getu(prob, z)(prob) == getu(prob, :z)(prob)
114+
@test getu(prob, p1)(prob) == getu(prob, :p1)(prob)
115+
@test getu(prob, p2)(prob) == getu(prob, :p2)(prob)
116+
end

0 commit comments

Comments
 (0)