Skip to content

Commit 2447c8d

Browse files
feat!: do not scalarize parameters, fix some tests
1 parent ea31448 commit 2447c8d

14 files changed

+347
-203
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,6 @@ macro mtkbuild(exprs...)
13971397
else
13981398
(;)
13991399
end
1400-
@show kwargs
14011400
esc(quote
14021401
$named_expr
14031402
$name = $structural_simplify($name; $(kwargs)...)
@@ -1518,13 +1517,12 @@ function linearization_function(sys::AbstractSystem, inputs,
15181517
sys = ssys
15191518
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
15201519
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1520+
ps = parameters(sys)
15211521
if has_index_cache(sys) && get_index_cache(sys) !== nothing
15221522
p = MTKParameters(sys, p)
1523-
ps = reorder_parameters(sys, parameters(sys))
15241523
else
15251524
p = _p
15261525
p, split_idxs = split_parameters_by_type(p)
1527-
ps = parameters(sys)
15281526
if p isa Tuple
15291527
ps = Base.Fix1(getindex, ps).(split_idxs)
15301528
ps = (ps...,) #if p is Tuple, ps should be Tuple

src/systems/callbacks.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
390390
if has_index_cache(sys) && get_index_cache(sys) !== nothing
391391
ic = get_index_cache(sys)
392392
update_inds = map(update_vars) do sym
393-
@unpack portion, idx = parameter_index(sys, sym)
394-
if portion == SciMLStructures.Discrete()
395-
idx += length(ic.param_idx)
396-
end
397-
idx
393+
pind = parameter_index(sys, sym)
394+
discrete_linear_index(ic, pind)
398395
end
399396
else
400397
psind = Dict(reverse(en) for en in enumerate(ps))

src/systems/clock_inference.jl

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ function generate_discrete_affect(
146146
@static if VERSION < v"1.7"
147147
error("The `generate_discrete_affect` function requires at least Julia 1.7")
148148
end
149+
has_index_cache(osys) && get_index_cache(osys) !== nothing || error("System must have index_cache for clock support")
149150
out = Sym{Any}(:out)
150151
appended_parameters = parameters(syss[continuous_id])
151152
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
@@ -169,14 +170,14 @@ function generate_discrete_affect(
169170
push!(fullvars, s)
170171
end
171172
needed_disc_to_cont_obs = []
172-
disc_to_cont_idxs = Int[]
173+
disc_to_cont_idxs = ParameterIndex[]
173174
for v in inputs[continuous_id]
174175
vv = arguments(v)[1]
175176
if vv in fullvars
176177
push!(needed_disc_to_cont_obs, vv)
177178
# @show param_to_idx[v] v
178179
# @assert param_to_idx[v].portion isa SciMLStructures.Discrete # TOOD: remove
179-
push!(disc_to_cont_idxs, param_to_idx[v].idx)
180+
push!(disc_to_cont_idxs, param_to_idx[v])
180181
end
181182
end
182183
append!(appended_parameters, input, unknowns(sys))
@@ -201,39 +202,36 @@ function generate_discrete_affect(
201202
],
202203
[],
203204
let_block)
204-
cont_to_disc_idxs = [parameter_index(osys, sym).idx for sym in input]
205-
disc_range = [parameter_index(osys, sym).idx for sym in unknowns(sys)]
205+
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
206+
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
206207
save_vec = Expr(:ref, :Float64)
207208
for unk in unknowns(sys)
208-
idx = parameter_index(osys, unk).idx
209-
push!(save_vec.args, :(discretes[$idx]))
209+
idx = parameter_index(osys, unk)
210+
push!(save_vec.args, :($(parameter_values)(p, $idx)))
210211
end
211212
empty_disc = isempty(disc_range)
212213
disc_init = :(function (p, t)
213214
d2c_obs = $disc_to_cont_obs
215+
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
216+
result = d2c_obs(disc_state, p..., t)
217+
for (val, i) in zip(result, $disc_to_cont_idxs)
218+
# prevent multiple updates to dependents
219+
_set_parameter_unchecked!(p, val, i; update_dependent = false)
220+
end
214221
discretes, repack, _ = $(SciMLStructures.canonicalize)(
215222
$(SciMLStructures.Discrete()), p)
216-
d2c_view = view(discretes, $disc_to_cont_idxs)
217-
disc_state = view(discretes, $disc_range)
218-
copyto!(d2c_view, d2c_obs(disc_state, p..., t))
219-
repack(discretes)
223+
repack(discretes) # to force recalculation of dependents
220224
end)
221225

222226
# @show disc_to_cont_idxs
223227
# @show cont_to_disc_idxs
224228
# @show disc_range
225-
226229
affect! = :(function (integrator, saved_values)
227230
@unpack u, p, t = integrator
228231
c2d_obs = $cont_to_disc_obs
229232
d2c_obs = $disc_to_cont_obs
230-
# Like Sample
231-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
232-
$(SciMLStructures.Discrete()), p)
233-
c2d_view = view(discretes, $cont_to_disc_idxs)
234-
# Like Hold
235-
d2c_view = view(discretes, $disc_to_cont_idxs)
236-
disc_unknowns = view(discretes, $disc_range)
233+
# TODO: find a way to do this without allocating
234+
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
237235
disc = $disc
238236

239237
push!(saved_values.t, t)
@@ -248,12 +246,25 @@ function generate_discrete_affect(
248246
# d2c comes last
249247
# @show t
250248
# @show "incoming", p
251-
copyto!(c2d_view, c2d_obs(integrator.u, p..., t))
249+
result = c2d_obs(integrator.u, p..., t)
250+
for (val, i) in zip(result, $cont_to_disc_idxs)
251+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
252+
end
252253
# @show "after c2d", p
253-
$empty_disc || disc(disc_unknowns, integrator.u, p..., t)
254+
if !$empty_disc
255+
disc(disc_unknowns, integrator.u, p..., t)
256+
for (val, i) in zip(disc_unknowns, $disc_range)
257+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
258+
end
259+
end
254260
# @show "after state update", p
255-
copyto!(d2c_view, d2c_obs(disc_unknowns, p..., t))
261+
result = d2c_obs(disc_unknowns, p..., t)
262+
for (val, i) in zip(result, $disc_to_cont_idxs)
263+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
264+
end
256265
# @show "after d2c", p
266+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
267+
$(SciMLStructures.Discrete()), p)
257268
repack(discretes)
258269
end)
259270
sv = SavedValues(Float64, Vector{Float64})

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,7 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
167167

168168
# TODO: add an optional check on the ordering of observed equations
169169
u = map(x -> time_varying_as_func(value(x), sys), dvs)
170-
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
171-
reorder_parameters(get_index_cache(sys), ps isa Tuple ? reduce(vcat, ps) : ps)
172-
else
173-
(map(x -> time_varying_as_func(value(x), sys), ps),)
174-
end
170+
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
175171
t = get_iv(sys)
176172

177173
if isdde
@@ -775,6 +771,20 @@ function get_u0_p(sys,
775771
u0, p, defs
776772
end
777773

774+
function get_u0(sys, u0map; symbolic_u0 = false)
775+
dvs = unknowns(sys)
776+
777+
defs = defaults(sys)
778+
defs = mergedefaults(defs, u0map, dvs)
779+
780+
if symbolic_u0
781+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
782+
else
783+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
784+
end
785+
return u0, defs
786+
end
787+
778788
function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
779789
implicit_dae = false, du0map = nothing,
780790
version = nothing, tgrad = false,
@@ -793,20 +803,24 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
793803
ps = parameters(sys)
794804
iv = get_iv(sys)
795805

796-
u0, _p, defs = get_u0_p(sys,
797-
u0map,
798-
parammap;
799-
tofloat,
800-
use_union,
801-
symbolic_u0)
802-
if u0 !== nothing
803-
u0 = u0_constructor(u0)
804-
end
805-
806806
if has_index_cache(sys) && get_index_cache(sys) !== nothing
807+
u0, defs = get_u0(sys, u0map; symbolic_u0)
807808
p = MTKParameters(sys, parammap)
808809
else
809-
p = _p
810+
u0, p, defs = get_u0_p(sys,
811+
u0map,
812+
parammap;
813+
tofloat,
814+
use_union,
815+
symbolic_u0)
816+
p, split_idxs = split_parameters_by_type(p)
817+
if p isa Tuple
818+
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
819+
ps = (ps...,) #if p is Tuple, ps should be Tuple
820+
end
821+
end
822+
if u0 !== nothing
823+
u0 = u0_constructor(u0)
810824
end
811825

812826
if implicit_dae && du0map !== nothing

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,21 +197,18 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
197197
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
198198
deqs = scalarize(deqs)
199199
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
200-
201200
iv′ = value(scalarize(iv))
202-
ps′ = value.(scalarize(ps))
201+
ps′ = ps
203202
ctrl′ = value.(scalarize(controls))
204203
dvs′ = value.(scalarize(dvs))
205204
dvs′ = filter(x -> !isdelay(x, iv), dvs′)
206-
207205
if !(isempty(default_u0) && isempty(default_p))
208206
Base.depwarn(
209207
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
210208
:ODESystem, force = true)
211209
end
212210
defaults = todict(defaults)
213211
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults))
214-
215212
var_to_name = Dict()
216213
process_variables!(var_to_name, defaults, dvs′)
217214
process_variables!(var_to_name, defaults, ps′)
@@ -279,7 +276,7 @@ function ODESystem(eqs, iv; kwargs...)
279276
algevars = setdiff(allunknowns, diffvars)
280277
# the orders here are very important!
281278
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
282-
collect(Iterators.flatten((diffvars, algevars))), ps; kwargs...)
279+
collect(Iterators.flatten((diffvars, algevars))), collect(ps); kwargs...)
283280
end
284281

285282
# NOTE: equality does not check cached Jacobian

src/systems/diffeqs/sdesystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,6 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
407407
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
408408
end
409409
dvs = scalarize.(dvs)
410-
ps = scalarize.(ps)
411410

412411
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, kwargs...)
413412
f_oop, f_iip = eval_expression ?

0 commit comments

Comments
 (0)