Skip to content

Commit ae65529

Browse files
fix: fix observed function generation with array variables
1 parent 2f67a8b commit ae65529

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,10 @@ function build_explicit_observed_function(sys, ts;
380380
ps = full_parameters(sys),
381381
op = Operator,
382382
throw = true)
383-
if (isscalar = !(ts isa AbstractVector))
383+
if (isscalar = symbolic_type(ts) !== NotSymbolic())
384384
ts = [ts]
385385
end
386-
ts = unwrap.(Symbolics.scalarize(ts))
386+
ts = unwrap.(ts)
387387

388388
vars = Set()
389389
foreach(v -> vars!(vars, v; op), ts)
@@ -399,9 +399,17 @@ function build_explicit_observed_function(sys, ts;
399399
end
400400

401401
sts = Set(unknowns(sys))
402+
sts = union(sts,
403+
Set(arguments(st)[1] for st in sts if istree(st) && operation(st) === getindex))
404+
402405
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
403406
param_set = Set(parameters(sys))
407+
param_set = union(param_set,
408+
Set(arguments(p)[1] for p in param_set if istree(p) && operation(p) === getindex))
404409
param_set_ns = Set(unknowns(sys, p) for p in parameters(sys))
410+
param_set_ns = union(param_set_ns,
411+
Set(arguments(p)[1]
412+
for p in param_set_ns if istree(p) && operation(p) === getindex))
405413
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
406414
namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys))
407415

@@ -473,9 +481,9 @@ function build_explicit_observed_function(sys, ts;
473481
pre = get_postprocess_fbody(sys)
474482

475483
ex = Func(args, [],
476-
pre(Let(obsexprs,
477-
isscalar ? ts[1] : MakeArray(ts, output_type),
478-
false))) |> toexpr
484+
pre(Let(obsexprs,
485+
isscalar ? ts[1] : MakeArray(ts, output_type),
486+
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
479487
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
480488
end
481489

test/odesystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,9 @@ eqs = [D(x) ~ foo(x, ms); D(ms) ~ bar(ms, p)]
543543
prob = ODEProblem(
544544
outersys, [sys.x => 1.0, sys.ms => 1:3], (0.0, 1.0), [sys.p => ones(3, 3)])
545545
@test_nowarn solve(prob, Tsit5())
546+
obsfn = ModelingToolkit.build_explicit_observed_function(
547+
outersys, bar(3outersys.sys.ms, 3outersys.sys.p))
548+
@test_nowarn obsfn(sol.u[1], prob.p..., sol.t[1])
546549

547550
# x/x
548551
@variables x(t)

0 commit comments

Comments
 (0)