Skip to content

Commit decfb51

Browse files
feat: flatten equations to avoid scalarizing array arguments
1 parent 6fbd195 commit decfb51

File tree

6 files changed

+69
-7
lines changed

6 files changed

+69
-7
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,17 @@ function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false)
807807
defs = mergedefaults(defs, parammap, ps)
808808
end
809809
defs = mergedefaults(defs, u0map, dvs)
810+
for (k, v) in defs
811+
if Symbolics.isarraysymbolic(k)
812+
ks = scalarize(k)
813+
length(ks) == length(v) || error("$k has default value $v with unmatched size")
814+
for (kk, vv) in zip(ks, v)
815+
if !haskey(defs, kk)
816+
defs[kk] = vv
817+
end
818+
end
819+
end
820+
end
810821

811822
if symbolic_u0
812823
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
@@ -1415,3 +1426,19 @@ function isisomorphic(sys1::AbstractODESystem, sys2::AbstractODESystem)
14151426
end
14161427
return false
14171428
end
1429+
1430+
function flatten_equations(eqs)
1431+
mapreduce(vcat, eqs; init = Equation[]) do eq
1432+
islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs)
1433+
isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs)
1434+
if islhsarr || isrhsarr
1435+
islhsarr && isrhsarr ||
1436+
error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar")
1437+
size(eq.lhs) == size(eq.rhs) ||
1438+
error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))")
1439+
return collect(eq.lhs) .~ collect(eq.rhs)
1440+
else
1441+
eq
1442+
end
1443+
end
1444+
end

src/systems/diffeqs/odesystem.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,19 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
202202
name === nothing &&
203203
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
204204
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
205-
deqs = reduce(vcat, scalarize(deqs); init = Equation[])
205+
deqs = mapreduce(vcat, deqs; init = Equation[]) do eq
206+
islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs)
207+
isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs)
208+
if islhsarr || isrhsarr
209+
islhsarr && isrhsarr ||
210+
error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar")
211+
size(eq.lhs) == size(eq.rhs) ||
212+
error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))")
213+
return collect(eq.lhs) .~ collect(eq.rhs)
214+
else
215+
eq
216+
end
217+
end
206218
iv′ = value(iv)
207219
ps′ = value.(ps)
208220
ctrl′ = value.(controls)
@@ -284,7 +296,8 @@ function ODESystem(eqs, iv; kwargs...)
284296
for p in ps
285297
if istree(p) && operation(p) === getindex
286298
par = arguments(p)[begin]
287-
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && all(par[i] in ps for i in eachindex(par))
299+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
300+
all(par[i] in ps for i in eachindex(par))
288301
push!(new_ps, par)
289302
else
290303
push!(new_ps, p)

src/systems/diffeqs/sdesystem.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,14 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
171171
gui_metadata = nothing)
172172
name === nothing &&
173173
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
174-
deqs = scalarize(deqs)
174+
deqs = flatten_equations(deqs)
175+
neqs = mapreduce(vcat, neqs) do expr
176+
if expr isa AbstractArray || Symbolics.isarraysymbolic(expr)
177+
collect(expr)
178+
else
179+
expr
180+
end
181+
end
175182
iv′ = value(iv)
176183
dvs′ = value.(dvs)
177184
ps′ = value.(ps)

src/systems/jumps/jumpsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ function JumpSystem(eqs, iv, unknowns, ps;
151151
kwargs...)
152152
name === nothing &&
153153
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
154-
eqs = scalarize(eqs)
154+
eqs = flatten_equations(eqs)
155155
sysnames = nameof.(systems)
156156
if length(unique(sysnames)) != length(sysnames)
157157
throw(ArgumentError("System names must be unique."))

src/systems/parameter_buffer.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
3737
end
3838

3939
for (sym, _) in p
40-
if istree(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[begin])
40+
if istree(sym) && operation(sym) === getindex &&
41+
is_parameter(sys, arguments(sym)[begin])
4142
# error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
4243
end
4344
end

test/odesystem.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,15 +520,29 @@ using SymbolicUtils.Code
520520
using Symbolics: unwrap, wrap, @register_symbolic
521521
foo(a, ms::AbstractVector) = a + sum(ms)
522522
@register_symbolic foo(a, ms::AbstractVector)
523-
@variables t x(t) ms(t)[1:3]
524-
D = Differential(t)
523+
@variables x(t) ms(t)[1:3]
525524
eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)]
526525
@named sys = ODESystem(eqs, t, [x; ms], [])
527526
@named emptysys = ODESystem(Equation[], t)
528527
@mtkbuild outersys = compose(emptysys, sys)
529528
prob = ODEProblem(outersys, [sys.x => 1.0, sys.ms => 1:3], (0, 1.0))
530529
@test_nowarn solve(prob, Tsit5())
531530

531+
# array equations
532+
bar(x, p) = p * x
533+
@register_array_symbolic bar(x::AbstractVector, p::AbstractMatrix) begin
534+
size = size(x)
535+
eltype = promote_type(eltype(x), eltype(p))
536+
end
537+
@parameters p[1:3, 1:3]
538+
eqs = [D(x) ~ foo(x, ms); D(ms) ~ bar(ms, p)]
539+
@named sys = ODESystem(eqs, t)
540+
@named emptysys = ODESystem(Equation[], t)
541+
@mtkbuild outersys = compose(emptysys, sys)
542+
prob = ODEProblem(
543+
outersys, [sys.x => 1.0, sys.ms => 1:3], (0.0, 1.0), [sys.p => ones(3, 3)])
544+
@test_nowarn solve(prob, Tsit5())
545+
532546
# x/x
533547
@variables x(t)
534548
@named sys = ODESystem([D(x) ~ x / x], t)

0 commit comments

Comments
 (0)