Skip to content

Commit 9eb96a5

Browse files
AayushSabharwalChrisRackauckas
authored andcommitted
feat: add support for parameter dependencies
1 parent 6abcc49 commit 9eb96a5

14 files changed

+353
-114
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using PrecompileTools, Reexport
3636
using RecursiveArrayTools
3737

3838
using SymbolicIndexingInterface
39-
export independent_variables, unknowns, parameters
39+
export independent_variables, unknowns, parameters, full_parameters
4040
import SymbolicUtils
4141
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
4242
Symbolic, isadd, ismul, ispow, issym, FnType,

src/systems/abstractsystem.jl

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -286,27 +286,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
286286
if has_index_cache(sys) && get_index_cache(sys) !== nothing
287287
ic = get_index_cache(sys)
288288
h = getsymbolhash(sym)
289-
return if haskey(ic.param_idx, h)
290-
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
291-
elseif haskey(ic.discrete_idx, h)
292-
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
293-
elseif haskey(ic.constant_idx, h)
294-
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
295-
elseif haskey(ic.dependent_idx, h)
296-
ParameterIndex(nothing, ic.dependent_idx[h])
289+
return if (idx = ParameterIndex(ic, sym)) !== nothing
290+
idx
291+
elseif (idx = ParameterIndex(ic, default_toterm(sym))) !== nothing
292+
idx
297293
else
298-
h = getsymbolhash(default_toterm(sym))
299-
if haskey(ic.param_idx, h)
300-
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
301-
elseif haskey(ic.discrete_idx, h)
302-
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
303-
elseif haskey(ic.constant_idx, h)
304-
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
305-
elseif haskey(ic.dependent_idx, h)
306-
ParameterIndex(nothing, ic.dependent_idx[h])
307-
else
308-
nothing
309-
end
294+
nothing
310295
end
311296
end
312297

@@ -329,7 +314,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
329314
end
330315

331316
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
332-
return parameters(sys)
317+
return full_parameters(sys)
333318
end
334319

335320
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
@@ -419,6 +404,7 @@ for prop in [:eqs
419404
:metadata
420405
:gui_metadata
421406
:discrete_subsystems
407+
:parameter_dependencies
422408
:solved_unknowns
423409
:split_idxs
424410
:parent
@@ -750,7 +736,29 @@ function parameters(sys::AbstractSystem)
750736
ps = first.(ps)
751737
end
752738
systems = get_systems(sys)
753-
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
739+
result = unique(isempty(systems) ? ps :
740+
[ps; reduce(vcat, namespace_parameters.(systems))])
741+
if has_parameter_dependencies(sys) &&
742+
(pdeps = get_parameter_dependencies(sys)) !== nothing
743+
filter(result) do sym
744+
!haskey(pdeps, sym)
745+
end
746+
else
747+
result
748+
end
749+
end
750+
751+
function dependent_parameters(sys::AbstractSystem)
752+
if has_parameter_dependencies(sys) &&
753+
(pdeps = get_parameter_dependencies(sys)) !== nothing
754+
collect(keys(pdeps))
755+
else
756+
[]
757+
end
758+
end
759+
760+
function full_parameters(sys::AbstractSystem)
761+
vcat(parameters(sys), dependent_parameters(sys))
754762
end
755763

756764
# required in `src/connectors.jl:437`
@@ -1612,7 +1620,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
16121620
kwargs...)
16131621
sts = unknowns(sys)
16141622
t = get_iv(sys)
1615-
ps = parameters(sys)
1623+
ps = full_parameters(sys)
16161624
p = reorder_parameters(sys, ps)
16171625

16181626
fun = generate_function(sys, sts, ps; expression = Val{false})[1]
@@ -2123,3 +2131,17 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
21232131
error("substituting symbols is not supported for $(typeof(sys))")
21242132
end
21252133
end
2134+
2135+
function process_parameter_dependencies(pdeps, ps)
2136+
pdeps === nothing && return pdeps, ps
2137+
if pdeps isa Vector && eltype(pdeps) <: Pair
2138+
pdeps = Dict(pdeps)
2139+
elseif !(pdeps isa Dict)
2140+
error("parameter_dependencies must be a `Dict` or `Vector{<:Pair}`")
2141+
end
2142+
2143+
ps = filter(ps) do p
2144+
!haskey(pdeps, p)
2145+
end
2146+
return pdeps, ps
2147+
end

src/systems/callbacks.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,14 +433,14 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
433433
end
434434

435435
function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
436-
ps = parameters(sys); kwargs...)
436+
ps = full_parameters(sys); kwargs...)
437437
cbs = continuous_events(sys)
438438
isempty(cbs) && return nothing
439439
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
440440
end
441441

442442
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
443-
ps = parameters(sys); kwargs...)
443+
ps = full_parameters(sys); kwargs...)
444444
eqs = map(cb -> cb.eqs, cbs)
445445
num_eqs = length.(eqs)
446446
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
@@ -556,7 +556,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
556556
end
557557

558558
function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
559-
ps = parameters(sys); kwargs...)
559+
ps = full_parameters(sys); kwargs...)
560560
has_discrete_events(sys) || return nothing
561561
symcbs = discrete_events(sys)
562562
isempty(symcbs) && return nothing

src/systems/clock_inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ function generate_discrete_affect(
198198
throw = false,
199199
expression = true,
200200
output_type = SVector,
201-
ps = reorder_parameters(osys, parameters(sys)))
201+
ps = reorder_parameters(osys, full_parameters(sys)))
202202
ni = length(input)
203203
ns = length(unknowns(sys))
204204
disc = Func(
205205
[
206206
out,
207207
DestructuredArgs(unknowns(osys)),
208208
if use_index_cache
209-
DestructuredArgs.(reorder_parameters(osys, parameters(osys)))
209+
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))
210210
else
211211
(DestructuredArgs(appended_parameters),)
212212
end...,

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ function calculate_control_jacobian(sys::AbstractODESystem;
8080
return jac
8181
end
8282

83-
function generate_tgrad(sys::AbstractODESystem, dvs = unknowns(sys), ps = parameters(sys);
83+
function generate_tgrad(
84+
sys::AbstractODESystem, dvs = unknowns(sys), ps = full_parameters(sys);
8485
simplify = false, kwargs...)
8586
tgrad = calculate_tgrad(sys, simplify = simplify)
8687
pre = get_preprocess_constants(tgrad)
@@ -100,7 +101,7 @@ function generate_tgrad(sys::AbstractODESystem, dvs = unknowns(sys), ps = parame
100101
end
101102

102103
function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
103-
ps = parameters(sys);
104+
ps = full_parameters(sys);
104105
simplify = false, sparse = false, kwargs...)
105106
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
106107
pre = get_preprocess_constants(jac)
@@ -118,7 +119,7 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
118119
end
119120

120121
function generate_control_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
121-
ps = parameters(sys);
122+
ps = full_parameters(sys);
122123
simplify = false, sparse = false, kwargs...)
123124
jac = calculate_control_jacobian(sys; simplify = simplify, sparse = sparse)
124125
p = reorder_parameters(sys, ps)
@@ -146,7 +147,7 @@ function generate_dae_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
146147
end
147148

148149
function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
149-
ps = parameters(sys);
150+
ps = full_parameters(sys);
150151
implicit_dae = false,
151152
ddvs = implicit_dae ? map(Differential(get_iv(sys)), dvs) :
152153
nothing,
@@ -314,7 +315,7 @@ end
314315

315316
function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
316317
dvs = unknowns(sys),
317-
ps = parameters(sys), u0 = nothing;
318+
ps = full_parameters(sys), u0 = nothing;
318319
version = nothing, tgrad = false,
319320
jac = false, p = nothing,
320321
t = nothing,
@@ -830,7 +831,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
830831
kwargs...)
831832
eqs = equations(sys)
832833
dvs = unknowns(sys)
833-
ps = parameters(sys)
834+
ps = full_parameters(sys)
834835
iv = get_iv(sys)
835836

836837
if has_index_cache(sys) && get_index_cache(sys) !== nothing
@@ -845,7 +846,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
845846
symbolic_u0)
846847
p, split_idxs = split_parameters_by_type(p)
847848
if p isa Tuple
848-
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
849+
ps = Base.Fix1(getindex, full_parameters(sys)).(split_idxs)
849850
ps = (ps...,) #if p is Tuple, ps should be Tuple
850851
end
851852
end
@@ -997,7 +998,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
997998
cbs = CallbackSet(discrete_cbs...)
998999
end
9991000
else
1000-
cbs = CallbackSet(cbs, discrete_cbs)
1001+
cbs = CallbackSet(cbs, discrete_cbs...)
10011002
end
10021003
else
10031004
svs = nothing
@@ -1060,7 +1061,7 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
10601061
end
10611062

10621063
function generate_history(sys::AbstractODESystem, u0; kwargs...)
1063-
p = reorder_parameters(sys, parameters(sys))
1064+
p = reorder_parameters(sys, full_parameters(sys))
10641065
build_function(u0, p..., get_iv(sys); expression = Val{false}, kwargs...)
10651066
end
10661067

src/systems/diffeqs/odesystem.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ struct ODESystem <: AbstractODESystem
111111
"""
112112
discrete_events::Vector{SymbolicDiscreteCallback}
113113
"""
114+
A mapping from dependent parameters to expressions describing how they are calculated from
115+
other parameters.
116+
"""
117+
parameter_dependencies::Union{Nothing, Dict}
118+
"""
114119
Metadata for the system, to be used by downstream packages.
115120
"""
116121
metadata::Any
@@ -154,7 +159,7 @@ struct ODESystem <: AbstractODESystem
154159
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
155160
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
156161
torn_matching, connector_type, preface, cevents,
157-
devents, metadata = nothing, gui_metadata = nothing,
162+
devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
158163
tearing_state = nothing,
159164
substitutions = nothing, complete = false, index_cache = nothing,
160165
discrete_subsystems = nothing, solved_unknowns = nothing,
@@ -171,8 +176,8 @@ struct ODESystem <: AbstractODESystem
171176
end
172177
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
173178
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
174-
connector_type, preface, cevents, devents, metadata, gui_metadata,
175-
tearing_state, substitutions, complete, index_cache,
179+
connector_type, preface, cevents, devents, parameter_dependencies, metadata,
180+
gui_metadata, tearing_state, substitutions, complete, index_cache,
176181
discrete_subsystems, solved_unknowns, split_idxs, parent)
177182
end
178183
end
@@ -190,6 +195,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
190195
preface = nothing,
191196
continuous_events = nothing,
192197
discrete_events = nothing,
198+
parameter_dependencies = nothing,
193199
checks = true,
194200
metadata = nothing,
195201
gui_metadata = nothing)
@@ -225,10 +231,12 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
225231
end
226232
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
227233
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
234+
parameter_dependencies, ps′ = process_parameter_dependencies(
235+
parameter_dependencies, ps′)
228236
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
229237
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
230238
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
231-
connector_type, preface, cont_callbacks, disc_callbacks,
239+
connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies,
232240
metadata, gui_metadata, checks = checks)
233241
end
234242

@@ -323,7 +331,7 @@ function build_explicit_observed_function(sys, ts;
323331
output_type = Array,
324332
checkbounds = true,
325333
drop_expr = drop_expr,
326-
ps = parameters(sys),
334+
ps = full_parameters(sys),
327335
throw = true)
328336
if (isscalar = !(ts isa AbstractVector))
329337
ts = [ts]

src/systems/diffeqs/sdesystem.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ struct SDESystem <: AbstractODESystem
104104
"""
105105
discrete_events::Vector{SymbolicDiscreteCallback}
106106
"""
107+
A mapping from dependent parameters to expressions describing how they are calculated from
108+
other parameters.
109+
"""
110+
parameter_dependencies::Union{Nothing, Dict}
111+
"""
107112
Metadata for the system, to be used by downstream packages.
108113
"""
109114
metadata::Any
@@ -128,7 +133,7 @@ struct SDESystem <: AbstractODESystem
128133
tgrad,
129134
jac,
130135
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
131-
cevents, devents, metadata = nothing, gui_metadata = nothing,
136+
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
132137
complete = false, index_cache = nothing, parent = nothing;
133138
checks::Union{Bool, Int} = true)
134139
if checks == true || (checks & CheckComponents) > 0
@@ -144,7 +149,7 @@ struct SDESystem <: AbstractODESystem
144149
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
145150
ctrl_jac,
146151
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
147-
metadata, gui_metadata, complete, index_cache, parent)
152+
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent)
148153
end
149154
end
150155

@@ -161,6 +166,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
161166
checks = true,
162167
continuous_events = nothing,
163168
discrete_events = nothing,
169+
parameter_dependencies = nothing,
164170
metadata = nothing,
165171
gui_metadata = nothing)
166172
name === nothing &&
@@ -195,11 +201,12 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
195201
Wfact_t = RefValue(EMPTY_JAC)
196202
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
197203
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
198-
204+
parameter_dependencies, ps′ = process_parameter_dependencies(
205+
parameter_dependencies, ps′)
199206
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
200207
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
201208
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
202-
cont_callbacks, disc_callbacks, metadata, gui_metadata; checks = checks)
209+
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata; checks = checks)
203210
end
204211

205212
function SDESystem(sys::ODESystem, neqs; kwargs...)
@@ -220,7 +227,7 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
220227
end
221228

222229
function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
223-
ps = parameters(sys); isdde = false, kwargs...)
230+
ps = full_parameters(sys); isdde = false, kwargs...)
224231
eqs = get_noiseeqs(sys)
225232
if isdde
226233
eqs = delay_to_function(sys, eqs)
@@ -285,7 +292,7 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
285292
end
286293

287294
SDESystem(deqs, get_noiseeqs(sys), get_iv(sys), unknowns(sys), parameters(sys),
288-
name = name, checks = false)
295+
name = name, parameter_dependencies = get_parameter_dependencies(sys), checks = false)
289296
end
290297

291298
"""
@@ -393,7 +400,8 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
393400
# return modified SDE System
394401
SDESystem(deqs, noiseeqs, get_iv(sys), unknown_vars, parameters(sys);
395402
defaults = Dict=> θ0), observed = [weight ~ θ / θ0],
396-
name = name, checks = false)
403+
name = name, parameter_dependencies = get_parameter_dependencies(sys),
404+
checks = false)
397405
end
398406

399407
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),

0 commit comments

Comments
 (0)