Skip to content

Commit d54f877

Browse files
Merge pull request #2632 from AayushSabharwal/as/nonlinearsys-paramdeps
feat: support parameter dependencies for NonlinearSystem
2 parents c97d559 + 2a4fbd2 commit d54f877

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
5757
"""
5858
connector_type::Any
5959
"""
60+
A mapping from dependent parameters to expressions describing how they are calculated from
61+
other parameters.
62+
"""
63+
parameter_dependencies::Union{Nothing, Dict}
64+
"""
6065
Metadata for the system, to be used by downstream packages.
6166
"""
6267
metadata::Any
@@ -87,7 +92,7 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
8792

8893
function NonlinearSystem(tag, eqs, unknowns, ps, var_to_name, observed, jac, name,
8994
systems,
90-
defaults, connector_type, metadata = nothing,
95+
defaults, connector_type, parameter_dependencies = nothing, metadata = nothing,
9196
gui_metadata = nothing,
9297
tearing_state = nothing, substitutions = nothing,
9398
complete = false, index_cache = nothing, parent = nothing; checks::Union{
@@ -97,8 +102,8 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
97102
check_units(u, eqs)
98103
end
99104
new(tag, eqs, unknowns, ps, var_to_name, observed, jac, name, systems, defaults,
100-
connector_type, metadata, gui_metadata, tearing_state, substitutions, complete,
101-
index_cache, parent)
105+
connector_type, parameter_dependencies, metadata, gui_metadata, tearing_state,
106+
substitutions, complete, index_cache, parent)
102107
end
103108
end
104109

@@ -113,6 +118,7 @@ function NonlinearSystem(eqs, unknowns, ps;
113118
continuous_events = nothing, # this argument is only required for ODESystems, but is added here for the constructor to accept it without error
114119
discrete_events = nothing, # this argument is only required for ODESystems, but is added here for the constructor to accept it without error
115120
checks = true,
121+
parameter_dependencies = nothing,
116122
metadata = nothing,
117123
gui_metadata = nothing)
118124
continuous_events === nothing || isempty(continuous_events) ||
@@ -148,9 +154,11 @@ function NonlinearSystem(eqs, unknowns, ps;
148154
process_variables!(var_to_name, defaults, ps)
149155
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
150156

157+
parameter_dependencies, ps = process_parameter_dependencies(
158+
parameter_dependencies, ps)
151159
NonlinearSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
152160
eqs, unknowns, ps, var_to_name, observed, jac, name, systems, defaults,
153-
connector_type, metadata, gui_metadata, checks = checks)
161+
connector_type, parameter_dependencies, metadata, gui_metadata, checks = checks)
154162
end
155163

156164
function NonlinearSystem(eqs; kwargs...)
@@ -233,6 +241,7 @@ function generate_function(
233241
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
234242

235243
p = reorder_parameters(sys, value.(ps))
244+
@show p ps
236245
return build_function(rhss, value.(dvs), p...; postprocess_fbody = pre,
237246
states = sol_states, kwargs...)
238247
end
@@ -385,7 +394,7 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
385394
kwargs...)
386395
eqs = equations(sys)
387396
dvs = unknowns(sys)
388-
ps = parameters(sys)
397+
ps = full_parameters(sys)
389398

390399
if has_index_cache(sys) && get_index_cache(sys) !== nothing
391400
u0, defs = get_u0(sys, u0map, parammap)

test/parameter_dependencies.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using JumpProcesses
77
using StableRNGs
88
using SciMLStructures: canonicalize, Tunable, replace, replace!
99
using SymbolicIndexingInterface
10+
using NonlinearSolve
1011

1112
@testset "ODESystem with callbacks" begin
1213
@parameters p1=1.0 p2=1.0
@@ -162,6 +163,22 @@ end
162163
@test integ.ps[β] == 0.0002
163164
end
164165

166+
@testset "NonlinearSystem" begin
167+
@parameters p1=1.0 p2=1.0
168+
@variables x(t)
169+
eqs = [0 ~ p1 * x * exp(x) + p2]
170+
@mtkbuild sys = NonlinearSystem(eqs; parameter_dependencies = [p2 => 2p1])
171+
@test isequal(only(parameters(sys)), p1)
172+
@test Set(full_parameters(sys)) == Set([p1, p2])
173+
prob = NonlinearProblem(sys, [x => 1.0])
174+
@test prob.ps[p1] == 1.0
175+
@test prob.ps[p2] == 2.0
176+
@test_nowarn solve(prob, NewtonRaphson())
177+
prob = NonlinearProblem(sys, [x => 1.0], [p1 => 2.0])
178+
@test prob.ps[p1] == 2.0
179+
@test prob.ps[p2] == 4.0
180+
end
181+
165182
@testset "SciMLStructures interface" begin
166183
@parameters p1=1.0 p2=1.0
167184
@variables x(t)

0 commit comments

Comments
 (0)