Skip to content

Commit a4628af

Browse files
fix: fix callback codegen, observed eqs with non-scalarized symbolic arrays
1 parent 2a0938c commit a4628af

File tree

3 files changed

+70
-19
lines changed

3 files changed

+70
-19
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,43 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
551551
end
552552

553553
sys = state.sys
554+
555+
obs_sub = dummy_sub
556+
for eq in neweqs
557+
isdiffeq(eq) || continue
558+
obs_sub[eq.lhs] = eq.rhs
559+
end
560+
# TODO: compute the dependency correctly so that we don't have to do this
561+
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
562+
563+
# HACK: Substitute non-scalarized symbolic arrays of observed variables
564+
# E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
565+
# ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
566+
# by the topological sorting and dependency identification pieces
567+
obs_arr_subs = Dict()
568+
569+
for eq in obs
570+
lhs = eq.lhs
571+
istree(lhs) || continue
572+
operation(lhs) === getindex || continue
573+
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
574+
arg1 = arguments(lhs)[1]
575+
haskey(obs_arr_subs, arg1) && continue
576+
obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)]
577+
end
578+
for i in eachindex(neweqs)
579+
neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator)
580+
end
581+
for i in eachindex(obs)
582+
obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator)
583+
end
584+
for i in eachindex(subeqs)
585+
subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator)
586+
end
587+
554588
@set! sys.eqs = neweqs
589+
@set! sys.observed = obs
590+
555591
unknowns = Any[v
556592
for (i, v) in enumerate(fullvars)
557593
if diff_to_var[i] === nothing && ispresent(i)]
@@ -563,15 +599,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
563599
@set! sys.unknowns = unknowns
564600
@set! sys.substitutions = Substitutions(subeqs, deps)
565601

566-
obs_sub = dummy_sub
567-
for eq in equations(sys)
568-
isdiffeq(eq) || continue
569-
obs_sub[eq.lhs] = eq.rhs
570-
end
571-
# TODO: compute the dependency correctly so that we don't have to do this
572-
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
573-
@set! sys.observed = obs
574-
575602
# Only makes sense for time-dependent
576603
# TODO: generalize to SDE
577604
if sys isa ODESystem

src/systems/callbacks.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
388388
return (args...) -> () # We don't do anything in the callback, we're just after the event
389389
end
390390
else
391+
eqs = flatten_equations(eqs)
391392
rhss = map(x -> x.rhs, eqs)
392393
outvar = :u
393394
if outputidxs === nothing
@@ -457,7 +458,7 @@ end
457458

458459
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
459460
ps = full_parameters(sys); kwargs...)
460-
eqs = map(cb -> cb.eqs, cbs)
461+
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
461462
num_eqs = length.(eqs)
462463
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
463464
# fuse equations to create VectorContinuousCallback
@@ -471,12 +472,8 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
471472
rhss = map(x -> x.rhs, eqs)
472473
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
473474

474-
u = map(x -> time_varying_as_func(value(x), sys), dvs)
475-
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
476-
t = get_iv(sys)
477-
pre = get_preprocess_constants(rhss)
478-
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{false},
479-
postprocess_fbody = pre, kwargs...)
475+
rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false},
476+
kwargs...)
480477

481478
affect_functions = map(cbs) do cb # Keep affect function separate
482479
eq_aff = affects(cb)
@@ -487,16 +484,16 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
487484
cond = function (u, t, integ)
488485
if DiffEqBase.isinplace(integ.sol.prob)
489486
tmp, = DiffEqBase.get_tmp_cache(integ)
490-
rf_ip(tmp, u, parameter_values(integ)..., t)
487+
rf_ip(tmp, u, parameter_values(integ), t)
491488
tmp[1]
492489
else
493-
rf_oop(u, parameter_values(integ)..., t)
490+
rf_oop(u, parameter_values(integ), t)
494491
end
495492
end
496493
ContinuousCallback(cond, affect_functions[])
497494
else
498495
cond = function (out, u, t, integ)
499-
rf_ip(out, u, parameter_values(integ)..., t)
496+
rf_ip(out, u, parameter_values(integ), t)
500497
end
501498

502499
# since there may be different number of conditions and affects,

test/odesystem.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,3 +995,30 @@ let # Issue https://github.com/SciML/ModelingToolkit.jl/issues/2322
995995
sol = solve(prob, Rodas4())
996996
@test sol(1)[]0.6065307685451087 rtol=1e-4
997997
end
998+
999+
# Issue#2599
1000+
@variables x(t) y(t)
1001+
eqs = [D(x) ~ x * t, y ~ 2x]
1002+
@mtkbuild sys = ODESystem(eqs, t; continuous_events = [[y ~ 3] => [x ~ 2]])
1003+
prob = ODEProblem(sys, [x => 1.0], (0.0, 10.0))
1004+
@test_nowarn solve(prob, Tsit5())
1005+
1006+
# Issue#2383
1007+
@variables x(t)[1:3]
1008+
@parameters p[1:3, 1:3]
1009+
eqs = [
1010+
D(x) ~ p * x
1011+
]
1012+
@mtkbuild sys = ODESystem(eqs, t; continuous_events = [[norm(x) ~ 3.0] => [x ~ ones(3)]])
1013+
# array affect equations used to not work
1014+
prob1 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
1015+
sol1 = @test_nowarn solve(prob1, Tsit5())
1016+
1017+
# array condition equations also used to not work
1018+
@mtkbuild sys = ODESystem(
1019+
eqs, t; continuous_events = [[x ~ sqrt(3) * ones(3)] => [x ~ ones(3)]])
1020+
# array affect equations used to not work
1021+
prob2 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
1022+
sol2 = @test_nowarn solve(prob, Tsit5())
1023+
1024+
@test sol1 sol2

0 commit comments

Comments
 (0)