Skip to content

Commit 1a83c92

Browse files
fixup! feat: allow specifying variable names in modelingtoolkitize
1 parent 6852141 commit 1a83c92

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

src/systems/optimization/modelingtoolkitize.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem;
2020
varidxs = variable_index.((prob.f.sys,), varnames)
2121
invpermute!(varnames, varidxs)
2222
_vars = [variable(name) for name in varnames]
23+
if prob.f.sys isa OptimizationSystem
24+
for (i, sym) in enumerate(variable_symbols(prob.f.sys))
25+
if hasbounds(sym)
26+
_vars[i] = Symbolics.setmetadata(_vars[i], VariableBounds, getbounds(sym))
27+
end
28+
end
29+
end
2330
else
2431
_vars = [variable(:x, i) for i in eachindex(prob.u0)]
2532
end
@@ -103,9 +110,24 @@ function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem;
103110
else
104111
vec(collect(params))
105112
end
106-
de = OptimizationSystem(eqs, vec(vars), params;
113+
114+
sts = vec(collect(vars))
115+
default_u0 = Dict(sts .=> vec(collect(prob.u0)))
116+
default_p = if has_p
117+
if prob.p isa AbstractDict
118+
Dict(v => prob.p[k] for (k, v) in pairs(_params))
119+
elseif prob.p isa MTKParameters
120+
Dict(params .=> reduce(vcat, prob.p))
121+
else
122+
Dict(params .=> vec(collect(prob.p)))
123+
end
124+
else
125+
Dict()
126+
end
127+
de = OptimizationSystem(eqs, sts, params;
107128
name = gensym(:MTKizedOpt),
108129
constraints = cons,
130+
defaults = merge(default_u0, default_p),
109131
kwargs...)
110132
de
111133
end

0 commit comments

Comments
 (0)