Skip to content

Commit a285fc2

Browse files
Merge pull request #2768 from AayushSabharwal/as/mtkize-names
feat: allow specifying variable names in `modelingtoolkitize`
2 parents 4e91484 + 9ed1b73 commit a285fc2

File tree

4 files changed

+378
-48
lines changed

4 files changed

+378
-48
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,44 @@ $(TYPEDSIGNATURES)
33
44
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
55
"""
6-
function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
6+
function modelingtoolkitize(
7+
prob::DiffEqBase.ODEProblem; u_names = nothing, p_names = nothing, kwargs...)
78
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
89
return prob.f.sys
9-
@parameters t
10-
10+
t = t_nounits
1111
p = prob.p
1212
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
1313

14-
_vars = define_vars(prob.u0, t)
14+
if u_names !== nothing
15+
varnames_length_check(prob.u0, u_names; is_unknowns = true)
16+
_vars = [_defvar(name)(t) for name in u_names]
17+
elseif SciMLBase.has_sys(prob.f)
18+
varnames = getname.(variable_symbols(prob.f.sys))
19+
varidxs = variable_index.((prob.f.sys,), varnames)
20+
invpermute!(varnames, varidxs)
21+
_vars = [_defvar(name)(t) for name in varnames]
22+
else
23+
_vars = define_vars(prob.u0, t)
24+
end
1525

1626
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0, _vars)
1727
params = if has_p
18-
_params = define_params(p)
28+
if p_names === nothing && SciMLBase.has_sys(prob.f)
29+
p_names = Dict(parameter_index(prob.f.sys, sym) => sym
30+
for sym in parameter_symbols(prob.f.sys))
31+
end
32+
_params = define_params(p, p_names)
1933
p isa Number ? _params[1] :
20-
(p isa Tuple || p isa NamedTuple || p isa AbstractDict ? _params :
34+
(p isa Tuple || p isa NamedTuple || p isa AbstractDict || p isa MTKParameters ?
35+
_params :
2136
ArrayInterface.restructure(p, _params))
2237
else
2338
[]
2439
end
2540

2641
var_set = Set(vars)
2742

28-
D = Differential(t)
43+
D = D_nounits
2944
mm = prob.f.mass_matrix
3045

3146
if mm === I
@@ -70,6 +85,8 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
7085
default_p = if has_p
7186
if prob.p isa AbstractDict
7287
Dict(v => prob.p[k] for (k, v) in pairs(_params))
88+
elseif prob.p isa MTKParameters
89+
Dict(params .=> reduce(vcat, prob.p))
7390
else
7491
Dict(params .=> vec(collect(prob.p)))
7592
end
@@ -125,44 +142,96 @@ function Base.showerror(io::IO, e::ModelingtoolkitizeParametersNotSupportedError
125142
println(io, e.type)
126143
end
127144

128-
function define_params(p)
145+
function varnames_length_check(vars, names; is_unknowns = false)
146+
if length(names) != length(vars)
147+
throw(ArgumentError("""
148+
Number of $(is_unknowns ? "unknowns" : "parameters") ($(length(vars))) \
149+
does not match number of names ($(length(names))).
150+
"""))
151+
end
152+
end
153+
154+
function define_params(p, _ = nothing)
129155
throw(ModelingtoolkitizeParametersNotSupportedError(typeof(p)))
130156
end
131157

132-
function define_params(p::AbstractArray)
133-
[toparam(variable(, i)) for i in eachindex(p)]
158+
function define_params(p::AbstractArray, names = nothing)
159+
if names === nothing
160+
[toparam(variable(, i)) for i in eachindex(p)]
161+
else
162+
varnames_length_check(p, names)
163+
[toparam(variable(names[i])) for i in eachindex(p)]
164+
end
134165
end
135166

136-
function define_params(p::Number)
137-
[toparam(variable())]
167+
function define_params(p::Number, names = nothing)
168+
if names === nothing
169+
[toparam(variable())]
170+
elseif names isa Union{AbstractArray, AbstractDict}
171+
varnames_length_check(p, names)
172+
[toparam(variable(names[i])) for i in eachindex(p)]
173+
else
174+
[toparam(variable(names))]
175+
end
138176
end
139177

140-
function define_params(p::AbstractDict)
141-
OrderedDict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
178+
function define_params(p::AbstractDict, names = nothing)
179+
if names === nothing
180+
OrderedDict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
181+
else
182+
varnames_length_check(p, names)
183+
OrderedDict(k => toparam(variable(names[k])) for k in keys(p))
184+
end
142185
end
143186

144-
function define_params(p::Union{SLArray, LArray})
145-
[toparam(variable(x)) for x in LabelledArrays.symnames(typeof(p))]
187+
function define_params(p::Union{SLArray, LArray}, names = nothing)
188+
if names === nothing
189+
[toparam(variable(x)) for x in LabelledArrays.symnames(typeof(p))]
190+
else
191+
varnames_length_check(p, names)
192+
[toparam(variable(names[i])) for i in eachindex(p)]
193+
end
146194
end
147195

148-
function define_params(p::Tuple)
149-
tuple((toparam(variable(, i)) for i in eachindex(p))...)
196+
function define_params(p::Tuple, names = nothing)
197+
if names === nothing
198+
tuple((toparam(variable(, i)) for i in eachindex(p))...)
199+
else
200+
varnames_length_check(p, names)
201+
tuple((toparam(variable(names[i])) for i in eachindex(p))...)
202+
end
150203
end
151204

152-
function define_params(p::NamedTuple)
153-
NamedTuple(x => toparam(variable(x)) for x in keys(p))
205+
function define_params(p::NamedTuple, names = nothing)
206+
if names === nothing
207+
NamedTuple(x => toparam(variable(x)) for x in keys(p))
208+
else
209+
varnames_length_check(p, names)
210+
NamedTuple(x => toparam(variable(names[x])) for x in keys(p))
211+
end
154212
end
155213

156-
function define_params(p::MTKParameters)
157-
bufs = (p...,)
158-
i = 1
159-
ps = []
160-
for buf in bufs
161-
for _ in buf
162-
push!(ps, toparam(variable(, i)))
214+
function define_params(p::MTKParameters, names = nothing)
215+
if names === nothing
216+
bufs = (p...,)
217+
i = 1
218+
ps = []
219+
for buf in bufs
220+
for _ in buf
221+
push!(
222+
ps,
223+
if names === nothing
224+
toparam(variable(, i))
225+
else
226+
toparam(variable(names[i]))
227+
end
228+
)
229+
end
163230
end
231+
return identity.(ps)
232+
else
233+
return collect(values(names))
164234
end
165-
return identity.(ps)
166235
end
167236

168237
"""

src/systems/nonlinear/modelingtoolkitize.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,34 @@ $(TYPEDSIGNATURES)
33
44
Generate `NonlinearSystem`, dependent variables, and parameters from an `NonlinearProblem`.
55
"""
6-
function modelingtoolkitize(prob::NonlinearProblem; kwargs...)
6+
function modelingtoolkitize(
7+
prob::NonlinearProblem; u_names = nothing, p_names = nothing, kwargs...)
78
p = prob.p
89
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
910

10-
_vars = reshape([variable(:x, i) for i in eachindex(prob.u0)], size(prob.u0))
11+
if u_names !== nothing
12+
varnames_length_check(prob.u0, u_names; is_unknowns = true)
13+
_vars = [variable(name) for name in u_names]
14+
elseif SciMLBase.has_sys(prob.f)
15+
varnames = getname.(variable_symbols(prob.f.sys))
16+
varidxs = variable_index.((prob.f.sys,), varnames)
17+
invpermute!(varnames, varidxs)
18+
_vars = [variable(name) for name in varnames]
19+
else
20+
_vars = [variable(:x, i) for i in eachindex(prob.u0)]
21+
end
22+
_vars = reshape(_vars, size(prob.u0))
1123

1224
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0, _vars)
1325
params = if has_p
14-
_params = define_params(p)
26+
if p_names === nothing && SciMLBase.has_sys(prob.f)
27+
p_names = Dict(parameter_index(prob.f.sys, sym) => sym
28+
for sym in parameter_symbols(prob.f.sys))
29+
end
30+
_params = define_params(p, p_names)
1531
p isa Number ? _params[1] :
16-
(p isa Tuple || p isa NamedTuple ? _params :
32+
(p isa Tuple || p isa NamedTuple || p isa AbstractDict || p isa MTKParameters ?
33+
_params :
1734
ArrayInterface.restructure(p, _params))
1835
else
1936
[]
@@ -29,14 +46,25 @@ function modelingtoolkitize(prob::NonlinearProblem; kwargs...)
2946
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)
3047

3148
sts = vec(collect(vars))
32-
49+
_params = params
50+
params = values(params)
3351
params = if params isa Number || (params isa Array && ndims(params) == 0)
3452
[params[1]]
3553
else
3654
vec(collect(params))
3755
end
3856
default_u0 = Dict(sts .=> vec(collect(prob.u0)))
39-
default_p = has_p ? Dict(params .=> vec(collect(prob.p))) : Dict()
57+
default_p = if has_p
58+
if prob.p isa AbstractDict
59+
Dict(v => prob.p[k] for (k, v) in pairs(_params))
60+
elseif prob.p isa MTKParameters
61+
Dict(params .=> reduce(vcat, prob.p))
62+
else
63+
Dict(params .=> vec(collect(prob.p)))
64+
end
65+
else
66+
Dict()
67+
end
4068

4169
de = NonlinearSystem(eqs, sts, params,
4270
defaults = merge(default_u0, default_p);

src/systems/optimization/modelingtoolkitize.jl

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,70 @@ $(TYPEDSIGNATURES)
33
44
Generate `OptimizationSystem`, dependent variables, and parameters from an `OptimizationProblem`.
55
"""
6-
function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem; kwargs...)
6+
function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem;
7+
u_names = nothing, p_names = nothing, kwargs...)
78
num_cons = isnothing(prob.lcons) ? 0 : length(prob.lcons)
89
if prob.p isa Tuple || prob.p isa NamedTuple
910
p = [x for x in prob.p]
1011
else
1112
p = prob.p
1213
end
13-
14-
vars = ArrayInterface.restructure(prob.u0,
15-
[variable(:x, i) for i in eachindex(prob.u0)])
16-
params = if p isa DiffEqBase.NullParameters
17-
[]
18-
elseif p isa MTKParameters
19-
[variable(, i) for i in eachindex(vcat(p...))]
14+
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
15+
if u_names !== nothing
16+
varnames_length_check(prob.u0, u_names; is_unknowns = true)
17+
_vars = [variable(name) for name in u_names]
18+
elseif SciMLBase.has_sys(prob.f)
19+
varnames = getname.(variable_symbols(prob.f.sys))
20+
varidxs = variable_index.((prob.f.sys,), varnames)
21+
invpermute!(varnames, varidxs)
22+
_vars = [variable(name) for name in varnames]
23+
if prob.f.sys isa OptimizationSystem
24+
for (i, sym) in enumerate(variable_symbols(prob.f.sys))
25+
if hasbounds(sym)
26+
_vars[i] = Symbolics.setmetadata(
27+
_vars[i], VariableBounds, getbounds(sym))
28+
end
29+
end
30+
end
31+
else
32+
_vars = [variable(:x, i) for i in eachindex(prob.u0)]
33+
end
34+
_vars = reshape(_vars, size(prob.u0))
35+
vars = ArrayInterface.restructure(prob.u0, _vars)
36+
params = if has_p
37+
if p_names === nothing && SciMLBase.has_sys(prob.f)
38+
p_names = Dict(parameter_index(prob.f.sys, sym) => sym
39+
for sym in parameter_symbols(prob.f.sys))
40+
end
41+
if p isa MTKParameters
42+
old_to_new = Dict()
43+
for sym in parameter_symbols(prob)
44+
idx = parameter_index(prob, sym)
45+
old_to_new[unwrap(sym)] = unwrap(p_names[idx])
46+
end
47+
order = reorder_parameters(prob.f.sys, full_parameters(prob.f.sys))
48+
for arr in order
49+
for i in eachindex(arr)
50+
arr[i] = old_to_new[arr[i]]
51+
end
52+
end
53+
_params = order
54+
else
55+
_params = define_params(p, p_names)
56+
end
57+
p isa Number ? _params[1] :
58+
(p isa Tuple || p isa NamedTuple || p isa AbstractDict || p isa MTKParameters ?
59+
_params :
60+
ArrayInterface.restructure(p, _params))
2061
else
21-
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
62+
[]
2263
end
2364

24-
eqs = prob.f(vars, params)
65+
if p isa MTKParameters
66+
eqs = prob.f(vars, params...)
67+
else
68+
eqs = prob.f(vars, params)
69+
end
2570

2671
if DiffEqBase.isinplace(prob) && !isnothing(prob.f.cons)
2772
lhs = Array{Num}(undef, num_cons)
@@ -58,10 +103,32 @@ function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem; kwargs...)
58103
else
59104
cons = []
60105
end
106+
params = values(params)
107+
params = if params isa Number || (params isa Array && ndims(params) == 0)
108+
[params[1]]
109+
elseif p isa MTKParameters
110+
reduce(vcat, params)
111+
else
112+
vec(collect(params))
113+
end
61114

62-
de = OptimizationSystem(eqs, vec(vars), vec(toparam.(params));
115+
sts = vec(collect(vars))
116+
default_u0 = Dict(sts .=> vec(collect(prob.u0)))
117+
default_p = if has_p
118+
if prob.p isa AbstractDict
119+
Dict(v => prob.p[k] for (k, v) in pairs(_params))
120+
elseif prob.p isa MTKParameters
121+
Dict(params .=> reduce(vcat, prob.p))
122+
else
123+
Dict(params .=> vec(collect(prob.p)))
124+
end
125+
else
126+
Dict()
127+
end
128+
de = OptimizationSystem(eqs, sts, params;
63129
name = gensym(:MTKizedOpt),
64130
constraints = cons,
131+
defaults = merge(default_u0, default_p),
65132
kwargs...)
66133
de
67134
end

0 commit comments

Comments
 (0)