Skip to content

Commit f96fc45

Browse files
authored
Merge pull request #2690 from SciML/inferred_sampletime
add `sampletime` operator
2 parents 6f4bad4 + b95fb3d commit f96fc45

File tree

5 files changed

+67
-17
lines changed

5 files changed

+67
-17
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ export debug_system
269269
#export Continuous, Discrete, sampletime, input_timedomain, output_timedomain
270270
#export has_discrete_domain, has_continuous_domain
271271
#export is_discrete_domain, is_continuous_domain, is_hybrid_domain
272-
export Sample, Hold, Shift, ShiftIndex
272+
export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime
273273
export Clock #, InferredDiscrete,
274274

275275
end # module

src/discretedomain.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
using Symbolics: Operator, Num, Term, value, recursive_hasoperator
22

3+
struct SampleTime <: Operator end
4+
SymbolicUtils.promote_symtype(::Type{SampleTime}, t...) = Real
5+
function SampleTime()
6+
SymbolicUtils.term(SampleTime, type = Real)
7+
end
8+
39
# Shift
410

511
"""
@@ -15,8 +21,6 @@ $(FIELDS)
1521
```jldoctest
1622
julia> using Symbolics
1723
18-
julia> @variables t;
19-
2024
julia> Δ = Shift(t)
2125
(::Shift) (generic function with 2 methods)
2226
```
@@ -176,16 +180,18 @@ end
176180
function (xn::Num)(k::ShiftIndex)
177181
@unpack clock, steps = k
178182
x = value(xn)
179-
t = clock.t
180183
# Verify that the independent variables of k and x match and that the expression doesn't have multiple variables
181184
vars = Symbolics.get_variables(x)
182185
length(vars) == 1 ||
183186
error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
184187
args = Symbolics.arguments(vars[]) # args should be one element vector with the t in x(t)
185188
length(args) == 1 ||
186189
error("Cannot shift an expression with multiple independent variables $x.")
187-
isequal(args[], t) ||
188-
error("Independent variable of $xn is not the same as that of the ShiftIndex $(k.t)")
190+
t = args[]
191+
if hasfield(typeof(clock), :t)
192+
isequal(t, clock.t) ||
193+
error("Independent variable of $xn is not the same as that of the ShiftIndex $(k.t)")
194+
end
189195

190196
# d, _ = propagate_time_domain(xn)
191197
# if d != clock # this is only required if the variable has another clock

src/systems/clock_inference.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,52 @@ function ClockInference(ts::TransformationState)
2121
ClockInference(ts, eq_domain, var_domain, inferred)
2222
end
2323

24+
struct NotInferedTimeDomain end
25+
function error_sample_time(eq)
26+
error("$eq\ncontains `SampleTime` but it is not an infered discrete equation.")
27+
end
28+
function substitute_sample_time(ci::ClockInference)
29+
@unpack ts, eq_domain = ci
30+
eqs = copy(equations(ts))
31+
@assert length(eqs) == length(eq_domain)
32+
for i in eachindex(eqs)
33+
eq = eqs[i]
34+
domain = eq_domain[i]
35+
dt = sampletime(domain)
36+
neweq = substitute_sample_time(eq, dt)
37+
if neweq isa NotInferedTimeDomain
38+
error_sample_time(eq)
39+
end
40+
eqs[i] = neweq
41+
end
42+
@set! ts.sys.eqs = eqs
43+
@set! ci.ts = ts
44+
end
45+
46+
function substitute_sample_time(eq::Equation, dt)
47+
substitute_sample_time(eq.lhs, dt) ~ substitute_sample_time(eq.rhs, dt)
48+
end
49+
50+
function substitute_sample_time(ex, dt)
51+
istree(ex) || return ex
52+
op = operation(ex)
53+
args = arguments(ex)
54+
if op == SampleTime
55+
dt === nothing && return NotInferedTimeDomain()
56+
return dt
57+
else
58+
new_args = similar(args)
59+
for (i, arg) in enumerate(args)
60+
ex_arg = substitute_sample_time(arg, dt)
61+
if ex_arg isa NotInferedTimeDomain
62+
return ex_arg
63+
end
64+
new_args[i] = ex_arg
65+
end
66+
similarterm(ex, op, new_args; metadata = metadata(ex))
67+
end
68+
end
69+
2470
function infer_clocks!(ci::ClockInference)
2571
@unpack ts, eq_domain, var_domain, inferred = ci
2672
@unpack var_to_diff, graph = ts.structure
@@ -66,6 +112,7 @@ function infer_clocks!(ci::ClockInference)
66112
end
67113
end
68114

115+
ci = substitute_sample_time(ci)
69116
return ci
70117
end
71118

src/systems/systemstructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
627627
kwargs...)
628628
if state.sys isa ODESystem
629629
ci = ModelingToolkit.ClockInference(state)
630-
ModelingToolkit.infer_clocks!(ci)
630+
ci = ModelingToolkit.infer_clocks!(ci)
631631
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
632632
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
633633
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)

test/clock.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ using ModelingToolkitStandardLibrary.Blocks
330330

331331
dt = 0.05
332332
d = Clock(t, dt)
333-
k = ShiftIndex(d)
333+
k = ShiftIndex()
334334

335335
@mtkmodel DiscretePI begin
336336
@components begin
@@ -347,7 +347,7 @@ k = ShiftIndex(d)
347347
y(t)
348348
end
349349
@equations begin
350-
x(k) ~ x(k - 1) + ki * u(k)
350+
x(k) ~ x(k - 1) + ki * u(k) * SampleTime() / dt
351351
output.u(k) ~ y(k)
352352
input.u(k) ~ u(k)
353353
y(k) ~ x(k - 1) + kp * u(k)
@@ -364,21 +364,18 @@ end
364364
end
365365
end
366366

367-
@mtkmodel Holder begin
368-
@components begin
369-
input = RealInput()
370-
output = RealOutput()
371-
end
367+
@mtkmodel ZeroOrderHold begin
368+
@extend u, y = siso = Blocks.SISO()
372369
@equations begin
373-
output.u ~ Hold(input.u)
370+
y ~ Hold(u)
374371
end
375372
end
376373

377374
@mtkmodel ClosedLoop begin
378375
@components begin
379376
plant = FirstOrder(k = 0.3, T = 1)
380377
sampler = Sampler()
381-
holder = Holder()
378+
holder = ZeroOrderHold()
382379
controller = DiscretePI(kp = 2, ki = 2)
383380
feedback = Feedback()
384381
ref = Constant(k = 0.5)
@@ -444,7 +441,7 @@ prob = ODEProblem(ssys,
444441
[model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0],
445442
(0.0, Tf))
446443
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
447-
@test int.ps[Hold(ssys.holder.input.u)] == 2 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
444+
@test_broken int.ps[Hold(ssys.holder.input.u)] == 2 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
448445
@test int.ps[ssys.controller.x] == 1 # c2d
449446
@test int.ps[Sample(d)(ssys.sampler.input.u)] == 0 # disc state
450447
sol = solve(prob,

0 commit comments

Comments
 (0)