Skip to content

Commit 802d441

Browse files
feat: update clock inference codegen, fix bugs
1 parent dff1b5f commit 802d441

File tree

8 files changed

+63
-41
lines changed

8 files changed

+63
-41
lines changed

src/discretedomain.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct Shift <: Operator
2828
Shift(t, steps = 1) = new(value(t), steps)
2929
end
3030
Shift(steps::Int) = new(nothing, steps)
31-
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps
31+
normalize_to_differential(s::Shift) = Differential(s.t)^abs(s.steps)
3232
function (D::Shift)(x, allow_zero = false)
3333
!allow_zero && D.steps == 0 && return x
3434
Term{symtype(x)}(D, Any[x])
@@ -114,7 +114,7 @@ Base.hash(D::Sample, u::UInt) = hash(D.clock, xor(u, 0x055640d6d952f101))
114114
115115
Returns true if the expression or equation `O` contains [`Sample`](@ref) terms.
116116
"""
117-
hassample(O) = recursive_hasoperator(Sample, O)
117+
hassample(O) = recursive_hasoperator(Sample, unwrap(O))
118118

119119
# Hold
120120

@@ -140,7 +140,7 @@ Hold(x) = Hold()(x)
140140
141141
Returns true if the expression or equation `O` contains [`Hold`](@ref) terms.
142142
"""
143-
hashold(O) = recursive_hasoperator(Hold, O)
143+
hashold(O) = recursive_hasoperator(Hold, unwrap(O))
144144

145145
# ShiftIndex
146146

src/systems/clock_inference.jl

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,15 @@ function split_system(ci::ClockInference{S}) where {S}
139139
return tss, inputs, continuous_id, id_to_clock
140140
end
141141

142-
function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
142+
function generate_discrete_affect(osys::AbstractODESystem, syss, inputs, continuous_id, id_to_clock;
143143
checkbounds = true,
144144
eval_module = @__MODULE__, eval_expression = true)
145145
@static if VERSION < v"1.7"
146146
error("The `generate_discrete_affect` function requires at least Julia 1.7")
147147
end
148148
out = Sym{Any}(:out)
149149
appended_parameters = parameters(syss[continuous_id])
150-
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
150+
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p) for p in appended_parameters)
151151
offset = length(appended_parameters)
152152
affect_funs = []
153153
init_funs = []
@@ -172,11 +172,13 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
172172
vv = arguments(v)[1]
173173
if vv in fullvars
174174
push!(needed_disc_to_cont_obs, vv)
175-
push!(disc_to_cont_idxs, param_to_idx[v])
175+
# @show param_to_idx[v] v
176+
# @assert param_to_idx[v].portion isa SciMLStructures.Discrete # TOOD: remove
177+
push!(disc_to_cont_idxs, param_to_idx[v].idx)
176178
end
177179
end
178180
append!(appended_parameters, input, unknowns(sys))
179-
cont_to_disc_obs = build_explicit_observed_function(syss[continuous_id],
181+
cont_to_disc_obs = build_explicit_observed_function(osys,
180182
needed_cont_to_disc_obs,
181183
throw = false,
182184
expression = true,
@@ -185,30 +187,31 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
185187
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
186188
throw = false,
187189
expression = true,
188-
output_type = SVector)
189-
ni = length(input)
190-
ns = length(unknowns(sys))
190+
output_type = SVector,
191+
ps = reorder_parameters(osys, parameters(sys)))
191192
disc = Func([
192193
out,
193-
DestructuredArgs(unknowns(sys)),
194-
DestructuredArgs(appended_parameters),
194+
DestructuredArgs(unknowns(osys)),
195+
DestructuredArgs.(reorder_parameters(osys, parameters(osys)))...,
196+
# DestructuredArgs(appended_parameters),
195197
get_iv(sys),
196198
], [],
197199
let_block)
198-
cont_to_disc_idxs = (offset + 1):(offset += ni)
199-
input_offset = offset
200-
disc_range = (offset + 1):(offset += ns)
200+
cont_to_disc_idxs = [parameter_index(osys, sym).idx for sym in input]
201+
disc_range = [parameter_index(osys, sym).idx for sym in unknowns(sys)]
201202
save_vec = Expr(:ref, :Float64)
202-
for i in 1:ns
203-
push!(save_vec.args, :(p[$(input_offset + i)]))
203+
for unk in unknowns(sys)
204+
idx = parameter_index(osys, unk).idx
205+
push!(save_vec.args, :(discretes[$idx]))
204206
end
205207
empty_disc = isempty(disc_range)
206-
207208
disc_init = :(function (p, t)
208209
d2c_obs = $disc_to_cont_obs
209-
d2c_view = view(p, $disc_to_cont_idxs)
210-
disc_state = view(p, $disc_range)
211-
copyto!(d2c_view, d2c_obs(disc_state, p, t))
210+
discretes, repack, _ = $(SciMLStructures.canonicalize)($(SciMLStructures.Discrete()), p)
211+
d2c_view = view(discretes, $disc_to_cont_idxs)
212+
disc_state = view(discretes, $disc_range)
213+
copyto!(d2c_view, d2c_obs(disc_state, p..., t))
214+
repack(discretes)
212215
end)
213216

214217
# @show disc_to_cont_idxs
@@ -220,10 +223,11 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
220223
c2d_obs = $cont_to_disc_obs
221224
d2c_obs = $disc_to_cont_obs
222225
# Like Sample
223-
c2d_view = view(p, $cont_to_disc_idxs)
226+
discretes, repack, _ = $(SciMLStructures.canonicalize)($(SciMLStructures.Discrete()), p)
227+
c2d_view = view(discretes, $cont_to_disc_idxs)
224228
# Like Hold
225-
d2c_view = view(p, $disc_to_cont_idxs)
226-
disc_unknowns = view(p, $disc_range)
229+
d2c_view = view(discretes, $disc_to_cont_idxs)
230+
disc_unknowns = view(discretes, $disc_range)
227231
disc = $disc
228232

229233
push!(saved_values.t, t)
@@ -238,12 +242,13 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
238242
# d2c comes last
239243
# @show t
240244
# @show "incoming", p
241-
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
245+
copyto!(c2d_view, c2d_obs(integrator.u, p..., t))
242246
# @show "after c2d", p
243-
$empty_disc || disc(disc_unknowns, disc_unknowns, p, t)
247+
$empty_disc || disc(disc_unknowns, integrator.u, p..., t)
244248
# @show "after state update", p
245-
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))
249+
copyto!(d2c_view, d2c_obs(disc_unknowns, p..., t))
246250
# @show "after d2c", p
251+
repack(discretes)
247252
end)
248253
sv = SavedValues(Float64, Vector{Float64})
249254
push!(affect_funs, affect!)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
800800
u0 = u0_constructor(u0)
801801
end
802802

803-
p = MTKParameters(sys, parammap; toterm = default_toterm)
803+
p = MTKParameters(sys, parammap)
804804

805805
if implicit_dae && du0map !== nothing
806806
ddvs = map(Differential(iv), dvs)
@@ -931,7 +931,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
931931
cbs = process_events(sys; callback, kwargs...)
932932
inits = []
933933
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
934-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
934+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
935935
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
936936
if clock isa Clock
937937
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1035,7 +1035,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10351035
cbs = process_events(sys; callback, kwargs...)
10361036
inits = []
10371037
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1038-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
1038+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
10391039
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10401040
if clock isa Clock
10411041
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1097,7 +1097,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10971097
cbs = process_events(sys; callback, kwargs...)
10981098
inits = []
10991099
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1100-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
1100+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
11011101
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
11021102
if clock isa Clock
11031103
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ struct ODESystem <: AbstractODESystem
172172
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
173173
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
174174
connector_type, preface, cevents, devents, metadata, gui_metadata,
175-
tearing_state, substitutions, complete, index_cache, discrete_subsystems,
176-
solved_unknowns, split_idxs, parent)
175+
tearing_state, substitutions, complete, index_cache,
176+
discrete_subsystems, solved_unknowns, split_idxs, parent)
177177
end
178178
end
179179

src/systems/index_cache.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ function IndexCache(sys::AbstractSystem)
2828
unks = solved_unknowns(sys)
2929
unk_idxs = Dict{UInt, Int}()
3030
for (i, sym) in enumerate(unks)
31-
h = hash(unwrap(sym))
31+
h = getsymbolhash(sym)
3232
unk_idxs[h] = i
33-
setmetadata(sym, SymbolHash, h)
3433
end
3534

3635
disc_buffers = Dict{DataType, Set{BasicSymbolic}}()
@@ -58,6 +57,13 @@ function IndexCache(sys::AbstractSystem)
5857
end
5958
end
6059
end
60+
if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
61+
_, inputs, continuous_id, _ = get_discrete_subsystems(sys)
62+
for par in inputs[continuous_id]
63+
is_parameter(sys, par) || error("Discrete subsytem input is not a parameter")
64+
insert_by_type!(disc_buffers, par)
65+
end
66+
end
6167

6268
all_ps = Set(unwrap.(parameters(sys)))
6369
for (sym, value) in defaults(sys)
@@ -91,8 +97,9 @@ function IndexCache(sys::AbstractSystem)
9197
idx = 1
9298
for (T, buf) in buffers
9399
for p in buf
94-
h = hash(p)
95-
setmetadata(p, SymbolHash, h)
100+
h = getsymbolhash(p)
101+
idxs[h] = idx
102+
h = getsymbolhash(default_toterm(p))
96103
idxs[h] = idx
97104
idx += 1
98105
end

src/systems/parameter_buffer.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@ struct MTKParameters{T, D, C, E, F}
66
dependent_update::F
77
end
88

9-
function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat = false, use_union = false)
9+
function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = false)
1010
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
1111
get_index_cache(sys)
1212
else
1313
error("Cannot create MTKParameters if system does not have index_cache")
1414
end
1515
all_ps = Set(unwrap.(parameters(sys)))
16+
union!(all_ps, default_toterm.(unwrap.(parameters(sys))))
1617
if p isa Vector && !(eltype(p) <: Pair) && !isempty(p)
1718
ps = parameters(sys)
1819
length(p) == length(ps) || error("Invalid parameters")
1920
p = ps .=> p
2021
end
21-
defs = Dict(unwrap(k) => v for (k, v) in defaults(sys) if unwrap(k) in all_ps)
22+
defs = Dict(default_toterm(unwrap(k)) => v for (k, v) in defaults(sys) if unwrap(k) in all_ps)
2223
if p isa SciMLBase.NullParameters
2324
p = defs
2425
else
25-
p = merge(defs, Dict(unwrap(k) => v for (k, v) in p))
26+
extra_params = Dict(unwrap(k) => v for (k, v) in p if !in(unwrap(k), all_ps))
27+
p = merge(defs, Dict(default_toterm(unwrap(k)) => v for (k, v) in p))
28+
p = Dict(k => fixpoint_sub(v, extra_params) for (k, v) in p if !haskey(extra_params, unwrap(k)))
2629
end
2730

2831
tunable_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.param_buffer_sizes)...)
@@ -41,6 +44,10 @@ function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat
4144
elseif haskey(ic.dependent_idx, h)
4245
dep_buffer[ic.dependent_idx[h]] = val
4346
dependencies[wrap(sym)] = wrap(p[sym])
47+
elseif !isequal(default_toterm(sym), sym)
48+
set_value(default_toterm(sym), val)
49+
else
50+
error("Symbol $sym does not have an index")
4451
end
4552
end
4653

src/systems/systemstructure.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
558558
if state.sys isa ODESystem
559559
ci = ModelingToolkit.ClockInference(state)
560560
ModelingToolkit.infer_clocks!(ci)
561+
time_domains = merge(Dict(state.fullvars .=> ci.var_domain), Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
561562
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
562563
cont_io = merge_io(io, inputs[continuous_id])
563564
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
@@ -586,6 +587,8 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
586587
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
587588
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
588589
end
590+
ps = [setmetadata(sym, TimeDomain, get(time_domains, sym, Continuous())) for sym in get_ps(sys)]
591+
@set! sys.ps = ps
589592
else
590593
sys, input_idxs = _structural_simplify!(state, io; simplify, check_consistency,
591594
fully_determined, kwargs...)

test/clock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ ss = structural_simplify(sys);
116116
Tf = 1.0
117117
prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, Tf),
118118
[kp => 1.0; z => 3.0; z(k + 1) => 2.0])
119-
@test sort(prob.p) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud
119+
@test sort(vcat(prob.p...)) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud
120120
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
121121
# For all inputs in parameters, just initialize them to 0.0, and then set them
122122
# in the callback.

0 commit comments

Comments
 (0)