Skip to content

Commit c6c96dd

Browse files
fix: do not scalarize in system constructors
1 parent a41a64f commit c6c96dd

File tree

4 files changed

+4
-23
lines changed

4 files changed

+4
-23
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -202,23 +202,10 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
202202
name === nothing &&
203203
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
204204
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
205-
deqs = mapreduce(vcat, deqs; init = Equation[]) do eq
206-
islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs)
207-
isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs)
208-
if islhsarr || isrhsarr
209-
islhsarr && isrhsarr ||
210-
error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar")
211-
size(eq.lhs) == size(eq.rhs) ||
212-
error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))")
213-
return collect(eq.lhs) .~ collect(eq.rhs)
214-
else
215-
eq
216-
end
217-
end
218205
iv′ = value(iv)
219206
ps′ = value.(ps)
220207
ctrl′ = value.(controls)
221-
dvs′ = value.(symbolic_type(dvs) === NotSymbolic() ? dvs : [dvs])
208+
dvs′ = value.(dvs)
222209
dvs′ = filter(x -> !isdelay(x, iv), dvs′)
223210
if !(isempty(default_u0) && isempty(default_p))
224211
Base.depwarn(

src/systems/diffeqs/sdesystem.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,6 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
171171
gui_metadata = nothing)
172172
name === nothing &&
173173
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
174-
deqs = flatten_equations(deqs)
175-
neqs = mapreduce(vcat, neqs) do expr
176-
if expr isa AbstractArray || Symbolics.isarraysymbolic(expr)
177-
collect(expr)
178-
else
179-
expr
180-
end
181-
end
182174
iv′ = value(iv)
183175
dvs′ = value.(dvs)
184176
ps′ = value.(ps)

src/systems/jumps/jumpsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ function JumpSystem(eqs, iv, unknowns, ps;
151151
kwargs...)
152152
name === nothing &&
153153
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
154+
eqs = scalarize.(eqs)
154155
sysnames = nameof.(systems)
155156
if length(unique(sysnames)) != length(sysnames)
156157
throw(ArgumentError("System names must be unique."))

src/systems/systemstructure.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ function TearingState(sys; quick_cancel = false, check = true)
251251
sys = flatten(sys)
252252
ivs = independent_variables(sys)
253253
iv = length(ivs) == 1 ? ivs[1] : nothing
254-
eqs = copy(equations(sys))
254+
# scalarize array equations, without scalarizing arguments to registered functions
255+
eqs = flatten_equations(copy(equations(sys)))
255256
neqs = length(eqs)
256257
dervaridxs = OrderedSet{Int}()
257258
var2idx = Dict{Any, Int}()

0 commit comments

Comments
 (0)