Skip to content

Commit 7bc758b

Browse files
Merge pull request #2640 from AayushSabharwal/as/error-missing-params
fix: error when all parameters are not initialized
2 parents c1b5d83 + c75ca6d commit 7bc758b

File tree

4 files changed

+66
-30
lines changed

4 files changed

+66
-30
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -876,31 +876,34 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
876876
parammap = Dict(unwrap.(parameters(sys)) .=> parammap)
877877
end
878878
end
879-
clockedparammap = Dict()
880-
defs = ModelingToolkit.get_defaults(sys)
881-
for v in ps
882-
v = unwrap(v)
883-
is_discrete_domain(v) || continue
884-
op = operation(v)
885-
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
886-
haskey(parammap, v)
887-
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
879+
880+
if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
881+
clockedparammap = Dict()
882+
defs = ModelingToolkit.get_defaults(sys)
883+
for v in ps
884+
v = unwrap(v)
885+
is_discrete_domain(v) || continue
886+
op = operation(v)
887+
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
888+
haskey(parammap, v)
889+
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
890+
end
891+
shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v))
892+
if parammap != SciMLBase.NullParameters() &&
893+
(val = get(parammap, shiftedv, nothing)) !== nothing
894+
clockedparammap[v] = val
895+
elseif op isa Shift
896+
root = arguments(v)[1]
897+
haskey(defs, root) || error("Initial condition for $v not provided.")
898+
clockedparammap[v] = defs[root]
899+
end
888900
end
889-
shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v))
890-
if parammap != SciMLBase.NullParameters() &&
891-
(val = get(parammap, shiftedv, nothing)) !== nothing
892-
clockedparammap[v] = val
893-
elseif op isa Shift
894-
root = arguments(v)[1]
895-
haskey(defs, root) || error("Initial condition for $v not provided.")
896-
clockedparammap[v] = defs[root]
901+
parammap = if parammap == SciMLBase.NullParameters()
902+
clockedparammap
903+
else
904+
merge(parammap, clockedparammap)
897905
end
898906
end
899-
parammap = if parammap == SciMLBase.NullParameters()
900-
clockedparammap
901-
else
902-
merge(parammap, clockedparammap)
903-
end
904907
# TODO: make it work with clocks
905908
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
906909
if sys isa ODESystem && build_initializeprob &&
@@ -931,7 +934,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
931934
if has_index_cache(sys) && get_index_cache(sys) !== nothing
932935
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0)
933936
check_eqs_u0(eqs, dvs, u0; kwargs...)
934-
p = MTKParameters(sys, parammap, trueinit)
937+
p = if parammap === nothing ||
938+
parammap == SciMLBase.NullParameters() && isempty(defs)
939+
nothing
940+
else
941+
MTKParameters(sys, parammap, trueinit)
942+
end
935943
else
936944
u0, p, defs = get_u0_p(sys,
937945
trueinit,
@@ -1592,7 +1600,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
15921600
if !iscomplete(sys)
15931601
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
15941602
end
1595-
15961603
if isempty(u0map) && get_initializesystem(sys) !== nothing
15971604
isys = get_initializesystem(sys)
15981605
elseif isempty(u0map) && get_initializesystem(sys) === nothing
@@ -1620,9 +1627,9 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
16201627
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
16211628
end
16221629

1623-
parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
1624-
[get_iv(sys) => t] :
1625-
merge(todict(parammap), Dict(get_iv(sys) => t))
1630+
parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
1631+
[get_iv(sys) => t] :
1632+
merge(todict(parammap), Dict(get_iv(sys) => t))
16261633

16271634
if neqs == nunknown
16281635
NonlinearProblem(isys, guesses, parammap)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ function generate_function(
241241
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
242242

243243
p = reorder_parameters(sys, value.(ps))
244-
@show p ps
245244
return build_function(rhss, value.(dvs), p...; postprocess_fbody = pre,
246245
states = sol_states, kwargs...)
247246
end
@@ -395,7 +394,6 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
395394
eqs = equations(sys)
396395
dvs = unknowns(sys)
397396
ps = full_parameters(sys)
398-
399397
if has_index_cache(sys) && get_index_cache(sys) !== nothing
400398
u0, defs = get_u0(sys, u0map, parammap)
401399
check_eqs_u0(eqs, dvs, u0; kwargs...)

src/systems/parameter_buffer.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,32 @@ function MTKParameters(
4343
p = merge(defs, p)
4444
p = merge(Dict(unwrap(k) => v for (k, v) in p),
4545
Dict(default_toterm(unwrap(k)) => v for (k, v) in p))
46-
p = Dict(k => fixpoint_sub(v, p) for (k, v) in p)
46+
p = Dict(unwrap(k) => fixpoint_sub(v, p) for (k, v) in p)
4747
for (sym, _) in p
4848
if istree(sym) && operation(sym) === getindex &&
4949
first(arguments(sym)) in all_ps
5050
error("Scalarized parameter values ($sym) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
5151
end
5252
end
5353

54+
missing_params = Set()
55+
for idxmap in (ic.tunable_idx, ic.discrete_idx, ic.constant_idx, ic.nonnumeric_idx)
56+
for sym in keys(idxmap)
57+
sym isa Symbol && continue
58+
haskey(p, sym) && continue
59+
hasname(sym) && haskey(p, getname(sym)) && continue
60+
ttsym = default_toterm(sym)
61+
haskey(p, ttsym) && continue
62+
hasname(ttsym) && haskey(p, getname(ttsym)) && continue
63+
64+
istree(sym) && operation(sym) === getindex && haskey(p, arguments(sym)[1]) &&
65+
continue
66+
push!(missing_params, sym)
67+
end
68+
end
69+
70+
isempty(missing_params) || throw(MissingVariablesError(collect(missing_params)))
71+
5472
tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length)
5573
for temp in ic.tunable_buffer_sizes)
5674
disc_buffer = Tuple(Vector{temp.type}(undef, temp.length)

test/mtkparameters.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,16 @@ ps = MTKParameters(sys, [p => 1.0, q => 2.0, r => 3.0])
121121
newps = remake_buffer(sys, ps, Dict(p => 1.0f0))
122122
@test newps.tunable[1] isa Vector{Float32}
123123
@test newps.tunable[1] == [1.0f0, 2.0f0, 3.0f0]
124+
125+
# Issue#2624
126+
@parameters p d
127+
@variables X(t)
128+
eqs = [D(X) ~ p - d * X]
129+
@mtkbuild sys = ODESystem(eqs, t)
130+
131+
u0 = [X => 1.0]
132+
tspan = (0.0, 100.0)
133+
ps = [p => 1.0] # Value for `d` is missing
134+
135+
@test_throws ModelingToolkit.MissingVariablesError ODEProblem(sys, u0, tspan, ps)
136+
@test_nowarn ODEProblem(sys, u0, tspan, [ps..., d => 1.0])

0 commit comments

Comments
 (0)