Skip to content

Commit 93e888c

Browse files
fix: fix observed function generation with array variables
1 parent 2f67a8b commit 93e888c

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,10 @@ function build_explicit_observed_function(sys, ts;
380380
ps = full_parameters(sys),
381381
op = Operator,
382382
throw = true)
383-
if (isscalar = !(ts isa AbstractVector))
383+
if (isscalar = symbolic_type(ts) !== NotSymbolic())
384384
ts = [ts]
385385
end
386-
ts = unwrap.(Symbolics.scalarize(ts))
386+
ts = unwrap.(ts)
387387

388388
vars = Set()
389389
foreach(v -> vars!(vars, v; op), ts)
@@ -399,9 +399,13 @@ function build_explicit_observed_function(sys, ts;
399399
end
400400

401401
sts = Set(unknowns(sys))
402+
sts = union(sts, Set(arguments(st)[1] for st in sts if istree(st) && operation(st) === getindex))
403+
402404
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
403405
param_set = Set(parameters(sys))
406+
param_set = union(param_set, Set(arguments(p)[1] for p in param_set if istree(p) && operation(p) === getindex))
404407
param_set_ns = Set(unknowns(sys, p) for p in parameters(sys))
408+
param_set_ns = union(param_set_ns, Set(arguments(p)[1] for p in param_set_ns if istree(p) && operation(p) === getindex))
405409
namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs)
406410
namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys))
407411

@@ -475,7 +479,7 @@ function build_explicit_observed_function(sys, ts;
475479
ex = Func(args, [],
476480
pre(Let(obsexprs,
477481
isscalar ? ts[1] : MakeArray(ts, output_type),
478-
false))) |> toexpr
482+
false))) |> wrap_array_vars(sys, ts)[1] |> toexpr
479483
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
480484
end
481485

test/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ eqs = [D(x) ~ foo(x, ms); D(ms) ~ bar(ms, p)]
543543
prob = ODEProblem(
544544
outersys, [sys.x => 1.0, sys.ms => 1:3], (0.0, 1.0), [sys.p => ones(3, 3)])
545545
@test_nowarn solve(prob, Tsit5())
546+
obsfn = observed(outersys, bar(3x, 3p))
546547

547548
# x/x
548549
@variables x(t)

0 commit comments

Comments
 (0)