Skip to content

Commit 4aeaaea

Browse files
Merge pull request #2847 from SciML/nlls_modelingtoolkitize
Handle modelingtoolkitize for nonlinearleastsquaresproblem
2 parents 622408b + 13c427a commit 4aeaaea

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

src/systems/nonlinear/modelingtoolkitize.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ $(TYPEDSIGNATURES)
44
Generate `NonlinearSystem`, dependent variables, and parameters from an `NonlinearProblem`.
55
"""
66
function modelingtoolkitize(
7-
prob::NonlinearProblem; u_names = nothing, p_names = nothing, kwargs...)
7+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem};
8+
u_names = nothing, p_names = nothing, kwargs...)
89
p = prob.p
910
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
1011

@@ -37,13 +38,22 @@ function modelingtoolkitize(
3738
end
3839

3940
if DiffEqBase.isinplace(prob)
40-
rhs = ArrayInterface.restructure(prob.u0, similar(vars, Num))
41-
prob.f(rhs, vars, params)
41+
if prob isa NonlinearLeastSquaresProblem
42+
rhs = ArrayInterface.restructure(
43+
prob.f.resid_prototype, similar(prob.f.resid_prototype, Num))
44+
prob.f(rhs, vars, params)
45+
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(prob.f.resid_prototype)]...)
46+
else
47+
rhs = ArrayInterface.restructure(prob.u0, similar(vars, Num))
48+
prob.f(rhs, vars, params)
49+
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(rhs)]...)
50+
end
51+
4252
else
4353
rhs = prob.f(vars, params)
54+
out_def = prob.f(prob.u0, prob.p)
55+
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)
4456
end
45-
out_def = prob.f(prob.u0, prob.p)
46-
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)
4757

4858
sts = vec(collect(vars))
4959
_params = params

test/modelingtoolkitize.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,3 +473,17 @@ sys = modelingtoolkitize(prob)
473473
end
474474
end
475475
end
476+
477+
## NonlinearLeastSquaresProblem
478+
479+
function nlls!(du, u, p)
480+
du[1] = 2u[1] - 2
481+
du[2] = u[1] - 4u[2]
482+
du[3] = 0
483+
end
484+
u0 = [0.0, 0.0]
485+
prob = NonlinearLeastSquaresProblem(
486+
NonlinearFunction(nlls!, resid_prototype = zeros(3)), u0)
487+
sys = modelingtoolkitize(prob)
488+
@test length(equations(sys)) == 3
489+
@test length(equations(structural_simplify(sys; fully_determined = false))) == 0

0 commit comments

Comments
 (0)