Skip to content

Commit a203b86

Browse files
Merge pull request #2826 from AayushSabharwal/as/m-e-g-a
feat: complete `eval_expression` and `eval_module` support
2 parents b0c4b2b + add87d5 commit a203b86

File tree

13 files changed

+265
-194
lines changed

13 files changed

+265
-194
lines changed

src/inputoutput.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
195195
disturbance_inputs = disturbances(sys);
196196
implicit_dae = false,
197197
simplify = false,
198+
eval_expression = false,
199+
eval_module = @__MODULE__,
198200
kwargs...)
199201
isempty(inputs) && @warn("No unbound inputs were found in system.")
200202

@@ -240,7 +242,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
240242
end
241243
process = get_postprocess_fbody(sys)
242244
f = build_function(rhss, args...; postprocess_fbody = process,
243-
expression = Val{false}, kwargs...)
245+
expression = Val{true}, kwargs...)
246+
f = eval_or_rgf.(f; eval_expression, eval_module)
244247
(; f, dvs, ps, io_sys = sys)
245248
end
246249

@@ -395,7 +398,7 @@ model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.i
395398
396399
`f_oop` will have an extra state corresponding to the integrator in the disturbance model. This state will not be affected by any input, but will affect the dynamics from where it enters, in this case it will affect additively from `model.torque.tau.u`.
397400
"""
398-
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
401+
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kwargs...)
399402
t = get_iv(sys)
400403
@variables d(t)=0 [disturbance = true]
401404
@variables u(t)=0 [input = true] # New system input
@@ -418,6 +421,6 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
418421
augmented_sys = extend(augmented_sys, sys)
419422

420423
(f_oop, f_ip), dvs, p = generate_control_function(augmented_sys, all_inputs,
421-
[d])
424+
[d]; kwargs...)
422425
(f_oop, f_ip), augmented_sys, dvs, p
423426
end

src/systems/abstractsystem.jl

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ time-independent systems. If `split=true` (the default) was passed to [`complete
161161
object.
162162
"""
163163
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
164-
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing, kwargs...)
164+
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing,
165+
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
165166
if !iscomplete(sys)
166167
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
167168
end
@@ -177,28 +178,38 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
177178
if states === nothing
178179
states = sol_states
179180
end
180-
if is_time_dependent(sys)
181-
return build_function(exprs,
181+
fnexpr = if is_time_dependent(sys)
182+
build_function(exprs,
182183
dvs,
183184
p...,
184185
get_iv(sys);
185186
kwargs...,
186187
postprocess_fbody,
187188
states,
188189
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
189-
wrap_array_vars(sys, exprs; dvs)
190+
wrap_array_vars(sys, exprs; dvs),
191+
expression = Val{true}
190192
)
191193
else
192-
return build_function(exprs,
194+
build_function(exprs,
193195
dvs,
194196
p...;
195197
kwargs...,
196198
postprocess_fbody,
197199
states,
198200
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
199-
wrap_array_vars(sys, exprs; dvs)
201+
wrap_array_vars(sys, exprs; dvs),
202+
expression = Val{true}
200203
)
201204
end
205+
if expression == Val{true}
206+
return fnexpr
207+
end
208+
if fnexpr isa Tuple
209+
return eval_or_rgf.(fnexpr; eval_expression, eval_module)
210+
else
211+
return eval_or_rgf(fnexpr; eval_expression, eval_module)
212+
end
202213
end
203214

204215
function wrap_assignments(isscalar, assignments; let_block = false)
@@ -509,7 +520,8 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
509520
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
510521
end
511522

512-
function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
523+
function SymbolicIndexingInterface.observed(
524+
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
513525
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
514526
if sym isa Symbol
515527
_sym = get(ic.symbol_to_variable, sym, nothing)
@@ -531,7 +543,8 @@ function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
531543
end
532544
end
533545
end
534-
_fn = build_explicit_observed_function(sys, sym)
546+
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
547+
535548
if is_time_dependent(sys)
536549
return let _fn = _fn
537550
fn1(u, p, t) = _fn(u, p, t)
@@ -1210,19 +1223,30 @@ end
12101223
struct ObservedFunctionCache{S}
12111224
sys::S
12121225
dict::Dict{Any, Any}
1226+
eval_expression::Bool
1227+
eval_module::Module
12131228
end
12141229

1215-
function ObservedFunctionCache(sys)
1216-
return ObservedFunctionCache(sys, Dict())
1217-
let sys = sys, dict = Dict()
1218-
function generated_observed(obsvar, args...)
1219-
end
1220-
end
1230+
function ObservedFunctionCache(sys; eval_expression = false, eval_module = @__MODULE__)
1231+
return ObservedFunctionCache(sys, Dict(), eval_expression, eval_module)
1232+
end
1233+
1234+
# This is hit because ensemble problems do a deepcopy
1235+
function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict)
1236+
sys = deepcopy(ofc.sys)
1237+
dict = deepcopy(ofc.dict)
1238+
eval_expression = ofc.eval_expression
1239+
eval_module = ofc.eval_module
1240+
newofc = ObservedFunctionCache(sys, dict, eval_expression, eval_module)
1241+
stackdict[ofc] = newofc
1242+
return newofc
12211243
end
12221244

12231245
function (ofc::ObservedFunctionCache)(obsvar, args...)
12241246
obs = get!(ofc.dict, value(obsvar)) do
1225-
SymbolicIndexingInterface.observed(ofc.sys, obsvar)
1247+
SymbolicIndexingInterface.observed(
1248+
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
1249+
eval_module = ofc.eval_module)
12261250
end
12271251
if args === ()
12281252
return obs
@@ -1871,6 +1895,7 @@ function linearization_function(sys::AbstractSystem, inputs,
18711895
p = DiffEqBase.NullParameters(),
18721896
zero_dummy_der = false,
18731897
initialization_solver_alg = TrustRegion(),
1898+
eval_expression = false, eval_module = @__MODULE__,
18741899
kwargs...)
18751900
inputs isa AbstractVector || (inputs = [inputs])
18761901
outputs isa AbstractVector || (outputs = [outputs])
@@ -1895,85 +1920,58 @@ function linearization_function(sys::AbstractSystem, inputs,
18951920
end
18961921
x0 = merge(defaults_and_guesses(sys), op)
18971922
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1898-
sys_ps = MTKParameters(sys, p, x0)
1923+
sys_ps = MTKParameters(sys, p, x0; eval_expression, eval_module)
18991924
else
19001925
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
19011926
end
19021927
p[get_iv(sys)] = NaN
19031928
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
1904-
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
1929+
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op);
1930+
eval_expression, eval_module)
19051931
initsys_ps = parameters(initsys)
1906-
initsys_idxs = [parameter_index(initsys, param) for param in initsys_ps]
1907-
tunable_ps = [initsys_ps[i]
1908-
for i in eachindex(initsys_ps)
1909-
if initsys_idxs[i].portion == SciMLStructures.Tunable()]
1910-
tunable_getter = isempty(tunable_ps) ? nothing : getu(sys, tunable_ps)
1911-
discrete_ps = [initsys_ps[i]
1912-
for i in eachindex(initsys_ps)
1913-
if initsys_idxs[i].portion == SciMLStructures.Discrete()]
1914-
disc_getter = isempty(discrete_ps) ? nothing : getu(sys, discrete_ps)
1915-
constant_ps = [initsys_ps[i]
1916-
for i in eachindex(initsys_ps)
1917-
if initsys_idxs[i].portion == SciMLStructures.Constants()]
1918-
const_getter = isempty(constant_ps) ? nothing : getu(sys, constant_ps)
1919-
nonnum_ps = [initsys_ps[i]
1920-
for i in eachindex(initsys_ps)
1921-
if initsys_idxs[i].portion == NONNUMERIC_PORTION]
1922-
nonnum_getter = isempty(nonnum_ps) ? nothing : getu(sys, nonnum_ps)
1932+
p_getter = build_explicit_observed_function(
1933+
sys, initsys_ps; eval_expression, eval_module)
1934+
19231935
u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
1924-
getu(sys, unknowns(initsys))
1925-
get_initprob_u_p = let tunable_getter = tunable_getter,
1926-
disc_getter = disc_getter,
1927-
const_getter = const_getter,
1928-
nonnum_getter = nonnum_getter,
1929-
oldps = oldps,
1936+
build_explicit_observed_function(
1937+
sys, unknowns(initsys); eval_expression, eval_module)
1938+
get_initprob_u_p = let p_getter,
1939+
p_setter! = setp(initsys, initsys_ps),
19301940
u_getter = u_getter
19311941

19321942
function (u, p, t)
19331943
state = ProblemState(; u, p, t)
1934-
if tunable_getter !== nothing
1935-
SciMLStructures.replace!(
1936-
SciMLStructures.Tunable(), oldps, tunable_getter(state))
1937-
end
1938-
if disc_getter !== nothing
1939-
SciMLStructures.replace!(
1940-
SciMLStructures.Discrete(), oldps, disc_getter(state))
1941-
end
1942-
if const_getter !== nothing
1943-
SciMLStructures.replace!(
1944-
SciMLStructures.Constants(), oldps, const_getter(state))
1945-
end
1946-
if nonnum_getter !== nothing
1947-
SciMLStructures.replace!(
1948-
NONNUMERIC_PORTION, oldps, nonnum_getter(state))
1949-
end
1944+
p_setter!(oldps, p_getter(state))
19501945
newu = u_getter(state)
19511946
return newu, oldps
19521947
end
19531948
end
19541949
else
19551950
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
1956-
u_getter = getu(sys, unknowns(initsys))
1951+
u_getter = build_explicit_observed_function(
1952+
sys, unknowns(initsys); eval_expression, eval_module)
19571953

19581954
function (u, p, t)
19591955
state = ProblemState(; u, p, t)
19601956
return u_getter(state), p_getter(state)
19611957
end
19621958
end
19631959
end
1964-
initfn = NonlinearFunction(initsys)
1965-
initprobmap = getu(initsys, unknowns(sys))
1960+
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
1961+
initprobmap = build_explicit_observed_function(
1962+
initsys, unknowns(sys); eval_expression, eval_module)
19661963
ps = full_parameters(sys)
1964+
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
19671965
lin_fun = let diff_idxs = diff_idxs,
19681966
alge_idxs = alge_idxs,
19691967
input_idxs = input_idxs,
19701968
sts = unknowns(sys),
19711969
get_initprob_u_p = get_initprob_u_p,
19721970
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
1973-
sys, unknowns(sys), ps),
1971+
sys, unknowns(sys), ps; eval_expression, eval_module),
19741972
initfn = initfn,
19751973
initprobmap = initprobmap,
1976-
h = build_explicit_observed_function(sys, outputs),
1974+
h = h,
19771975
chunk = ForwardDiff.Chunk(input_idxs),
19781976
sys_ps = sys_ps,
19791977
initialize = initialize,
@@ -2056,6 +2054,7 @@ where `x` are differential unknown variables, `z` algebraic variables, `u` input
20562054
"""
20572055
function linearize_symbolic(sys::AbstractSystem, inputs,
20582056
outputs; simplify = false, allow_input_derivatives = false,
2057+
eval_expression = false, eval_module = @__MODULE__,
20592058
kwargs...)
20602059
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(
20612060
sys, inputs, outputs; simplify,
@@ -2065,10 +2064,11 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
20652064
ps = full_parameters(sys)
20662065
p = reorder_parameters(sys, ps)
20672066

2068-
fun = generate_function(sys, sts, ps; expression = Val{false})[1]
2067+
fun_expr = generate_function(sys, sts, ps; expression = Val{true})[1]
2068+
fun = eval_or_rgf(fun_expr; eval_expression, eval_module)
20692069
dx = fun(sts, p..., t)
20702070

2071-
h = build_explicit_observed_function(sys, outputs)
2071+
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
20722072
y = h(sts, p..., t)
20732073

20742074
fg_xz = Symbolics.jacobian(dx, sts)

src/systems/callbacks.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ Notes
343343
- `kwargs` are passed through to `Symbolics.build_function`.
344344
"""
345345
function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
346-
expression = Val{true}, kwargs...)
346+
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
347347
u = map(x -> time_varying_as_func(value(x), sys), dvs)
348348
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
349349
t = get_iv(sys)
@@ -353,8 +353,13 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
353353
cmap = map(x -> x => getdefault(x), cs)
354354
condit = substitute(condit, cmap)
355355
end
356-
build_function(condit, u, t, p...; expression, wrap_code = condition_header(sys),
356+
expr = build_function(
357+
condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys),
357358
kwargs...)
359+
if expression == Val{true}
360+
return expr
361+
end
362+
return eval_or_rgf(expr; eval_expression, eval_module)
358363
end
359364

360365
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
@@ -379,7 +384,8 @@ Notes
379384
- `kwargs` are passed through to `Symbolics.build_function`.
380385
"""
381386
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
382-
expression = Val{true}, checkvars = true,
387+
expression = Val{true}, checkvars = true, eval_expression = false,
388+
eval_module = @__MODULE__,
383389
postprocess_affect_expr! = nothing, kwargs...)
384390
if isempty(eqs)
385391
if expression == Val{true}
@@ -432,20 +438,20 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
432438
end
433439
t = get_iv(sys)
434440
integ = gensym(:MTKIntegrator)
435-
getexpr = (postprocess_affect_expr! === nothing) ? expression : Val{true}
436441
pre = get_preprocess_constants(rhss)
437-
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = getexpr,
442+
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
438443
wrap_code = add_integrator_header(sys, integ, outvar),
439444
outputidxs = update_inds,
440445
postprocess_fbody = pre,
441446
kwargs...)
442447
# applied user-provided function to the generated expression
443448
if postprocess_affect_expr! !== nothing
444449
postprocess_affect_expr!(rf_ip, integ)
445-
(expression == Val{false}) &&
446-
(return drop_expr(@RuntimeGeneratedFunction(rf_ip)))
447450
end
448-
rf_ip
451+
if expression == Val{false}
452+
return eval_or_rgf(rf_ip; eval_expression, eval_module)
453+
end
454+
return rf_ip
449455
end
450456
end
451457

src/systems/connectors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
408408
id2set = Dict{Int, Int}()
409409
merged_set = ConnectionSet[]
410410
for (id, ele) in enumerate(idx2ele)
411-
rid = find_root(union_find, id)
411+
rid = find_root!(union_find, id)
412412
set_idx = get!(id2set, rid) do
413413
set = ConnectionSet()
414414
push!(merged_set, set)

0 commit comments

Comments
 (0)