Skip to content

fix: fix callback codegen, observed eqs with non-scalarized symbolic arrays #2605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,43 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
end

sys = state.sys

obs_sub = dummy_sub
for eq in neweqs
isdiffeq(eq) || continue
obs_sub[eq.lhs] = eq.rhs
end
# TODO: compute the dependency correctly so that we don't have to do this
obs = [fast_substitute(observed(sys), obs_sub); subeqs]

# HACK: Substitute non-scalarized symbolic arrays of observed variables
# E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
# ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
# by the topological sorting and dependency identification pieces
obs_arr_subs = Dict()

for eq in obs
lhs = eq.lhs
istree(lhs) || continue
operation(lhs) === getindex || continue
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
arg1 = arguments(lhs)[1]
haskey(obs_arr_subs, arg1) && continue
obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)]
end
for i in eachindex(neweqs)
neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator)
end
for i in eachindex(obs)
obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator)
end
for i in eachindex(subeqs)
subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator)
end

@set! sys.eqs = neweqs
@set! sys.observed = obs

unknowns = Any[v
for (i, v) in enumerate(fullvars)
if diff_to_var[i] === nothing && ispresent(i)]
Expand All @@ -563,15 +599,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
@set! sys.unknowns = unknowns
@set! sys.substitutions = Substitutions(subeqs, deps)

obs_sub = dummy_sub
for eq in equations(sys)
isdiffeq(eq) || continue
obs_sub[eq.lhs] = eq.rhs
end
# TODO: compute the dependency correctly so that we don't have to do this
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
@set! sys.observed = obs

# Only makes sense for time-dependent
# TODO: generalize to SDE
if sys isa ODESystem
Expand Down
27 changes: 16 additions & 11 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),

Generate a function to evaluate `exprs`. `exprs` is a symbolic expression or
array of symbolic expression involving symbolic variables in `sys`. The symbolic variables
may be subsetted using `dvs` and `ps`. All `kwargs` except `postprocess_fbody` and `states`
are passed to the internal [`build_function`](@ref) call. The returned function can be called
as `f(u, p, t)` or `f(du, u, p, t)` for time-dependent systems and `f(u, p)` or `f(du, u, p)`
for time-independent systems. If `split=true` (the default) was passed to [`complete`](@ref),
may be subsetted using `dvs` and `ps`. All `kwargs` are passed to the internal
[`build_function`](@ref) call. The returned function can be called as `f(u, p, t)` or
`f(du, u, p, t)` for time-dependent systems and `f(u, p)` or `f(du, u, p)` for
time-independent systems. If `split=true` (the default) was passed to [`complete`](@ref),
[`structural_simplify`](@ref) or [`@mtkbuild`](@ref), `p` is expected to be an `MTKParameters`
object.
"""
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
ps = parameters(sys); wrap_code = nothing, kwargs...)
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing, kwargs...)
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
end
Expand All @@ -170,16 +170,21 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
if wrap_code === nothing
wrap_code = isscalar ? identity : (identity, identity)
end
pre, sol_states = get_substitutions_and_solved_unknowns(sys)

pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
if postprocess_fbody === nothing
postprocess_fbody = pre
end
if states === nothing
states = sol_states
end
if is_time_dependent(sys)
return build_function(exprs,
dvs,
p...,
get_iv(sys);
kwargs...,
postprocess_fbody = pre,
states = sol_states,
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs)
)
Expand All @@ -188,8 +193,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
dvs,
p...;
kwargs...,
postprocess_fbody = pre,
states = sol_states,
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs)
)
Expand Down
17 changes: 7 additions & 10 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
return (args...) -> () # We don't do anything in the callback, we're just after the event
end
else
eqs = flatten_equations(eqs)
rhss = map(x -> x.rhs, eqs)
outvar = :u
if outputidxs === nothing
Expand Down Expand Up @@ -457,7 +458,7 @@ end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
eqs = map(cb -> cb.eqs, cbs)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
# fuse equations to create VectorContinuousCallback
Expand All @@ -471,12 +472,8 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
rhss = map(x -> x.rhs, eqs)
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))

u = map(x -> time_varying_as_func(value(x), sys), dvs)
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
t = get_iv(sys)
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{false},
postprocess_fbody = pre, kwargs...)
rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false},
kwargs...)

affect_functions = map(cbs) do cb # Keep affect function separate
eq_aff = affects(cb)
Expand All @@ -487,16 +484,16 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
cond = function (u, t, integ)
if DiffEqBase.isinplace(integ.sol.prob)
tmp, = DiffEqBase.get_tmp_cache(integ)
rf_ip(tmp, u, parameter_values(integ)..., t)
rf_ip(tmp, u, parameter_values(integ), t)
tmp[1]
else
rf_oop(u, parameter_values(integ)..., t)
rf_oop(u, parameter_values(integ), t)
end
end
ContinuousCallback(cond, affect_functions[])
else
cond = function (out, u, t, integ)
rf_ip(out, u, parameter_values(integ)..., t)
rf_ip(out, u, parameter_values(integ), t)
end

# since there may be different number of conditions and affects,
Expand Down
1 change: 1 addition & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
varmap = u0map === nothing || isempty(u0map) || eltype(u0map) <: Number ?
defaults(sys) :
merge(defaults(sys), todict(u0map))
varmap = canonicalize_varmap(varmap)
varlist = collect(map(unwrap, dvs))
missingvars = setdiff(varlist, collect(keys(varmap)))

Expand Down
5 changes: 4 additions & 1 deletion src/systems/optimization/constraints_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,15 @@ function generate_canonical_form_lhss(sys)
lhss = subs_constants([Symbolics.canonical_form(eq).lhs for eq in constraints(sys)])
end

function get_cmap(sys::ConstraintsSystem)
function get_cmap(sys::ConstraintsSystem, exprs = nothing)
#Inject substitutions for constants => values
cs = collect_constants([get_constraints(sys); get_observed(sys)]) #ctrls? what else?
if !empty_substitutions(sys)
cs = [cs; collect_constants(get_substitutions(sys).subs)]
end
if exprs !== nothing
cs = [cs; collect_contants(exprs)]
end
# Swap constants for their values
cmap = map(x -> x ~ getdefault(x), cs)
return cmap, cs
Expand Down
9 changes: 6 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -564,19 +564,22 @@ function empty_substitutions(sys)
isnothing(subs) || isempty(subs.deps)
end

function get_cmap(sys)
function get_cmap(sys, exprs = nothing)
#Inject substitutions for constants => values
cs = collect_constants([get_eqs(sys); get_observed(sys)]) #ctrls? what else?
if !empty_substitutions(sys)
cs = [cs; collect_constants(get_substitutions(sys).subs)]
end
if exprs !== nothing
cs = [cs; collect_constants(exprs)]
end
# Swap constants for their values
cmap = map(x -> x ~ getdefault(x), cs)
return cmap, cs
end

function get_substitutions_and_solved_unknowns(sys; no_postprocess = false)
cmap, cs = get_cmap(sys)
function get_substitutions_and_solved_unknowns(sys, exprs = nothing; no_postprocess = false)
cmap, cs = get_cmap(sys, exprs)
if empty_substitutions(sys) && isempty(cs)
sol_states = Code.LazyState()
pre = no_postprocess ? (ex -> ex) : get_postprocess_fbody(sys)
Expand Down
1 change: 1 addition & 0 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ function canonicalize_varmap(varmap; toterm = Symbolics.diff2term)
if Symbolics.isarraysymbolic(k) && Symbolics.shape(k) !== Symbolics.Unknown()
for i in eachindex(k)
new_varmap[k[i]] = v[i]
new_varmap[toterm(k[i])] = v[i]
end
end
end
Expand Down
7 changes: 7 additions & 0 deletions test/initial_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ varmap = Dict(p => ones(3), q => 2ones(3))
cvarmap = ModelingToolkit.canonicalize_varmap(varmap)
target_varmap = Dict(p => ones(3), q => 2ones(3), q[1] => 2.0, q[2] => 2.0, q[3] => 2.0)
@test cvarmap == target_varmap

# Initialization of ODEProblem with dummy derivatives of multidimensional arrays
# Issue#1283
@variables z(t)[1:2, 1:2]
eqs = [D(D(z)) ~ ones(2, 2)]
@mtkbuild sys = ODESystem(eqs, t)
@test_nowarn ODEProblem(sys, [z => zeros(2, 2), D(z) => ones(2, 2)], (0.0, 10.0))
35 changes: 35 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,38 @@ let # Issue https://github.com/SciML/ModelingToolkit.jl/issues/2322
sol = solve(prob, Rodas4())
@test sol(1)[]≈0.6065307685451087 rtol=1e-4
end

# Issue#2599
@variables x(t) y(t)
eqs = [D(x) ~ x * t, y ~ 2x]
@mtkbuild sys = ODESystem(eqs, t; continuous_events = [[y ~ 3] => [x ~ 2]])
prob = ODEProblem(sys, [x => 1.0], (0.0, 10.0))
@test_nowarn solve(prob, Tsit5())

# Issue#2383
@variables x(t)[1:3]
@parameters p[1:3, 1:3]
eqs = [
D(x) ~ p * x
]
@mtkbuild sys = ODESystem(eqs, t; continuous_events = [[norm(x) ~ 3.0] => [x ~ ones(3)]])
# array affect equations used to not work
prob1 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
sol1 = @test_nowarn solve(prob1, Tsit5())

# array condition equations also used to not work
@mtkbuild sys = ODESystem(
eqs, t; continuous_events = [[x ~ sqrt(3) * ones(3)] => [x ~ ones(3)]])
# array affect equations used to not work
prob2 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
sol2 = @test_nowarn solve(prob2, Tsit5())

@test sol1 ≈ sol2

# Requires fix in symbolics for `linear_expansion(p * x, D(y))`
@test_broken begin
@variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3]
@test_nowarn @mtkbuild sys = ODESystem([D(x) ~ p * x, D(y) ~ x' * p * x], t)
@test_nowarn ODEProblem(sys, [x => ones(3), y => 2], (0.0, 10.0), [p => ones(3, 3)])
end