Skip to content

Fixes for dependent parameters #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,11 @@
Dict(unknowns(sys, k) => namespace_expr(v, sys) for (k, v) in guess)
end

function namespace_parameter_dependencies(sys)
pdeps = parameter_dependencies(sys)
Dict(parameters(sys, k) => namespace_expr(v, sys) for (k, v) in pdeps)
end

function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
eqs = equations(sys)
isempty(eqs) && return Equation[]
Expand Down Expand Up @@ -965,7 +970,7 @@
result = unique(isempty(systems) ? ps :
[ps; reduce(vcat, namespace_parameters.(systems))])
if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
(pdeps = parameter_dependencies(sys)) !== nothing
filter(result) do sym
!haskey(pdeps, sym)
end
Expand All @@ -976,13 +981,27 @@

function dependent_parameters(sys::AbstractSystem)
if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
collect(keys(pdeps))
!isempty(parameter_dependencies(sys))
collect(keys(parameter_dependencies(sys)))

Check warning on line 985 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L985

Added line #L985 was not covered by tests
else
[]
end
end

function parameter_dependencies(sys::AbstractSystem)
pdeps = get_parameter_dependencies(sys)
if isnothing(pdeps)
pdeps = Dict()
end
systems = get_systems(sys)
isempty(systems) && return pdeps
for subsys in systems
pdeps = merge(pdeps, namespace_parameter_dependencies(subsys))
end
# @info pdeps
return pdeps
end

function full_parameters(sys::AbstractSystem)
vcat(parameters(sys), dependent_parameters(sys))
end
Expand Down Expand Up @@ -2372,8 +2391,8 @@
eqs = union(get_eqs(basesys), get_eqs(sys))
sts = union(get_unknowns(basesys), get_unknowns(sys))
ps = union(get_ps(basesys), get_ps(sys))
base_deps = get_parameter_dependencies(basesys)
deps = get_parameter_dependencies(sys)
base_deps = parameter_dependencies(basesys)
deps = parameter_dependencies(sys)
dep_ps = isnothing(base_deps) ? deps :
isnothing(deps) ? base_deps : union(base_deps, deps)
obs = union(get_observed(basesys), get_observed(sys))
Expand Down
5 changes: 3 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ function flatten(sys::ODESystem, noeqs = false)
get_iv(sys),
unknowns(sys),
parameters(sys),
parameter_dependencies = parameter_dependencies(sys),
guesses = guesses(sys),
observed = observed(sys),
continuous_events = continuous_events(sys),
Expand Down Expand Up @@ -405,10 +406,10 @@ function build_explicit_observed_function(sys, ts;
Set(arguments(st)[1] for st in sts if iscall(st) && operation(st) === getindex))

observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
param_set = Set(parameters(sys))
param_set = Set(full_parameters(sys))
param_set = union(param_set,
Set(arguments(p)[1] for p in param_set if iscall(p) && operation(p) === getindex))
param_set_ns = Set(unknowns(sys, p) for p in parameters(sys))
param_set_ns = Set(unknowns(sys, p) for p in full_parameters(sys))
param_set_ns = union(param_set_ns,
Set(arguments(p)[1]
for p in param_set_ns if iscall(p) && operation(p) === getindex))
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
end

SDESystem(deqs, get_noiseeqs(sys), get_iv(sys), unknowns(sys), parameters(sys),
name = name, parameter_dependencies = get_parameter_dependencies(sys), checks = false)
name = name, parameter_dependencies = parameter_dependencies(sys), checks = false)
end

"""
Expand Down Expand Up @@ -399,7 +399,7 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
# return modified SDE System
SDESystem(deqs, noiseeqs, get_iv(sys), unknown_vars, parameters(sys);
defaults = Dict(θ => θ0), observed = [weight ~ θ / θ0],
name = name, parameter_dependencies = get_parameter_dependencies(sys),
name = name, parameter_dependencies = parameter_dependencies(sys),
checks = false)
end

Expand Down
4 changes: 2 additions & 2 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ function IndexCache(sys::AbstractSystem)
end
end

if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
if has_parameter_dependencies(sys)
pdeps = parameter_dependencies(sys)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not technically necessary, although it will work regardless. If the system is simplified, all parameter dependencies will be available at the top level. If it isn't simplified, we can't simulate it anyway so an index cache is immaterial.

for (sym, value) in pdeps
sym = unwrap(sym)
insert_by_type!(dependent_buffers, sym)
Expand Down
1 change: 1 addition & 0 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ function generate_initializesystem(sys::ODESystem;
full_states,
pars;
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
parameter_dependencies = parameter_dependencies(sys),
name,
kwargs...)

Expand Down
2 changes: 1 addition & 1 deletion src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function MTKParameters(
nonnumeric_buffer = nonnumeric_buffer

if has_parameter_dependencies(sys) &&
(pdeps = get_parameter_dependencies(sys)) !== nothing
(pdeps = parameter_dependencies(sys)) !== nothing
pdeps = Dict(k => fixpoint_sub(v, pdeps) for (k, v) in pdeps)
dep_exprs = ArrayPartition((Any[missing for _ in 1:length(v)] for v in dep_buffer)...)
for (sym, val) in pdeps
Expand Down
88 changes: 87 additions & 1 deletion test/parameter_dependencies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,28 @@ using NonlinearSolve
@test integ.ps[p2] == 10.0
end

@testset "vector parameter deps" begin
@parameters p1[1:2]=[1.0, 2.0] p2[1:2]=[0.0, 0.0]
@variables x(t) = 0

@named sys = ODESystem(
[D(x) ~ sum(p1) * t + sum(p2)],
t;
parameter_dependencies = [p2 => 2p1]
)
prob = ODEProblem(complete(sys))
setp1! = setp(prob, p1)
get_p1 = getp(prob, p1)
get_p2 = getp(prob, p2)
setp1!(prob, [1.5, 2.5])

@test get_p1(prob) == [1.5, 2.5]
@test get_p2(prob) == [3.0, 5.0]
end

@testset "extend" begin
@parameters p1=1.0 p2=1.0
@variables x(t)
@variables x(t) = 0

@mtkbuild sys1 = ODESystem(
[D(x) ~ p1 * t + p2],
Expand All @@ -65,6 +84,73 @@ end
sys = extend(sys2, sys1)
@test isequal(only(parameters(sys)), p1)
@test Set(full_parameters(sys)) == Set([p1, p2])
prob = ODEProblem(complete(sys))
get_dep = getu(prob, 2p2)
@test get_dep(prob) == 4
end

@testset "getu with parameter deps" begin
@parameters p1=1.0 p2=1.0
@variables x(t) = 0

@named sys = ODESystem(
[D(x) ~ p1 * t + p2],
t;
parameter_dependencies = [p2 => 2p1]
)
prob = ODEProblem(complete(sys))
get_dep = getu(prob, 2p2)
@test get_dep(prob) == 4
end

Comment on lines +92 to +105
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a testcase for hierarchical systems where the subsystems have parameter dependencies? I have a feeling that will fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it failed. I added a test in db89f68 and I tried to do the renamespacing, but something is off and the dependent parameter does not seem to update correctly.

@testset "getu with vector parameter deps" begin
@parameters p1[1:2]=[1.0, 2.0] p2[1:2]=[0.0, 0.0]
@variables x(t) = 0

@named sys = ODESystem(
[D(x) ~ sum(p1) * t + sum(p2)],
t;
parameter_dependencies = [p2 => 2p1]
)
prob = ODEProblem(complete(sys))
get_dep = getu(prob, 2p1)
@test get_dep(prob) == [2.0, 4.0]
end

@testset "composing systems with parameter deps" begin
@parameters p1=1.0 p2=2.0
@variables x(t) = 0

@mtkbuild sys1 = ODESystem(
[D(x) ~ p1 * t + p2],
t
)
@named sys2 = ODESystem(
[D(x) ~ p1 * t - p2],
t;
parameter_dependencies = [p2 => 2p1]
)
sys = complete(ODESystem([], t, systems = [sys1, sys2], name = :sys))

prob = ODEProblem(sys)
v1 = sys.sys2.p2
v2 = 2 * v1
@test is_parameter(prob, v1)
@test is_observed(prob, v2)
get_v1 = getu(prob, v1)
get_v2 = getu(prob, v2)
@test get_v1(prob) == 2
@test get_v2(prob) == 4

setp1! = setp(prob, sys2.p1)
setp1!(prob, 2.5)
@test prob.ps[sys2.p2] == 5.0

new_prob = remake(prob, p = [sys2.p1 => 1.5])

@test !isempty(ModelingToolkit.parameter_dependencies(sys2))
@test new_prob.ps[sys2.p1] == 1.5
@test new_prob.ps[sys2.p2] == 3.0
end

@testset "Clock system" begin
Expand Down
Loading