Skip to content

Commit 0287c1d

Browse files
AayushSabharwalChrisRackauckas
authored andcommitted
feat!: do not scalarize parameters, fix some tests
1 parent ae6bd49 commit 0287c1d

14 files changed

+349
-200
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,13 +1521,12 @@ function linearization_function(sys::AbstractSystem, inputs,
15211521
sys = ssys
15221522
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
15231523
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1524+
ps = parameters(sys)
15241525
if has_index_cache(sys) && get_index_cache(sys) !== nothing
15251526
p = MTKParameters(sys, p)
1526-
ps = reorder_parameters(sys, parameters(sys))
15271527
else
15281528
p = _p
15291529
p, split_idxs = split_parameters_by_type(p)
1530-
ps = parameters(sys)
15311530
if p isa Tuple
15321531
ps = Base.Fix1(getindex, ps).(split_idxs)
15331532
ps = (ps...,) #if p is Tuple, ps should be Tuple

src/systems/callbacks.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
390390
if has_index_cache(sys) && get_index_cache(sys) !== nothing
391391
ic = get_index_cache(sys)
392392
update_inds = map(update_vars) do sym
393-
@unpack portion, idx = parameter_index(sys, sym)
394-
if portion == SciMLStructures.Discrete()
395-
idx += length(ic.param_idx)
396-
end
397-
idx
393+
pind = parameter_index(sys, sym)
394+
discrete_linear_index(ic, pind)
398395
end
399396
else
400397
psind = Dict(reverse(en) for en in enumerate(ps))

src/systems/clock_inference.jl

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ function generate_discrete_affect(
146146
@static if VERSION < v"1.7"
147147
error("The `generate_discrete_affect` function requires at least Julia 1.7")
148148
end
149+
has_index_cache(osys) && get_index_cache(osys) !== nothing || error("System must have index_cache for clock support")
149150
out = Sym{Any}(:out)
150151
appended_parameters = parameters(syss[continuous_id])
151152
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
@@ -169,14 +170,14 @@ function generate_discrete_affect(
169170
push!(fullvars, s)
170171
end
171172
needed_disc_to_cont_obs = []
172-
disc_to_cont_idxs = Int[]
173+
disc_to_cont_idxs = ParameterIndex[]
173174
for v in inputs[continuous_id]
174175
vv = arguments(v)[1]
175176
if vv in fullvars
176177
push!(needed_disc_to_cont_obs, vv)
177178
# @show param_to_idx[v] v
178179
# @assert param_to_idx[v].portion isa SciMLStructures.Discrete # TOOD: remove
179-
push!(disc_to_cont_idxs, param_to_idx[v].idx)
180+
push!(disc_to_cont_idxs, param_to_idx[v])
180181
end
181182
end
182183
append!(appended_parameters, input, unknowns(sys))
@@ -201,39 +202,36 @@ function generate_discrete_affect(
201202
],
202203
[],
203204
let_block)
204-
cont_to_disc_idxs = [parameter_index(osys, sym).idx for sym in input]
205-
disc_range = [parameter_index(osys, sym).idx for sym in unknowns(sys)]
205+
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
206+
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
206207
save_vec = Expr(:ref, :Float64)
207208
for unk in unknowns(sys)
208-
idx = parameter_index(osys, unk).idx
209-
push!(save_vec.args, :(discretes[$idx]))
209+
idx = parameter_index(osys, unk)
210+
push!(save_vec.args, :($(parameter_values)(p, $idx)))
210211
end
211212
empty_disc = isempty(disc_range)
212213
disc_init = :(function (p, t)
213214
d2c_obs = $disc_to_cont_obs
215+
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
216+
result = d2c_obs(disc_state, p..., t)
217+
for (val, i) in zip(result, $disc_to_cont_idxs)
218+
# prevent multiple updates to dependents
219+
_set_parameter_unchecked!(p, val, i; update_dependent = false)
220+
end
214221
discretes, repack, _ = $(SciMLStructures.canonicalize)(
215222
$(SciMLStructures.Discrete()), p)
216-
d2c_view = view(discretes, $disc_to_cont_idxs)
217-
disc_state = view(discretes, $disc_range)
218-
copyto!(d2c_view, d2c_obs(disc_state, p..., t))
219-
repack(discretes)
223+
repack(discretes) # to force recalculation of dependents
220224
end)
221225

222226
# @show disc_to_cont_idxs
223227
# @show cont_to_disc_idxs
224228
# @show disc_range
225-
226229
affect! = :(function (integrator, saved_values)
227230
@unpack u, p, t = integrator
228231
c2d_obs = $cont_to_disc_obs
229232
d2c_obs = $disc_to_cont_obs
230-
# Like Sample
231-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
232-
$(SciMLStructures.Discrete()), p)
233-
c2d_view = view(discretes, $cont_to_disc_idxs)
234-
# Like Hold
235-
d2c_view = view(discretes, $disc_to_cont_idxs)
236-
disc_unknowns = view(discretes, $disc_range)
233+
# TODO: find a way to do this without allocating
234+
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
237235
disc = $disc
238236

239237
push!(saved_values.t, t)
@@ -248,12 +246,25 @@ function generate_discrete_affect(
248246
# d2c comes last
249247
# @show t
250248
# @show "incoming", p
251-
copyto!(c2d_view, c2d_obs(integrator.u, p..., t))
249+
result = c2d_obs(integrator.u, p..., t)
250+
for (val, i) in zip(result, $cont_to_disc_idxs)
251+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
252+
end
252253
# @show "after c2d", p
253-
$empty_disc || disc(disc_unknowns, integrator.u, p..., t)
254+
if !$empty_disc
255+
disc(disc_unknowns, integrator.u, p..., t)
256+
for (val, i) in zip(disc_unknowns, $disc_range)
257+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
258+
end
259+
end
254260
# @show "after state update", p
255-
copyto!(d2c_view, d2c_obs(disc_unknowns, p..., t))
261+
result = d2c_obs(disc_unknowns, p..., t)
262+
for (val, i) in zip(result, $disc_to_cont_idxs)
263+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
264+
end
256265
# @show "after d2c", p
266+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
267+
$(SciMLStructures.Discrete()), p)
257268
repack(discretes)
258269
end)
259270
sv = SavedValues(Float64, Vector{Float64})

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,7 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
183183

184184
# TODO: add an optional check on the ordering of observed equations
185185
u = map(x -> time_varying_as_func(value(x), sys), dvs)
186-
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
187-
reorder_parameters(get_index_cache(sys), ps isa Tuple ? reduce(vcat, ps) : ps)
188-
else
189-
(map(x -> time_varying_as_func(value(x), sys), ps),)
190-
end
186+
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
191187
t = get_iv(sys)
192188

193189
if isdde
@@ -802,6 +798,23 @@ function get_u0_p(sys,
802798
u0, p, defs
803799
end
804800

801+
function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false)
802+
dvs = unknowns(sys)
803+
ps = parameters(sys)
804+
defs = defaults(sys)
805+
if parammap !== nothing
806+
defs = mergedefaults(defs, parammap, ps)
807+
end
808+
defs = mergedefaults(defs, u0map, dvs)
809+
810+
if symbolic_u0
811+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
812+
else
813+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
814+
end
815+
return u0, defs
816+
end
817+
805818
function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
806819
implicit_dae = false, du0map = nothing,
807820
version = nothing, tgrad = false,
@@ -820,20 +833,24 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
820833
ps = parameters(sys)
821834
iv = get_iv(sys)
822835

823-
u0, _p, defs = get_u0_p(sys,
824-
u0map,
825-
parammap;
826-
tofloat,
827-
use_union,
828-
symbolic_u0)
829-
if u0 !== nothing
830-
u0 = u0_constructor(u0)
831-
end
832-
833836
if has_index_cache(sys) && get_index_cache(sys) !== nothing
837+
u0, defs = get_u0(sys, u0map, parammap; symbolic_u0)
834838
p = MTKParameters(sys, parammap)
835839
else
836-
p = _p
840+
u0, p, defs = get_u0_p(sys,
841+
u0map,
842+
parammap;
843+
tofloat,
844+
use_union,
845+
symbolic_u0)
846+
p, split_idxs = split_parameters_by_type(p)
847+
if p isa Tuple
848+
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
849+
ps = (ps...,) #if p is Tuple, ps should be Tuple
850+
end
851+
end
852+
if u0 !== nothing
853+
u0 = u0_constructor(u0)
837854
end
838855

839856
if implicit_dae && du0map !== nothing

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
202202
ctrl′ = value.(controls)
203203
dvs′ = value.(dvs)
204204
dvs′ = filter(x -> !isdelay(x, iv), dvs′)
205-
206205
if !(isempty(default_u0) && isempty(default_p))
207206
Base.depwarn(
208207
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
209208
:ODESystem, force = true)
210209
end
211210
defaults = todict(defaults)
212211
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults))
213-
214212
var_to_name = Dict()
215213
process_variables!(var_to_name, defaults, dvs′)
216214
process_variables!(var_to_name, defaults, ps′)
@@ -277,7 +275,7 @@ function ODESystem(eqs, iv; kwargs...)
277275
algevars = setdiff(allunknowns, diffvars)
278276
# the orders here are very important!
279277
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
280-
collect(Iterators.flatten((diffvars, algevars))), ps; kwargs...)
278+
collect(Iterators.flatten((diffvars, algevars))), collect(ps); kwargs...)
281279
end
282280

283281
# NOTE: equality does not check cached Jacobian

src/systems/diffeqs/sdesystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,6 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
407407
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
408408
end
409409
dvs = scalarize.(dvs)
410-
ps = scalarize.(ps)
411410

412411
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, kwargs...)
413412
f_oop, f_iip = eval_expression ?

0 commit comments

Comments
 (0)