Skip to content

Commit 67f1b77

Browse files
authored
Merge branch 'master' into ox/nomacrotools
2 parents a6f1f07 + 0e04e8a commit 67f1b77

File tree

10 files changed

+398
-85
lines changed

10 files changed

+398
-85
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
- uses: julia-actions/setup-julia@v1
3535
with:
3636
version: ${{ matrix.version }}
37-
- uses: actions/cache@v3
37+
- uses: actions/cache@v4
3838
env:
3939
cache-name: cache-artifacts
4040
with:

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2020
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
2121
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
22+
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
2223
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2324
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
2425
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -73,8 +74,9 @@ Distributed = "1"
7374
Distributions = "0.23, 0.24, 0.25"
7475
DocStringExtensions = "0.7, 0.8, 0.9"
7576
DomainSets = "0.6"
76-
DynamicQuantities = "0.8, 0.9, 0.10"
77+
DynamicQuantities = "^0.11.2"
7778
ExprTools = "0.1.10"
79+
FindFirstFunctions = "1"
7880
ForwardDiff = "0.10.3"
7981
FunctionWrappersWrappers = "0.1"
8082
Graphs = "1.5.2"
@@ -111,6 +113,7 @@ julia = "1.9"
111113
[extras]
112114
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
113115
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
116+
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
114117
ControlSystemsMTK = "687d7614-c7e5-45fc-bfc3-9ee385575c88"
115118
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
116119
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -134,4 +137,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
134137
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
135138

136139
[targets]
137-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg"]
140+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg"]

src/systems/alias_elimination.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ the `constraint`.
153153
mask,
154154
constraint)
155155
eadj = M.row_cols
156-
for i in range
156+
@inbounds for i in range
157157
vertices = eadj[i]
158158
if constraint(length(vertices))
159159
for (j, v) in enumerate(vertices)
@@ -170,7 +170,7 @@ end
170170
range,
171171
mask,
172172
constraint)
173-
for i in range
173+
@inbounds for i in range
174174
row = @view M[i, :]
175175
if constraint(count(!iszero, row))
176176
for (v, val) in enumerate(row)
@@ -382,13 +382,6 @@ end
382382

383383
swap!(v, i, j) = v[i], v[j] = v[j], v[i]
384384

385-
function getcoeff(vars, coeffs, var)
386-
for (vj, v) in enumerate(vars)
387-
v == var && return coeffs[vj]
388-
end
389-
return 0
390-
end
391-
392385
"""
393386
$(SIGNATURES)
394387

src/systems/clock_inference.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
150150
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
151151
offset = length(appended_parameters)
152152
affect_funs = []
153+
init_funs = []
153154
svs = []
154155
clocks = TimeDomain[]
155156
for (i, (sys, input)) in enumerate(zip(syss, inputs))
@@ -202,6 +203,18 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
202203
push!(save_vec.args, :(p[$(input_offset + i)]))
203204
end
204205
empty_disc = isempty(disc_range)
206+
207+
disc_init = :(function (p, t)
208+
d2c_obs = $disc_to_cont_obs
209+
d2c_view = view(p, $disc_to_cont_idxs)
210+
disc_state = view(p, $disc_range)
211+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
212+
end)
213+
214+
# @show disc_to_cont_idxs
215+
# @show cont_to_disc_idxs
216+
# @show disc_range
217+
205218
affect! = :(function (integrator, saved_values)
206219
@unpack u, p, t = integrator
207220
c2d_obs = $cont_to_disc_obs
@@ -212,27 +225,42 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
212225
d2c_view = view(p, $disc_to_cont_idxs)
213226
disc_state = view(p, $disc_range)
214227
disc = $disc
215-
# Write continuous into to discrete: handles `Sample`
216-
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
217-
# Write discrete into to continuous
218-
# get old discrete states
219-
copyto!(d2c_view, d2c_obs(disc_state, p, t))
228+
220229
push!(saved_values.t, t)
221230
push!(saved_values.saveval, $save_vec)
222-
# update discrete states
231+
232+
# Write continuous into to discrete: handles `Sample`
233+
# Write discrete into to continuous
234+
# Update discrete states
235+
236+
# At a tick, c2d must come first
237+
# state update comes in the middle
238+
# d2c comes last
239+
# @show t
240+
# @show "incoming", p
241+
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
242+
# @show "after c2d", p
223243
$empty_disc || disc(disc_state, disc_state, p, t)
244+
# @show "after state update", p
245+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
246+
# @show "after d2c", p
224247
end)
225248
sv = SavedValues(Float64, Vector{Float64})
226249
push!(affect_funs, affect!)
250+
push!(init_funs, disc_init)
227251
push!(svs, sv)
228252
end
229253
if eval_expression
230254
affects = map(affect_funs) do a
231255
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
232256
end
257+
inits = map(init_funs) do a
258+
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
259+
end
233260
else
234261
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
262+
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
235263
end
236264
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
237-
return affects, clocks, svs, appended_parameters, defaults
265+
return affects, inits, clocks, svs, appended_parameters, defaults
238266
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,9 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
945945
has_difference = has_difference,
946946
check_length, kwargs...)
947947
cbs = process_events(sys; callback, has_difference, kwargs...)
948+
inits = []
948949
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
949-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
950+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
950951
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
951952
if clock isa Clock
952953
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -976,7 +977,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
976977
if svs !== nothing
977978
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
978979
end
979-
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
980+
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
981+
if !isempty(inits)
982+
for init in inits
983+
init(prob.p, tspan[1])
984+
end
985+
end
986+
prob
980987
end
981988
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
982989

@@ -1045,8 +1052,9 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10451052
h = h_oop
10461053
u0 = h(p, tspan[1])
10471054
cbs = process_events(sys; callback, has_difference, kwargs...)
1055+
inits = []
10481056
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1049-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
1057+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
10501058
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10511059
if clock isa Clock
10521060
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1075,7 +1083,13 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10751083
if svs !== nothing
10761084
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
10771085
end
1078-
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1086+
prob = DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1087+
if !isempty(inits)
1088+
for init in inits
1089+
init(prob.p, tspan[1])
1090+
end
1091+
end
1092+
prob
10791093
end
10801094

10811095
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -1099,8 +1113,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10991113
h(p, t) = h_oop(p, t)
11001114
u0 = h(p, tspan[1])
11011115
cbs = process_events(sys; callback, has_difference, kwargs...)
1116+
inits = []
11021117
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1103-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
1118+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
11041119
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
11051120
if clock isa Clock
11061121
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1140,8 +1155,15 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11401155
else
11411156
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
11421157
end
1143-
SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype =
1158+
prob = SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
1159+
noise_rate_prototype =
11441160
noise_rate_prototype, kwargs1..., kwargs...)
1161+
if !isempty(inits)
1162+
for init in inits
1163+
init(prob.p, tspan[1])
1164+
end
1165+
end
1166+
prob
11451167
end
11461168

11471169
"""

src/systems/model_parsing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ function generate_var!(dict, a, b, varclass;
219219
vd isa Vector && (vd = first(vd))
220220
vd[a] = Dict{Symbol, Any}()
221221
var = if indices === nothing
222-
Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)
222+
Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Any}, Real})(iv)
223223
else
224224
vd[a][:size] = Tuple(lastindex.(indices))
225225
first(@variables $a(iv)[indices...])

src/systems/sparsematrixclil.jl

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ end
129129
# build something that works for us here and worry about it later.
130130
nonzerosmap(a::CLILVector) = NonZeros(a)
131131

132+
using FindFirstFunctions: findfirstequal
133+
132134
function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot,
133135
last_pivot; pivot_equal_optimization = true)
134136
# for ei in nzrows(>= k)
@@ -168,12 +170,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
168170
# conservative, we leave it at this, as this captures the most important
169171
# case for MTK (where most pivots are `1` or `-1`).
170172
pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot)
171-
172-
for ei in (k + 1):size(M, 1)
173+
@inbounds for ei in (k + 1):size(M, 1)
173174
# eliminate `v`
174175
coeff = 0
175176
ivars = eadj[ei]
176-
vj = findfirst(isequal(vpivot), ivars)
177+
vj = findfirstequal(vpivot, ivars)
177178
if vj !== nothing
178179
coeff = old_cadj[ei][vj]
179180
deleteat!(old_cadj[ei], vj)
@@ -189,24 +190,118 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
189190
ivars = eadj[ei]
190191
icoeffs = old_cadj[ei]
191192

192-
tmp_incidence = similar(eadj[ei], 0)
193-
tmp_coeffs = similar(old_cadj[ei], 0)
194-
# TODO: We know both ivars and kvars are sorted, we could just write
195-
# a quick iterator here that does this without allocation/faster.
196-
vars = sort(union(ivars, kvars))
197-
198-
for v in vars
199-
v == vpivot && continue
200-
ck = getcoeff(kvars, kcoeffs, v)
201-
ci = getcoeff(ivars, icoeffs, v)
202-
p1 = Base.Checked.checked_mul(pivot, ci)
203-
p2 = Base.Checked.checked_mul(coeff, ck)
204-
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
205-
if !iszero(ci)
206-
push!(tmp_incidence, v)
207-
push!(tmp_coeffs, ci)
193+
numkvars = length(kvars)
194+
numivars = length(ivars)
195+
tmp_incidence = similar(eadj[ei], numkvars + numivars)
196+
tmp_coeffs = similar(old_cadj[ei], numkvars + numivars)
197+
tmp_len = 0
198+
kvind = ivind = 0
199+
if _debug_mode
200+
# in debug mode, we at least check to confirm we're iterating over
201+
# `v`s in the correct order
202+
vars = sort(union(ivars, kvars))
203+
vi = 0
204+
end
205+
if numivars > 0 && numkvars > 0
206+
kvv = kvars[kvind += 1]
207+
ivv = ivars[ivind += 1]
208+
dobreak = false
209+
while true
210+
if kvv == ivv
211+
v = kvv
212+
ck = kcoeffs[kvind]
213+
ci = icoeffs[ivind]
214+
kvind += 1
215+
ivind += 1
216+
if kvind > numkvars
217+
dobreak = true
218+
else
219+
kvv = kvars[kvind]
220+
end
221+
if ivind > numivars
222+
dobreak = true
223+
else
224+
ivv = ivars[ivind]
225+
end
226+
p1 = Base.Checked.checked_mul(pivot, ci)
227+
p2 = Base.Checked.checked_mul(coeff, ck)
228+
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
229+
elseif kvv < ivv
230+
v = kvv
231+
ck = kcoeffs[kvind]
232+
kvind += 1
233+
if kvind > numkvars
234+
dobreak = true
235+
else
236+
kvv = kvars[kvind]
237+
end
238+
p2 = Base.Checked.checked_mul(coeff, ck)
239+
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
240+
else # kvv > ivv
241+
v = ivv
242+
ci = icoeffs[ivind]
243+
ivind += 1
244+
if ivind > numivars
245+
dobreak = true
246+
else
247+
ivv = ivars[ivind]
248+
end
249+
ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot)
250+
end
251+
if _debug_mode
252+
@assert v == vars[vi += 1]
253+
end
254+
if v != vpivot && !iszero(ci)
255+
tmp_incidence[tmp_len += 1] = v
256+
tmp_coeffs[tmp_len] = ci
257+
end
258+
dobreak && break
259+
end
260+
elseif numkvars > 0
261+
ivind = 1
262+
kvv = kvars[kvind += 1]
263+
elseif numivars > 0
264+
kvind = 1
265+
ivv = ivars[ivind += 1]
266+
end
267+
if kvind <= numkvars
268+
v = kvv
269+
while true
270+
if _debug_mode
271+
@assert v == vars[vi += 1]
272+
end
273+
if v != vpivot
274+
ck = kcoeffs[kvind]
275+
p2 = Base.Checked.checked_mul(coeff, ck)
276+
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
277+
if !iszero(ci)
278+
tmp_incidence[tmp_len += 1] = v
279+
tmp_coeffs[tmp_len] = ci
280+
end
281+
end
282+
(kvind == numkvars) && break
283+
v = kvars[kvind += 1]
284+
end
285+
elseif ivind <= numivars
286+
v = ivv
287+
while true
288+
if _debug_mode
289+
@assert v == vars[vi += 1]
290+
end
291+
if v != vpivot
292+
p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind])
293+
ci = exactdiv(p1, last_pivot)
294+
if !iszero(ci)
295+
tmp_incidence[tmp_len += 1] = v
296+
tmp_coeffs[tmp_len] = ci
297+
end
298+
end
299+
(ivind == numivars) && break
300+
v = ivars[ivind += 1]
208301
end
209302
end
303+
resize!(tmp_incidence, tmp_len)
304+
resize!(tmp_coeffs, tmp_len)
210305
eadj[ei] = tmp_incidence
211306
old_cadj[ei] = tmp_coeffs
212307
end

0 commit comments

Comments
 (0)