Skip to content

Commit 9ed933a

Browse files
Merge pull request #2563 from AayushSabharwal/as/fix-param-defaults
feat: allow parameter defaults to depend on initial values of unknowns
2 parents 50af6c9 + f94e4c2 commit 9ed933a

File tree

8 files changed

+31
-70
lines changed

8 files changed

+31
-70
lines changed

docs/src/basics/FAQ.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ and the `MTKParameters` constructor for parameters. For example:
9191

9292
```julia
9393
unew = varmap_to_vars([x => 1.0, y => 2.0, z => 3.0], unknowns(sys))
94-
pnew = ModelingToolkit.MTKParameters(sys, [β => 3.0, c => 10.0, γ => 2.0])
94+
pnew = ModelingToolkit.MTKParameters(sys, [β => 3.0, c => 10.0, γ => 2.0], unew)
9595
```
9696

9797
## How do I handle `if` statements in my symbolic forms?

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,7 +1708,7 @@ function linearization_function(sys::AbstractSystem, inputs,
17081708
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
17091709
ps = parameters(sys)
17101710
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1711-
p = MTKParameters(sys, p)
1711+
p = MTKParameters(sys, p, u0)
17121712
else
17131713
p = _p
17141714
p, split_idxs = split_parameters_by_type(p)
@@ -2011,7 +2011,7 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
20112011
elseif p isa Vector
20122012
p = merge(Dict(parameters(sys) .=> p), op)
20132013
end
2014-
p2 = MTKParameters(sys, p)
2014+
p2 = MTKParameters(sys, p, Dict(unknowns(sys) .=> u0))
20152015
end
20162016
linres = lin_fun(u0, p2, t)
20172017
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
889889

890890
if has_index_cache(sys) && get_index_cache(sys) !== nothing
891891
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0)
892-
p = MTKParameters(sys, parammap)
892+
check_eqs_u0(eqs, dvs, u0; kwargs...)
893+
p = MTKParameters(sys, parammap, trueinit)
893894
else
894895
u0, p, defs = get_u0_p(sys,
895896
trueinit,

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
338338

339339
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
340340
if has_index_cache(sys) && get_index_cache(sys) !== nothing
341-
p = MTKParameters(sys, parammap)
341+
p = MTKParameters(sys, parammap, u0map)
342342
else
343343
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
344344
end
@@ -395,7 +395,7 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
395395

396396
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
397397
if has_index_cache(sys) && get_index_cache(sys) !== nothing
398-
p = MTKParameters(sys, parammap)
398+
p = MTKParameters(sys, parammap, u0map)
399399
else
400400
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
401401
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,12 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
363363

364364
if has_index_cache(sys) && get_index_cache(sys) !== nothing
365365
u0, defs = get_u0(sys, u0map, parammap)
366-
p = MTKParameters(sys, parammap)
366+
check_eqs_u0(eqs, dvs, u0; kwargs...)
367+
p = MTKParameters(sys, parammap, u0map)
367368
else
368369
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
370+
check_eqs_u0(eqs, dvs, u0; kwargs...)
369371
end
370-
check_eqs_u0(eqs, dvs, u0; kwargs...)
371372

372373
f = constructor(sys, dvs, ps, u0; jac = jac, checkbounds = checkbounds,
373374
linenumbers = linenumbers, parallel = parallel, simplify = simplify,

src/systems/optimization/optimizationsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
280280
if parammap isa MTKParameters
281281
p = parammap
282282
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
283-
p = MTKParameters(sys, parammap)
283+
p = MTKParameters(sys, parammap, u0map)
284284
else
285285
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
286286
end
@@ -516,7 +516,7 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0map,
516516

517517
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
518518
if has_index_cache(sys) && get_index_cache(sys) !== nothing
519-
p = MTKParameters(sys, parammap)
519+
p = MTKParameters(sys, parammap, u0map)
520520
else
521521
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
522522
end

src/systems/parameter_buffer.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ struct MTKParameters{T, D, C, E, N, F, G}
1010
dependent_update_oop::G
1111
end
1212

13-
function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = false)
13+
function MTKParameters(
14+
sys::AbstractSystem, p, u0 = Dict(); tofloat = false, use_union = false)
1415
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
1516
get_index_cache(sys)
1617
else
@@ -23,21 +24,27 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
2324
length(p) == length(ps) || error("Invalid parameters")
2425
p = ps .=> p
2526
end
27+
if p isa SciMLBase.NullParameters || isempty(p)
28+
p = Dict()
29+
end
30+
p = todict(p)
2631
defs = Dict(default_toterm(unwrap(k)) => v
2732
for (k, v) in defaults(sys)
2833
if unwrap(k) in all_ps || default_toterm(unwrap(k)) in all_ps)
29-
if p isa SciMLBase.NullParameters
30-
p = defs
31-
else
32-
extra_params = Dict(unwrap(k) => v
33-
for (k, v) in p if !in(unwrap(k), all_ps) && !in(default_toterm(unwrap(k)), all_ps))
34-
p = merge(defs,
35-
Dict(default_toterm(unwrap(k)) => v
36-
for (k, v) in p if unwrap(k) in all_ps || default_toterm(unwrap(k)) in all_ps))
37-
p = Dict(k => fixpoint_sub(v, extra_params)
38-
for (k, v) in p if !haskey(extra_params, unwrap(k)))
34+
if eltype(u0) <: Pair
35+
u0 = todict(u0)
36+
elseif u0 isa AbstractArray && !isempty(u0)
37+
u0 = Dict(unknowns(sys) .=> vec(u0))
38+
elseif u0 === nothing || isempty(u0)
39+
u0 = Dict()
3940
end
40-
41+
defs = merge(defs, u0)
42+
defs = merge(defs, Dict(eq.lhs => eq.rhs for eq in observed(sys)))
43+
p = merge(defs, p)
44+
p = merge(Dict(unwrap(k) => v for (k, v) in p),
45+
Dict(default_toterm(unwrap(k)) => v for (k, v) in p))
46+
p = Dict(k => fixpoint_sub(v, p)
47+
for (k, v) in p if k in all_ps || default_toterm(k) in all_ps)
4148
for (sym, _) in p
4249
if istree(sym) && operation(sym) === getindex &&
4350
first(arguments(sym)) in all_ps

src/variables.jl

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -215,54 +215,6 @@ end
215215
throw(ArgumentError("$vars are missing from the variable map."))
216216
end
217217

218-
"""
219-
$(SIGNATURES)
220-
221-
Intercept the call to `process_p_u0_symbolic` and process symbolic maps of `p` and/or `u0` if the
222-
user has `ModelingToolkit` loaded.
223-
"""
224-
function SciMLBase.process_p_u0_symbolic(
225-
prob::Union{SciMLBase.AbstractDEProblem,
226-
NonlinearProblem, OptimizationProblem,
227-
SciMLBase.AbstractOptimizationCache},
228-
p,
229-
u0)
230-
# check if a symbolic remake is possible
231-
if p isa Vector && !(eltype(p) <: Pair)
232-
error("Parameter values must be specified as a `Dict` or `Vector{<:Pair}`")
233-
end
234-
if eltype(p) <: Pair
235-
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||
236-
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
237-
" Please use `remake` with the `p` keyword argument as a vector of values, paying attention to parameter order."))
238-
end
239-
if eltype(u0) <: Pair
240-
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :unknowns) ||
241-
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
242-
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to unknown variable order."))
243-
end
244-
245-
sys = prob.f.sys
246-
defs = defaults(sys)
247-
ps = parameters(sys)
248-
if has_split_idxs(sys) && (split_idxs = get_split_idxs(sys)) !== nothing
249-
for (i, idxs) in enumerate(split_idxs)
250-
defs = mergedefaults(defs, prob.p[i], ps[idxs])
251-
end
252-
else
253-
# assemble defaults
254-
defs = defaults(sys)
255-
defs = mergedefaults(defs, prob.p, ps)
256-
end
257-
defs = mergedefaults(defs, p, ps)
258-
sts = unknowns(sys)
259-
defs = mergedefaults(defs, prob.u0, sts)
260-
defs = mergedefaults(defs, u0, sts)
261-
u0, _, defs = get_u0_p(sys, defs)
262-
p = MTKParameters(sys, p)
263-
return p, u0
264-
end
265-
266218
struct IsHistory end
267219
ishistory(x) = ishistory(unwrap(x))
268220
ishistory(x::Symbolic) = getmetadata(x, IsHistory, false)

0 commit comments

Comments
 (0)