Skip to content

Commit 150a01a

Browse files
Merge pull request #2547 from AayushSabharwal/as/fix-io
fix: fix io handling in structural_simplify, input_output_handling tests
2 parents f42a6f3 + 1d3cc2b commit 150a01a

File tree

6 files changed

+13
-10
lines changed

6 files changed

+13
-10
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -870,12 +870,15 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
870870
if eltype(u0map) <: Number
871871
u0map = unknowns(sys) .=> u0map
872872
end
873+
if isempty(u0map)
874+
u0map = Dict()
875+
end
873876
initializeprob = ModelingToolkit.InitializationProblem(
874877
sys, t, u0map, parammap; guesses, warn_initialize_determined)
875878
initializeprobmap = getu(initializeprob, unknowns(sys))
876879

877-
zerovars = setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0
878-
trueinit = identity.([zerovars; u0map])
880+
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
881+
trueinit = collect(merge(zerovars, eltype(u0map) <: Pair ? todict(u0map) : u0map))
879882
u0map isa StaticArraysCore.StaticArray &&
880883
(trueinit = SVector{length(trueinit)}(trueinit))
881884
else
@@ -913,7 +916,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
913916
du0 = nothing
914917
ddvs = nothing
915918
end
916-
917919
check_eqs_u0(eqs, dvs, u0; kwargs...)
918920

919921
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,

src/systems/systemstructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
630630
if has_io
631631
ModelingToolkit.markio!(state, orig_inputs, io...)
632632
end
633-
if io !== nothing || any(isinput, state.fullvars)
633+
if io !== nothing
634634
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
635635
else
636636
input_idxs = 0:-1 # Empty range

test/input_output_handling.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ end
5050
@test !is_bound(sys31, sys1.v[2])
5151

5252
# simplification turns input variables into parameters
53-
ssys = structural_simplify(sys)
53+
ssys, _ = structural_simplify(sys, ([u], []))
5454
@test ModelingToolkit.isparameter(unbound_inputs(ssys)[])
5555
@test !is_bound(ssys, u)
5656
@test u Set(unbound_inputs(ssys))
@@ -236,7 +236,7 @@ i = findfirst(isequal(u[1]), out)
236236
@variables x(t) u(t) [input = true]
237237
eqs = [D(x) ~ u]
238238
@named sys = ODESystem(eqs, t)
239-
@test_nowarn structural_simplify(sys)
239+
@test_nowarn structural_simplify(sys, ([u], []))
240240

241241
#=
242242
## Disturbance input handling

test/reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ eqs = [D(x) ~ σ * (y - x)
234234
u ~ z + a]
235235

236236
lorenz1 = ODESystem(eqs, t, name = :lorenz1)
237-
lorenz1_reduced = structural_simplify(lorenz1)
237+
lorenz1_reduced, _ = structural_simplify(lorenz1, ([z], []))
238238
@test z in Set(parameters(lorenz1_reduced))
239239

240240
# #2064

test/serialization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@ probexpr = ODEProblemExpr{true}(ss, [capacitor.v => 0.0], (0, 0.1); observedfun_
6565
prob_obs = eval(probexpr)
6666
sol_obs = solve(prob_obs, ImplicitEuler())
6767
@show all_obs
68-
@test_broken sol_obs[all_obs] == sol[all_obs]
68+
@test sol_obs[all_obs] == sol[all_obs]

test/structural_transformation/tearing.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,10 @@ prob.f(du, u, pr, tt)
164164
@test du[u[2], u[1] + sin(u[2]) - pr * tt] atol=1e-5
165165

166166
# test the initial guess is respected
167-
@named sys = ODESystem(eqs, t, defaults = Dict(z => Inf))
167+
@named sys = ODESystem(eqs, t, defaults = Dict(z => NaN))
168168
infprob = ODEProblem(structural_simplify(sys), [x => 1.0], (0, 1.0), [p => 0.2])
169-
@test_throws Any infprob.f(du, infprob.u0, pr, tt)
169+
infprob.f(du, infprob.u0, pr, tt)
170+
@test any(isnan, du)
170171

171172
sol1 = solve(prob, RosShamp4(), reltol = 8e-7)
172173
sol2 = solve(ODEProblem{false}((u, p, t) -> [-asin(u[1] - pr * t)],

0 commit comments

Comments
 (0)