Skip to content

Commit 4e91484

Browse files
Merge pull request #2668 from SebastianM-C/obs_dep
Fixes for dependent parameters
2 parents 2adc33a + 55a9dc8 commit 4e91484

File tree

7 files changed

+120
-13
lines changed

7 files changed

+120
-13
lines changed

src/systems/abstractsystem.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,11 @@ function namespace_guesses(sys)
878878
Dict(unknowns(sys, k) => namespace_expr(v, sys) for (k, v) in guess)
879879
end
880880

881+
function namespace_parameter_dependencies(sys)
882+
pdeps = parameter_dependencies(sys)
883+
Dict(parameters(sys, k) => namespace_expr(v, sys) for (k, v) in pdeps)
884+
end
885+
881886
function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
882887
eqs = equations(sys)
883888
isempty(eqs) && return Equation[]
@@ -965,7 +970,7 @@ function parameters(sys::AbstractSystem)
965970
result = unique(isempty(systems) ? ps :
966971
[ps; reduce(vcat, namespace_parameters.(systems))])
967972
if has_parameter_dependencies(sys) &&
968-
(pdeps = get_parameter_dependencies(sys)) !== nothing
973+
(pdeps = parameter_dependencies(sys)) !== nothing
969974
filter(result) do sym
970975
!haskey(pdeps, sym)
971976
end
@@ -976,13 +981,27 @@ end
976981

977982
function dependent_parameters(sys::AbstractSystem)
978983
if has_parameter_dependencies(sys) &&
979-
(pdeps = get_parameter_dependencies(sys)) !== nothing
980-
collect(keys(pdeps))
984+
!isempty(parameter_dependencies(sys))
985+
collect(keys(parameter_dependencies(sys)))
981986
else
982987
[]
983988
end
984989
end
985990

991+
function parameter_dependencies(sys::AbstractSystem)
992+
pdeps = get_parameter_dependencies(sys)
993+
if isnothing(pdeps)
994+
pdeps = Dict()
995+
end
996+
systems = get_systems(sys)
997+
isempty(systems) && return pdeps
998+
for subsys in systems
999+
pdeps = merge(pdeps, namespace_parameter_dependencies(subsys))
1000+
end
1001+
# @info pdeps
1002+
return pdeps
1003+
end
1004+
9861005
function full_parameters(sys::AbstractSystem)
9871006
vcat(parameters(sys), dependent_parameters(sys))
9881007
end
@@ -2372,8 +2391,8 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam
23722391
eqs = union(get_eqs(basesys), get_eqs(sys))
23732392
sts = union(get_unknowns(basesys), get_unknowns(sys))
23742393
ps = union(get_ps(basesys), get_ps(sys))
2375-
base_deps = get_parameter_dependencies(basesys)
2376-
deps = get_parameter_dependencies(sys)
2394+
base_deps = parameter_dependencies(basesys)
2395+
deps = parameter_dependencies(sys)
23772396
dep_ps = isnothing(base_deps) ? deps :
23782397
isnothing(deps) ? base_deps : union(base_deps, deps)
23792398
obs = union(get_observed(basesys), get_observed(sys))

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ function flatten(sys::ODESystem, noeqs = false)
355355
get_iv(sys),
356356
unknowns(sys),
357357
parameters(sys),
358+
parameter_dependencies = parameter_dependencies(sys),
358359
guesses = guesses(sys),
359360
observed = observed(sys),
360361
continuous_events = continuous_events(sys),
@@ -405,10 +406,10 @@ function build_explicit_observed_function(sys, ts;
405406
Set(arguments(st)[1] for st in sts if iscall(st) && operation(st) === getindex))
406407

407408
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
408-
param_set = Set(parameters(sys))
409+
param_set = Set(full_parameters(sys))
409410
param_set = union(param_set,
410411
Set(arguments(p)[1] for p in param_set if iscall(p) && operation(p) === getindex))
411-
param_set_ns = Set(unknowns(sys, p) for p in parameters(sys))
412+
param_set_ns = Set(unknowns(sys, p) for p in full_parameters(sys))
412413
param_set_ns = union(param_set_ns,
413414
Set(arguments(p)[1]
414415
for p in param_set_ns if iscall(p) && operation(p) === getindex))

src/systems/diffeqs/sdesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
291291
end
292292

293293
SDESystem(deqs, get_noiseeqs(sys), get_iv(sys), unknowns(sys), parameters(sys),
294-
name = name, parameter_dependencies = get_parameter_dependencies(sys), checks = false)
294+
name = name, parameter_dependencies = parameter_dependencies(sys), checks = false)
295295
end
296296

297297
"""
@@ -399,7 +399,7 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
399399
# return modified SDE System
400400
SDESystem(deqs, noiseeqs, get_iv(sys), unknown_vars, parameters(sys);
401401
defaults = Dict=> θ0), observed = [weight ~ θ / θ0],
402-
name = name, parameter_dependencies = get_parameter_dependencies(sys),
402+
name = name, parameter_dependencies = parameter_dependencies(sys),
403403
checks = false)
404404
end
405405

src/systems/index_cache.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ function IndexCache(sys::AbstractSystem)
115115
end
116116
end
117117

118-
if has_parameter_dependencies(sys) &&
119-
(pdeps = get_parameter_dependencies(sys)) !== nothing
118+
if has_parameter_dependencies(sys)
119+
pdeps = parameter_dependencies(sys)
120120
for (sym, value) in pdeps
121121
sym = unwrap(sym)
122122
insert_by_type!(dependent_buffers, sym)

src/systems/nonlinear/initializesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ function generate_initializesystem(sys::ODESystem;
102102
full_states,
103103
pars;
104104
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
105+
parameter_dependencies = parameter_dependencies(sys),
105106
name,
106107
kwargs...)
107108

src/systems/parameter_buffer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function MTKParameters(
136136
nonnumeric_buffer = nonnumeric_buffer
137137

138138
if has_parameter_dependencies(sys) &&
139-
(pdeps = get_parameter_dependencies(sys)) !== nothing
139+
(pdeps = parameter_dependencies(sys)) !== nothing
140140
pdeps = Dict(k => fixpoint_sub(v, pdeps) for (k, v) in pdeps)
141141
dep_exprs = ArrayPartition((Any[missing for _ in 1:length(v)] for v in dep_buffer)...)
142142
for (sym, val) in pdeps

test/parameter_dependencies.jl

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,28 @@ using NonlinearSolve
4949
@test integ.ps[p2] == 10.0
5050
end
5151

52+
@testset "vector parameter deps" begin
53+
@parameters p1[1:2]=[1.0, 2.0] p2[1:2]=[0.0, 0.0]
54+
@variables x(t) = 0
55+
56+
@named sys = ODESystem(
57+
[D(x) ~ sum(p1) * t + sum(p2)],
58+
t;
59+
parameter_dependencies = [p2 => 2p1]
60+
)
61+
prob = ODEProblem(complete(sys))
62+
setp1! = setp(prob, p1)
63+
get_p1 = getp(prob, p1)
64+
get_p2 = getp(prob, p2)
65+
setp1!(prob, [1.5, 2.5])
66+
67+
@test get_p1(prob) == [1.5, 2.5]
68+
@test get_p2(prob) == [3.0, 5.0]
69+
end
70+
5271
@testset "extend" begin
5372
@parameters p1=1.0 p2=1.0
54-
@variables x(t)
73+
@variables x(t) = 0
5574

5675
@mtkbuild sys1 = ODESystem(
5776
[D(x) ~ p1 * t + p2],
@@ -65,6 +84,73 @@ end
6584
sys = extend(sys2, sys1)
6685
@test isequal(only(parameters(sys)), p1)
6786
@test Set(full_parameters(sys)) == Set([p1, p2])
87+
prob = ODEProblem(complete(sys))
88+
get_dep = getu(prob, 2p2)
89+
@test get_dep(prob) == 4
90+
end
91+
92+
@testset "getu with parameter deps" begin
93+
@parameters p1=1.0 p2=1.0
94+
@variables x(t) = 0
95+
96+
@named sys = ODESystem(
97+
[D(x) ~ p1 * t + p2],
98+
t;
99+
parameter_dependencies = [p2 => 2p1]
100+
)
101+
prob = ODEProblem(complete(sys))
102+
get_dep = getu(prob, 2p2)
103+
@test get_dep(prob) == 4
104+
end
105+
106+
@testset "getu with vector parameter deps" begin
107+
@parameters p1[1:2]=[1.0, 2.0] p2[1:2]=[0.0, 0.0]
108+
@variables x(t) = 0
109+
110+
@named sys = ODESystem(
111+
[D(x) ~ sum(p1) * t + sum(p2)],
112+
t;
113+
parameter_dependencies = [p2 => 2p1]
114+
)
115+
prob = ODEProblem(complete(sys))
116+
get_dep = getu(prob, 2p1)
117+
@test get_dep(prob) == [2.0, 4.0]
118+
end
119+
120+
@testset "composing systems with parameter deps" begin
121+
@parameters p1=1.0 p2=2.0
122+
@variables x(t) = 0
123+
124+
@mtkbuild sys1 = ODESystem(
125+
[D(x) ~ p1 * t + p2],
126+
t
127+
)
128+
@named sys2 = ODESystem(
129+
[D(x) ~ p1 * t - p2],
130+
t;
131+
parameter_dependencies = [p2 => 2p1]
132+
)
133+
sys = complete(ODESystem([], t, systems = [sys1, sys2], name = :sys))
134+
135+
prob = ODEProblem(sys)
136+
v1 = sys.sys2.p2
137+
v2 = 2 * v1
138+
@test is_parameter(prob, v1)
139+
@test is_observed(prob, v2)
140+
get_v1 = getu(prob, v1)
141+
get_v2 = getu(prob, v2)
142+
@test get_v1(prob) == 2
143+
@test get_v2(prob) == 4
144+
145+
setp1! = setp(prob, sys2.p1)
146+
setp1!(prob, 2.5)
147+
@test prob.ps[sys2.p2] == 5.0
148+
149+
new_prob = remake(prob, p = [sys2.p1 => 1.5])
150+
151+
@test !isempty(ModelingToolkit.parameter_dependencies(sys2))
152+
@test new_prob.ps[sys2.p1] == 1.5
153+
@test new_prob.ps[sys2.p2] == 3.0
68154
end
69155

70156
@testset "Clock system" begin

0 commit comments

Comments
 (0)