Skip to content

Add propagation of guesses from parameters and observed #2775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,9 @@ function get_u0_p(sys,
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
end
end
defs = mergedefaults(defs, u0map, dvs)
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
observedmap = isempty(obs) ? Dict() : todict(obs)
defs = mergedefaults(defs, observedmap, u0map, dvs)
for (k, v) in defs
if Symbolics.isarraysymbolic(k)
ks = scalarize(k)
Expand Down Expand Up @@ -821,7 +823,9 @@ function get_u0(
if parammap !== nothing
defs = mergedefaults(defs, parammap, ps)
end
defs = mergedefaults(defs, u0map, dvs)
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
observedmap = isempty(obs) ? Dict() : todict(obs)
defs = mergedefaults(defs, observedmap, u0map, dvs)
if symbolic_u0
u0 = varmap_to_vars(
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
Expand Down Expand Up @@ -1637,6 +1641,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
if isempty(guesses)
guesses = Dict()
end

u0map = merge(todict(guesses), todict(u0map))
if neqs == nunknown
NonlinearProblem(isys, u0map, parammap)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function generate_initializesystem(sys::ODESystem;
schedule = getfield(sys, :schedule)

if schedule !== nothing
guessmap = [x[2] => get(guesses, x[1], default_dd_value)
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
for x in schedule.dummy_sub]
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
if u0map === nothing || isempty(u0map)
Expand Down
12 changes: 12 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,18 @@
end
end

function mergedefaults(defaults, observedmap, varmap, vars)
defs = if varmap isa Dict
merge(observedmap, defaults, varmap)
elseif eltype(varmap) <: Pair
merge(observedmap, defaults, Dict(varmap))
elseif eltype(varmap) <: Number
merge(observedmap, defaults, Dict(zip(vars, varmap)))

Check warning on line 626 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L626

Added line #L626 was not covered by tests
else
merge(observedmap, defaults)
end
end

@noinline function throw_missingvars_in_sys(vars)
throw(ArgumentError("$vars are either missing from the variable map or missing from the system's unknowns/parameters list."))
end
Expand Down
75 changes: 75 additions & 0 deletions test/guess_propagation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using ModelingToolkit, OrdinaryDiffEq
using ModelingToolkit: D, t_nounits as t
using Test

# Standard case

@variables x(t) [guess = 2]
@variables y(t)
eqs = [D(x) ~ 1
x ~ y]
initialization_eqs = [1 ~ exp(1 + x)]

@named sys = ODESystem(eqs, t; initialization_eqs)
sys = complete(structural_simplify(sys))
tspan = (0.0, 0.2)
prob = ODEProblem(sys, [], tspan, [])

@test prob.f.initializeprob[y] == 2.0
@test prob.f.initializeprob[x] == 2.0
sol = solve(prob.f.initializeprob; show_trace = Val(true))

# Guess via observed

@variables x(t)
@variables y(t) [guess = 2]
eqs = [D(x) ~ 1
x ~ y]
initialization_eqs = [1 ~ exp(1 + x)]

@named sys = ODESystem(eqs, t; initialization_eqs)
sys = complete(structural_simplify(sys))
tspan = (0.0, 0.2)
prob = ODEProblem(sys, [], tspan, [])

@test prob.f.initializeprob[x] == 2.0
@test prob.f.initializeprob[y] == 2.0
sol = solve(prob.f.initializeprob; show_trace = Val(true))

# Guess via parameter

@parameters a = -1.0
@variables x(t) [guess = a]

eqs = [D(x) ~ a]

initialization_eqs = [1 ~ exp(1 + x)]

@named sys = ODESystem(eqs, t; initialization_eqs)
sys = complete(structural_simplify(sys))

tspan = (0.0, 0.2)
prob = ODEProblem(sys, [], tspan, [])

@test prob.f.initializeprob[x] == -1.0
sol = solve(prob.f.initializeprob; show_trace = Val(true))

# Guess via observed parameter

@parameters a = -1.0
@variables x(t)
@variables y(t) [guess = a]

eqs = [D(x) ~ a,
y ~ x]

initialization_eqs = [1 ~ exp(1 + x)]

@named sys = ODESystem(eqs, t; initialization_eqs)
sys = complete(structural_simplify(sys))

tspan = (0.0, 0.2)
prob = ODEProblem(sys, [], tspan, [])

@test prob.f.initializeprob[x] == -1.0
sol = solve(prob.f.initializeprob; show_trace = Val(true))
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ end
@safetestset "DDESystem Test" include("dde.jl")
@safetestset "NonlinearSystem Test" include("nonlinearsystem.jl")
@safetestset "InitializationSystem Test" include("initializationsystem.jl")
@safetestset "Guess Propagation" include("guess_propagation.jl")
@safetestset "Hierarchical Initialization Equations" include("hierarchical_initialization_eqs.jl")
@safetestset "PDE Construction Test" include("pde.jl")
@safetestset "JumpSystem Test" include("jumpsystem.jl")
Expand Down
Loading