Skip to content

feat: complete eval_expression and eval_module support #2826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
disturbance_inputs = disturbances(sys);
implicit_dae = false,
simplify = false,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...)
isempty(inputs) && @warn("No unbound inputs were found in system.")

Expand Down Expand Up @@ -240,7 +242,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
end
process = get_postprocess_fbody(sys)
f = build_function(rhss, args...; postprocess_fbody = process,
expression = Val{false}, kwargs...)
expression = Val{true}, kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
(; f, dvs, ps, io_sys = sys)
end

Expand Down Expand Up @@ -395,7 +398,7 @@ model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.i

`f_oop` will have an extra state corresponding to the integrator in the disturbance model. This state will not be affected by any input, but will affect the dynamics from where it enters, in this case it will affect additively from `model.torque.tau.u`.
"""
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kwargs...)
t = get_iv(sys)
@variables d(t)=0 [disturbance = true]
@variables u(t)=0 [input = true] # New system input
Expand All @@ -418,6 +421,6 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
augmented_sys = extend(augmented_sys, sys)

(f_oop, f_ip), dvs, p = generate_control_function(augmented_sys, all_inputs,
[d])
[d]; kwargs...)
(f_oop, f_ip), augmented_sys, dvs, p
end
126 changes: 63 additions & 63 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ time-independent systems. If `split=true` (the default) was passed to [`complete
object.
"""
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing, kwargs...)
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing,
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
end
Expand All @@ -177,28 +178,38 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
if states === nothing
states = sol_states
end
if is_time_dependent(sys)
return build_function(exprs,
fnexpr = if is_time_dependent(sys)
build_function(exprs,
dvs,
p...,
get_iv(sys);
kwargs...,
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs)
wrap_array_vars(sys, exprs; dvs),
expression = Val{true}
)
else
return build_function(exprs,
build_function(exprs,
dvs,
p...;
kwargs...,
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs)
wrap_array_vars(sys, exprs; dvs),
expression = Val{true}
)
end
if expression == Val{true}
return fnexpr
end
if fnexpr isa Tuple
return eval_or_rgf.(fnexpr; eval_expression, eval_module)
else
return eval_or_rgf(fnexpr; eval_expression, eval_module)
end
end

function wrap_assignments(isscalar, assignments; let_block = false)
Expand Down Expand Up @@ -509,7 +520,8 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
end

function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if sym isa Symbol
_sym = get(ic.symbol_to_variable, sym, nothing)
Expand All @@ -531,7 +543,8 @@ function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
end
end
end
_fn = build_explicit_observed_function(sys, sym)
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)

if is_time_dependent(sys)
return let _fn = _fn
fn1(u, p, t) = _fn(u, p, t)
Expand Down Expand Up @@ -1210,19 +1223,30 @@ end
struct ObservedFunctionCache{S}
sys::S
dict::Dict{Any, Any}
eval_expression::Bool
eval_module::Module
end

function ObservedFunctionCache(sys)
return ObservedFunctionCache(sys, Dict())
let sys = sys, dict = Dict()
function generated_observed(obsvar, args...)
end
end
function ObservedFunctionCache(sys; eval_expression = false, eval_module = @__MODULE__)
return ObservedFunctionCache(sys, Dict(), eval_expression, eval_module)
end

# This is hit because ensemble problems do a deepcopy
function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict)
sys = deepcopy(ofc.sys)
dict = deepcopy(ofc.dict)
eval_expression = ofc.eval_expression
eval_module = ofc.eval_module
newofc = ObservedFunctionCache(sys, dict, eval_expression, eval_module)
stackdict[ofc] = newofc
return newofc
end

function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(ofc.sys, obsvar)
SymbolicIndexingInterface.observed(
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
eval_module = ofc.eval_module)
end
if args === ()
return obs
Expand Down Expand Up @@ -1871,6 +1895,7 @@ function linearization_function(sys::AbstractSystem, inputs,
p = DiffEqBase.NullParameters(),
zero_dummy_der = false,
initialization_solver_alg = TrustRegion(),
eval_expression = false, eval_module = @__MODULE__,
kwargs...)
inputs isa AbstractVector || (inputs = [inputs])
outputs isa AbstractVector || (outputs = [outputs])
Expand All @@ -1895,85 +1920,58 @@ function linearization_function(sys::AbstractSystem, inputs,
end
x0 = merge(defaults_and_guesses(sys), op)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
sys_ps = MTKParameters(sys, p, x0)
sys_ps = MTKParameters(sys, p, x0; eval_expression, eval_module)
else
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
end
p[get_iv(sys)] = NaN
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op);
eval_expression, eval_module)
initsys_ps = parameters(initsys)
initsys_idxs = [parameter_index(initsys, param) for param in initsys_ps]
tunable_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Tunable()]
tunable_getter = isempty(tunable_ps) ? nothing : getu(sys, tunable_ps)
discrete_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Discrete()]
disc_getter = isempty(discrete_ps) ? nothing : getu(sys, discrete_ps)
constant_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Constants()]
const_getter = isempty(constant_ps) ? nothing : getu(sys, constant_ps)
nonnum_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == NONNUMERIC_PORTION]
nonnum_getter = isempty(nonnum_ps) ? nothing : getu(sys, nonnum_ps)
p_getter = build_explicit_observed_function(
sys, initsys_ps; eval_expression, eval_module)

u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
getu(sys, unknowns(initsys))
get_initprob_u_p = let tunable_getter = tunable_getter,
disc_getter = disc_getter,
const_getter = const_getter,
nonnum_getter = nonnum_getter,
oldps = oldps,
build_explicit_observed_function(
sys, unknowns(initsys); eval_expression, eval_module)
get_initprob_u_p = let p_getter,
p_setter! = setp(initsys, initsys_ps),
u_getter = u_getter

function (u, p, t)
state = ProblemState(; u, p, t)
if tunable_getter !== nothing
SciMLStructures.replace!(
SciMLStructures.Tunable(), oldps, tunable_getter(state))
end
if disc_getter !== nothing
SciMLStructures.replace!(
SciMLStructures.Discrete(), oldps, disc_getter(state))
end
if const_getter !== nothing
SciMLStructures.replace!(
SciMLStructures.Constants(), oldps, const_getter(state))
end
if nonnum_getter !== nothing
SciMLStructures.replace!(
NONNUMERIC_PORTION, oldps, nonnum_getter(state))
end
p_setter!(oldps, p_getter(state))
newu = u_getter(state)
return newu, oldps
end
end
else
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
u_getter = getu(sys, unknowns(initsys))
u_getter = build_explicit_observed_function(
sys, unknowns(initsys); eval_expression, eval_module)

function (u, p, t)
state = ProblemState(; u, p, t)
return u_getter(state), p_getter(state)
end
end
end
initfn = NonlinearFunction(initsys)
initprobmap = getu(initsys, unknowns(sys))
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
initprobmap = build_explicit_observed_function(
initsys, unknowns(sys); eval_expression, eval_module)
ps = full_parameters(sys)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = unknowns(sys),
get_initprob_u_p = get_initprob_u_p,
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
sys, unknowns(sys), ps),
sys, unknowns(sys), ps; eval_expression, eval_module),
initfn = initfn,
initprobmap = initprobmap,
h = build_explicit_observed_function(sys, outputs),
h = h,
chunk = ForwardDiff.Chunk(input_idxs),
sys_ps = sys_ps,
initialize = initialize,
Expand Down Expand Up @@ -2056,6 +2054,7 @@ where `x` are differential unknown variables, `z` algebraic variables, `u` input
"""
function linearize_symbolic(sys::AbstractSystem, inputs,
outputs; simplify = false, allow_input_derivatives = false,
eval_expression = false, eval_module = @__MODULE__,
kwargs...)
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(
sys, inputs, outputs; simplify,
Expand All @@ -2065,10 +2064,11 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
ps = full_parameters(sys)
p = reorder_parameters(sys, ps)

fun = generate_function(sys, sts, ps; expression = Val{false})[1]
fun_expr = generate_function(sys, sts, ps; expression = Val{true})[1]
fun = eval_or_rgf(fun_expr; eval_expression, eval_module)
dx = fun(sts, p..., t)

h = build_explicit_observed_function(sys, outputs)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
y = h(sts, p..., t)

fg_xz = Symbolics.jacobian(dx, sts)
Expand Down
22 changes: 14 additions & 8 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ Notes
- `kwargs` are passed through to `Symbolics.build_function`.
"""
function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
expression = Val{true}, kwargs...)
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
u = map(x -> time_varying_as_func(value(x), sys), dvs)
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
t = get_iv(sys)
Expand All @@ -353,8 +353,13 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
cmap = map(x -> x => getdefault(x), cs)
condit = substitute(condit, cmap)
end
build_function(condit, u, t, p...; expression, wrap_code = condition_header(sys),
expr = build_function(
condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys),
kwargs...)
if expression == Val{true}
return expr
end
return eval_or_rgf(expr; eval_expression, eval_module)
end

function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
Expand All @@ -379,7 +384,8 @@ Notes
- `kwargs` are passed through to `Symbolics.build_function`.
"""
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
expression = Val{true}, checkvars = true,
expression = Val{true}, checkvars = true, eval_expression = false,
eval_module = @__MODULE__,
postprocess_affect_expr! = nothing, kwargs...)
if isempty(eqs)
if expression == Val{true}
Expand Down Expand Up @@ -432,20 +438,20 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
end
t = get_iv(sys)
integ = gensym(:MTKIntegrator)
getexpr = (postprocess_affect_expr! === nothing) ? expression : Val{true}
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = getexpr,
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = add_integrator_header(sys, integ, outvar),
outputidxs = update_inds,
postprocess_fbody = pre,
kwargs...)
# applied user-provided function to the generated expression
if postprocess_affect_expr! !== nothing
postprocess_affect_expr!(rf_ip, integ)
(expression == Val{false}) &&
(return drop_expr(@RuntimeGeneratedFunction(rf_ip)))
end
rf_ip
if expression == Val{false}
return eval_or_rgf(rf_ip; eval_expression, eval_module)
end
return rf_ip
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
id2set = Dict{Int, Int}()
merged_set = ConnectionSet[]
for (id, ele) in enumerate(idx2ele)
rid = find_root(union_find, id)
rid = find_root!(union_find, id)
set_idx = get!(id2set, rid) do
set = ConnectionSet()
push!(merged_set, set)
Expand Down
Loading
Loading