Skip to content

Commit ac4b237

Browse files
feat: allow specifying variable names in modelingtoolkitize
1 parent 4e91484 commit ac4b237

File tree

3 files changed

+187
-46
lines changed

3 files changed

+187
-46
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,44 @@ $(TYPEDSIGNATURES)
33
44
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
55
"""
6-
function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
6+
function modelingtoolkitize(
7+
prob::DiffEqBase.ODEProblem; u_names = nothing, p_names = nothing, kwargs...)
78
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
89
return prob.f.sys
9-
@parameters t
10-
10+
t = t_nounits
1111
p = prob.p
1212
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
1313

14-
_vars = define_vars(prob.u0, t)
14+
if u_names !== nothing
15+
varnames_length_check(prob.u0, u_names; is_unknowns = true)
16+
_vars = [_defvar(name)(t) for name in u_names]
17+
elseif SciMLBase.has_sys(prob.f)
18+
varnames = getname.(variable_symbols(prob.f.sys))
19+
varidxs = variable_index.((prob.f.sys,), varnames)
20+
invpermute!(varnames, varidxs)
21+
_vars = [_defvar(name)(t) for name in varnames]
22+
else
23+
_vars = define_vars(prob.u0, t)
24+
end
1525

1626
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0, _vars)
1727
params = if has_p
18-
_params = define_params(p)
28+
if p_names === nothing && SciMLBase.has_sys(prob.f)
29+
p_names = Dict(parameter_index(prob.f.sys, sym) => sym
30+
for sym in parameter_symbols(prob.f.sys))
31+
end
32+
_params = define_params(p, p_names)
1933
p isa Number ? _params[1] :
20-
(p isa Tuple || p isa NamedTuple || p isa AbstractDict ? _params :
34+
(p isa Tuple || p isa NamedTuple || p isa AbstractDict || p isa MTKParameters ?
35+
_params :
2136
ArrayInterface.restructure(p, _params))
2237
else
2338
[]
2439
end
2540

2641
var_set = Set(vars)
2742

28-
D = Differential(t)
43+
D = D_nounits
2944
mm = prob.f.mass_matrix
3045

3146
if mm === I
@@ -70,6 +85,8 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
7085
default_p = if has_p
7186
if prob.p isa AbstractDict
7287
Dict(v => prob.p[k] for (k, v) in pairs(_params))
88+
elseif prob.p isa MTKParameters
89+
Dict(params .=> reduce(vcat, prob.p))
7390
else
7491
Dict(params .=> vec(collect(prob.p)))
7592
end
@@ -125,44 +142,96 @@ function Base.showerror(io::IO, e::ModelingtoolkitizeParametersNotSupportedError
125142
println(io, e.type)
126143
end
127144

128-
function define_params(p)
145+
function varnames_length_check(vars, names; is_unknowns = false)
146+
if length(names) != length(vars)
147+
throw(ArgumentError("""
148+
Number of $(is_unknowns ? "unknowns" : "parameters") ($(length(vars))) \
149+
does not match number of names ($(length(names))).
150+
"""))
151+
end
152+
end
153+
154+
function define_params(p, _ = nothing)
129155
throw(ModelingtoolkitizeParametersNotSupportedError(typeof(p)))
130156
end
131157

132-
function define_params(p::AbstractArray)
133-
[toparam(variable(, i)) for i in eachindex(p)]
158+
function define_params(p::AbstractArray, names = nothing)
159+
if names === nothing
160+
[toparam(variable(, i)) for i in eachindex(p)]
161+
else
162+
varnames_length_check(p, names)
163+
[toparam(variable(names[i])) for i in eachindex(p)]
164+
end
134165
end
135166

136-
function define_params(p::Number)
137-
[toparam(variable())]
167+
function define_params(p::Number, names = nothing)
168+
if names === nothing
169+
[toparam(variable())]
170+
elseif names isa Union{AbstractArray, AbstractDict}
171+
varnames_length_check(p, names)
172+
[toparam(variable(names[i])) for i in eachindex(p)]
173+
else
174+
[toparam(variable(names))]
175+
end
138176
end
139177

140-
function define_params(p::AbstractDict)
141-
OrderedDict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
178+
function define_params(p::AbstractDict, names = nothing)
179+
if names === nothing
180+
OrderedDict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
181+
else
182+
varnames_length_check(p, names)
183+
OrderedDict(k => toparam(variable(names[k])) for k in keys(p))
184+
end
142185
end
143186

144-
function define_params(p::Union{SLArray, LArray})
145-
[toparam(variable(x)) for x in LabelledArrays.symnames(typeof(p))]
187+
function define_params(p::Union{SLArray, LArray}, names = nothing)
188+
if names === nothing
189+
[toparam(variable(x)) for x in LabelledArrays.symnames(typeof(p))]
190+
else
191+
varnames_length_check(p, names)
192+
[toparam(variable(names[i])) for i in eachindex(p)]
193+
end
146194
end
147195

148-
function define_params(p::Tuple)
149-
tuple((toparam(variable(, i)) for i in eachindex(p))...)
196+
function define_params(p::Tuple, names = nothing)
197+
if names === nothing
198+
tuple((toparam(variable(, i)) for i in eachindex(p))...)
199+
else
200+
varnames_length_check(p, names)
201+
tuple((toparam(variable(names[i])) for i in eachindex(p))...)
202+
end
150203
end
151204

152-
function define_params(p::NamedTuple)
153-
NamedTuple(x => toparam(variable(x)) for x in keys(p))
205+
function define_params(p::NamedTuple, names = nothing)
206+
if names === nothing
207+
NamedTuple(x => toparam(variable(x)) for x in keys(p))
208+
else
209+
varnames_length_check(p, names)
210+
NamedTuple(x => toparam(variable(names[x])) for x in keys(p))
211+
end
154212
end
155213

156-
function define_params(p::MTKParameters)
157-
bufs = (p...,)
158-
i = 1
159-
ps = []
160-
for buf in bufs
161-
for _ in buf
162-
push!(ps, toparam(variable(, i)))
214+
function define_params(p::MTKParameters, names = nothing)
215+
if names === nothing
216+
bufs = (p...,)
217+
i = 1
218+
ps = []
219+
for buf in bufs
220+
for _ in buf
221+
push!(
222+
ps,
223+
if names === nothing
224+
toparam(variable(, i))
225+
else
226+
toparam(variable(names[i]))
227+
end
228+
)
229+
end
163230
end
231+
return identity.(ps)
232+
else
233+
return collect(values(names))
164234
end
165-
return identity.(ps)
166235
end
167236

168237
"""

src/systems/nonlinear/modelingtoolkitize.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,34 @@ $(TYPEDSIGNATURES)
33
44
Generate `NonlinearSystem`, dependent variables, and parameters from an `NonlinearProblem`.
55
"""
6-
function modelingtoolkitize(prob::NonlinearProblem; kwargs...)
6+
function modelingtoolkitize(
7+
prob::NonlinearProblem; u_names = nothing, p_names = nothing, kwargs...)
78
p = prob.p
89
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
910

10-
_vars = reshape([variable(:x, i) for i in eachindex(prob.u0)], size(prob.u0))
11+
if u_names !== nothing
12+
varnames_length_check(prob.u0, u_names; is_unknowns = true)
13+
_vars = [variable(name) for name in u_names]
14+
elseif SciMLBase.has_sys(prob.f)
15+
varnames = getname.(variable_symbols(prob.f.sys))
16+
varidxs = variable_index.((prob.f.sys,), varnames)
17+
invpermute!(varnames, varidxs)
18+
_vars = [variable(name) for name in varnames]
19+
else
20+
_vars = [variable(:x, i) for i in eachindex(prob.u0)]
21+
end
22+
_vars = reshape(_vars, size(prob.u0))
1123

1224
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0, _vars)
1325
params = if has_p
14-
_params = define_params(p)
26+
if p_names === nothing && SciMLBase.has_sys(prob.f)
27+
p_names = Dict(parameter_index(prob.f.sys, sym) => sym
28+
for sym in parameter_symbols(prob.f.sys))
29+
end
30+
_params = define_params(p, p_names)
1531
p isa Number ? _params[1] :
16-
(p isa Tuple || p isa NamedTuple ? _params :
32+
(p isa Tuple || p isa NamedTuple || p isa AbstractDict || p isa MTKParameters ?
33+
_params :
1734
ArrayInterface.restructure(p, _params))
1835
else
1936
[]
@@ -29,14 +46,25 @@ function modelingtoolkitize(prob::NonlinearProblem; kwargs...)
2946
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)
3047

3148
sts = vec(collect(vars))
32-
49+
_params = params
50+
params = values(params)
3351
params = if params isa Number || (params isa Array && ndims(params) == 0)
3452
[params[1]]
3553
else
3654
vec(collect(params))
3755
end
3856
default_u0 = Dict(sts .=> vec(collect(prob.u0)))
39-
default_p = has_p ? Dict(params .=> vec(collect(prob.p))) : Dict()
57+
default_p = if has_p
58+
if prob.p isa AbstractDict
59+
Dict(v => prob.p[k] for (k, v) in pairs(_params))
60+
elseif prob.p isa MTKParameters
61+
Dict(params .=> reduce(vcat, prob.p))
62+
else
63+
Dict(params .=> vec(collect(prob.p)))
64+
end
65+
else
66+
Dict()
67+
end
4068

4169
de = NonlinearSystem(eqs, sts, params,
4270
defaults = merge(default_u0, default_p);

src/systems/optimization/modelingtoolkitize.jl

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,62 @@ $(TYPEDSIGNATURES)
33
44
Generate `OptimizationSystem`, dependent variables, and parameters from an `OptimizationProblem`.
55
"""
6-
function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem; kwargs...)
6+
function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem;
7+
u_names = nothing, p_names = nothing, kwargs...)
78
num_cons = isnothing(prob.lcons) ? 0 : length(prob.lcons)
89
if prob.p isa Tuple || prob.p isa NamedTuple
910
p = [x for x in prob.p]
1011
else
1112
p = prob.p
1213
end
13-
14-
vars = ArrayInterface.restructure(prob.u0,
15-
[variable(:x, i) for i in eachindex(prob.u0)])
16-
params = if p isa DiffEqBase.NullParameters
17-
[]
18-
elseif p isa MTKParameters
19-
[variable(, i) for i in eachindex(vcat(p...))]
14+
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
15+
if u_names !== nothing
16+
varnames_length_check(prob.u0, u_names; is_unknowns = true)
17+
_vars = [variable(name) for name in u_names]
18+
elseif SciMLBase.has_sys(prob.f)
19+
varnames = getname.(variable_symbols(prob.f.sys))
20+
varidxs = variable_index.((prob.f.sys,), varnames)
21+
invpermute!(varnames, varidxs)
22+
_vars = [variable(name) for name in varnames]
2023
else
21-
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
24+
_vars = [variable(:x, i) for i in eachindex(prob.u0)]
25+
end
26+
_vars = reshape(_vars, size(prob.u0))
27+
vars = ArrayInterface.restructure(prob.u0, _vars)
28+
params = if has_p
29+
if p_names === nothing && SciMLBase.has_sys(prob.f)
30+
p_names = Dict(parameter_index(prob.f.sys, sym) => sym
31+
for sym in parameter_symbols(prob.f.sys))
32+
end
33+
if p isa MTKParameters
34+
old_to_new = Dict()
35+
for sym in parameter_symbols(prob)
36+
idx = parameter_index(prob, sym)
37+
old_to_new[unwrap(sym)] = unwrap(p_names[idx])
38+
end
39+
order = reorder_parameters(prob.f.sys, full_parameters(prob.f.sys))
40+
for arr in order
41+
for i in eachindex(arr)
42+
arr[i] = old_to_new[arr[i]]
43+
end
44+
end
45+
_params = order
46+
else
47+
_params = define_params(p, p_names)
48+
end
49+
p isa Number ? _params[1] :
50+
(p isa Tuple || p isa NamedTuple || p isa AbstractDict || p isa MTKParameters ?
51+
_params :
52+
ArrayInterface.restructure(p, _params))
53+
else
54+
[]
2255
end
2356

24-
eqs = prob.f(vars, params)
57+
if p isa MTKParameters
58+
eqs = prob.f(vars, params...)
59+
else
60+
eqs = prob.f(vars, params)
61+
end
2562

2663
if DiffEqBase.isinplace(prob) && !isnothing(prob.f.cons)
2764
lhs = Array{Num}(undef, num_cons)
@@ -58,8 +95,15 @@ function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem; kwargs...)
5895
else
5996
cons = []
6097
end
61-
62-
de = OptimizationSystem(eqs, vec(vars), vec(toparam.(params));
98+
params = values(params)
99+
params = if params isa Number || (params isa Array && ndims(params) == 0)
100+
[params[1]]
101+
elseif p isa MTKParameters
102+
reduce(vcat, params)
103+
else
104+
vec(collect(params))
105+
end
106+
de = OptimizationSystem(eqs, vec(vars), params;
63107
name = gensym(:MTKizedOpt),
64108
constraints = cons,
65109
kwargs...)

0 commit comments

Comments
 (0)