Skip to content

Commit 0717ef7

Browse files
fixup! fix: fix callback codegen, observed eqs with non-scalarized symbolic arrays
1 parent ec2e297 commit 0717ef7

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/systems/abstractsystem.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
153153
154154
Generate a function to evaluate `exprs`. `exprs` is a symbolic expression or
155155
array of symbolic expression involving symbolic variables in `sys`. The symbolic variables
156-
may be subsetted using `dvs` and `ps`. All `kwargs` except `postprocess_fbody` and `states`
157-
are passed to the internal [`build_function`](@ref) call. The returned function can be called
158-
as `f(u, p, t)` or `f(du, u, p, t)` for time-dependent systems and `f(u, p)` or `f(du, u, p)`
159-
for time-independent systems. If `split=true` (the default) was passed to [`complete`](@ref),
156+
may be subsetted using `dvs` and `ps`. All `kwargs` are passed to the internal
157+
[`build_function`](@ref) call. The returned function can be called as `f(u, p, t)` or
158+
`f(du, u, p, t)` for time-dependent systems and `f(u, p)` or `f(du, u, p)` for
159+
time-independent systems. If `split=true` (the default) was passed to [`complete`](@ref),
160160
[`structural_simplify`](@ref) or [`@mtkbuild`](@ref), `p` is expected to be an `MTKParameters`
161161
object.
162162
"""
163163
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
164-
ps = parameters(sys); wrap_code = nothing, kwargs...)
164+
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing, kwargs...)
165165
if !iscomplete(sys)
166166
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
167167
end
@@ -171,15 +171,20 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
171171
wrap_code = isscalar ? identity : (identity, identity)
172172
end
173173
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
174-
174+
if postprocess_fbody === nothing
175+
postprocess_fbody = pre
176+
end
177+
if states === nothing
178+
states = sol_states
179+
end
175180
if is_time_dependent(sys)
176181
return build_function(exprs,
177182
dvs,
178183
p...,
179184
get_iv(sys);
180185
kwargs...,
181-
postprocess_fbody = pre,
182-
states = sol_states,
186+
postprocess_fbody,
187+
states,
183188
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
184189
wrap_array_vars(sys, exprs; dvs)
185190
)
@@ -188,8 +193,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
188193
dvs,
189194
p...;
190195
kwargs...,
191-
postprocess_fbody = pre,
192-
states = sol_states,
196+
postprocess_fbody,
197+
states,
193198
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
194199
wrap_array_vars(sys, exprs; dvs)
195200
)

src/systems/callbacks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,8 +472,9 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
472472
rhss = map(x -> x.rhs, eqs)
473473
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
474474

475+
pre = get_preprocess_constants(rhss)
475476
rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false},
476-
kwargs...)
477+
postprocess_fbody = pre, kwargs...)
477478

478479
affect_functions = map(cbs) do cb # Keep affect function separate
479480
eq_aff = affects(cb)

0 commit comments

Comments
 (0)