Skip to content

Commit 77d57f0

Browse files
authored
Merge pull request #186 from SciML/bgc/ae_hack
`Parameter` type optimizations
2 parents 588470a + 46985fc commit 77d57f0

File tree

7 files changed

+128
-28
lines changed

7 files changed

+128
-28
lines changed

src/Blocks/sources.jl

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,10 @@ end
428428
struct Parameter{T <: Real}
429429
data::Vector{T}
430430
ref::T
431-
n::Int
432431
end
433432

434433
function Base.isequal(x::Parameter, y::Parameter)
435-
b0 = x.n == y.n
434+
b0 = length(x.data) == length(y.data)
436435
if b0
437436
b1 = all(x.data .== y.data)
438437
b2 = x.ref == y.ref
@@ -465,7 +464,7 @@ Base.:^(x::Parameter, y::Parameter) = Base.power_by_squaring(x.ref, y.ref)
465464
Base.isless(x::Parameter, y::Number) = Base.isless(x.ref, y)
466465
Base.isless(y::Number, x::Parameter) = Base.isless(y, x.ref)
467466

468-
Base.copy(x::Parameter{T}) where {T} = Parameter{T}(copy(x.data), x.ref, x.n)
467+
Base.copy(x::Parameter{T}) where {T} = Parameter{T}(copy(x.data), x.ref)
469468

470469
function Base.show(io::IO, m::MIME"text/plain", p::Parameter)
471470
if !isempty(p.data)
@@ -484,9 +483,8 @@ function Parameter(x::T; tofloat = true) where {T <: Real}
484483
P = T
485484
end
486485

487-
return Parameter(P[], x, 0)
486+
return Parameter(P[], x)
488487
end
489-
Parameter(x::Vector{T}, dt::T) where {T <: Real} = Parameter(x, dt, length(x))
490488

491489
function get_sampled_data(t, memory::Parameter{T}) where {T}
492490
if t < 0
@@ -505,18 +503,19 @@ function get_sampled_data(t, memory::Parameter{T}) where {T}
505503
i2 = i1 + 1
506504

507505
t1 = (i1 - 1) * memory.ref
508-
x1 = @inbounds getindex(memory.data, i1)
506+
x1 = @inbounds memory.data[i1]
509507

510508
if t == t1
511509
return x1
512510
else
513-
if i2 > memory.n
514-
i2 = memory.n
511+
n = length(memory.data)
512+
if i2 > n
513+
i2 = n
515514
i1 = i2 - 1
516515
end
517516

518517
t2 = (i2 - 1) * memory.ref
519-
x2 = @inbounds getindex(memory.data, i2)
518+
x2 = @inbounds memory.data[i2]
520519
return linear_interpolation(x1, x2, t1, t2, t)
521520
end
522521
end
@@ -535,10 +534,13 @@ function first_order_backwards_difference(t, memory)
535534
end
536535

537536
function Symbolics.derivative(::typeof(get_sampled_data), args::NTuple{2, Any}, ::Val{1})
538-
first_order_backwards_difference(args[1], args[2])
537+
t = @inbounds args[1]
538+
memory = @inbounds args[2]
539+
first_order_backwards_difference(t, memory)
539540
end
540541

541542
SampledData(T::Type; name) = SampledData(T[], zero(T); name)
543+
SampledData(dt::T) where {T <: Real} = SampledData(T[], dt; name)
542544
function SampledData(data::Vector{T}, dt::T; name) where {T <: Real}
543545
SampledData(; name, buffer = Parameter(data, dt))
544546
end
@@ -571,3 +573,45 @@ end
571573
@deprecate Input SampledData
572574

573575
Base.convert(::Type{T}, x::Parameter{T}) where {T <: Real} = x.ref
576+
577+
# Beta Code for potential AE Hack ----------------------
578+
function set_sampled_data!(memory::Parameter{T}, t, x, Δt::Parameter{T}) where {T}
579+
if t < 0
580+
t = zero(t)
581+
end
582+
583+
if t == zero(t)
584+
empty!(memory.data)
585+
end
586+
587+
n = length(memory.data)
588+
i = round(Int, t / Δt) + 1 #expensive
589+
if i == n + 1
590+
push!(memory.data, x)
591+
elseif i <= n
592+
@inbounds memory.data[i] = x
593+
else
594+
error("Memory buffer skipped a step: n=$n, i=$i")
595+
end
596+
597+
# memory.ref = Δt
598+
599+
return x
600+
end
601+
Symbolics.@register_symbolic set_sampled_data!(memory, t, x, Δt)
602+
603+
function Symbolics.derivative(::typeof(set_sampled_data!), args::NTuple{4, Any}, ::Val{2})
604+
memory = @inbounds args[1]
605+
t = @inbounds args[2]
606+
x = @inbounds args[3]
607+
Δt = @inbounds args[4]
608+
first_order_backwards_difference(t, x, Δt, memory)
609+
end
610+
Symbolics.derivative(::typeof(set_sampled_data!), args::NTuple{4, Any}, ::Val{3}) = 1 #set_sampled_data returns x, therefore d/dx (x) = 1
611+
612+
function first_order_backwards_difference(t, x, Δt, memory)
613+
x1 = set_sampled_data!(memory, t, x, Δt)
614+
x0 = get_sampled_data(t - Δt, memory)
615+
616+
return (x1 - x0) / Δt
617+
end

src/Hydraulic/IsothermalCompressible/components.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ dm ────► │ │ area
506506

507507
ports = @named begin
508508
port = HydraulicPort(; p_int)
509-
flange = MechanicalPort()
509+
flange = MechanicalPort(; f_int = p_int*area)
510510
damper = ValveBase(; p_a_int = p_int, p_b_int = p_int, area_int = 1, Cd,
511511
Cd_reverse, minimum_area)
512512
end
@@ -528,8 +528,7 @@ dm ────► │ │ area
528528
Δx = ParentScope(x_max) / N
529529
x₀ = ParentScope(x_int)
530530

531-
@named moving_volume = VolumeBase(; p_int, x_int = 0, area, dead_volume = 0, Χ1 = 0,
532-
Χ2 = 1)
531+
@named moving_volume = VolumeBase(; p_int, x_int = 0, area, dead_volume = 0, Χ1 = 0, Χ2 = 1)
533532

534533
volumes = []
535534
for i in 1:N
@@ -582,6 +581,8 @@ dm ────► │ │ area
582581
defaults = [flange.v => 0])
583582
end
584583

584+
585+
585586
@component function SpoolValve(reversible = false; p_a_int, p_b_int, x_int, Cd, d, name)
586587
pars = @parameters begin
587588
p_a_int = p_a_int

src/Mechanical/Translational/components.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
"""
22
Free(; name)
33
4-
Use to close a system that has un-connected `MechanicalPort`'s.
4+
Use to close a system that has un-connected `MechanicalPort`'s where the force should not be zero (i.e. you want to solve for the force to produce the given movement of the port)
55
66
# Connectors:
77
88
- `flange`: 1-dim. translational flange
99
"""
1010
@component function Free(; name)
1111
@named flange = MechanicalPort()
12-
vars = []
12+
vars = @variables f(t) = 0
1313
eqs = [
14-
flange.f ~ 0,
14+
flange.f ~ f
1515
]
1616
return compose(ODESystem(eqs, t, vars, []; name, defaults = [flange.v => 0]),
1717
flange)

src/Mechanical/Translational/sources.jl

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Linear 1D position input source
3636
- `flange`: 1-dim. translational flange
3737
- `s`: real input
3838
"""
39-
@component function Position(; s_0 = 0, name)
39+
@component function Position(solves_force=true; s_0 = 0, name)
4040
systems = @named begin
4141
flange = MechanicalPort()
4242
s = RealInput()
@@ -45,8 +45,50 @@ Linear 1D position input source
4545
pars = @parameters s_0 = s_0
4646
vars = @variables x(t) = s_0
4747

48-
eqs = [D(x) ~ flange.v
49-
s.u ~ x]
48+
eqs = [
49+
D(x) ~ flange.v
50+
s.u ~ x
51+
]
52+
53+
!solves_force && push!(eqs, 0 ~ flange.f)
5054

5155
ODESystem(eqs, t, vars, pars; name, systems, defaults = [flange.v => 0, s.u => s_0])
5256
end
57+
58+
59+
@component function Velocity(solves_force=true; name)
60+
systems = @named begin
61+
flange = MechanicalPort()
62+
v = RealInput()
63+
end
64+
65+
pars = []
66+
vars = []
67+
68+
eqs = [
69+
v.u ~ flange.v
70+
]
71+
72+
!solves_force && push!(eqs, 0 ~ flange.f)
73+
74+
ODESystem(eqs, t, vars, pars; name, systems, defaults = [flange.v => 0])
75+
end
76+
77+
78+
@component function Acceleration(solves_force=true; s_0 = 0, name)
79+
systems = @named begin
80+
flange = MechanicalPort()
81+
a = RealInput()
82+
end
83+
84+
pars = []
85+
vars = @variables v(t) = 0
86+
87+
eqs = [
88+
v ~ flange.v
89+
D(v) ~ a.u]
90+
91+
!solves_force && push!(eqs, 0 ~ flange.f)
92+
93+
ODESystem(eqs, t, vars, pars; name, systems, defaults = [flange.v => 0])
94+
end

src/Mechanical/Translational/utils.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
@connector function MechanicalPort(; name)
2-
pars = []
1+
@connector function MechanicalPort(; name, f_int=0, v_int=0)
2+
pars = @parameters begin
3+
f_int = f_int
4+
v_int = v_int
5+
end
36
vars = @variables begin
4-
v(t)
7+
v(t) = v_int
58
f(t), [connect = Flow]
69
end
7-
ODESystem(Equation[], t, vars, pars, name = name, defaults = [f => 0])
10+
ODESystem(Equation[], t, vars, pars; name, defaults=[f=>f_int])
811
end
912
Base.@doc """
1013
MechanicalPort(;name)

test/Blocks/sources.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ModelingToolkitStandardLibrary.Blocks: smooth_sin, smooth_cos, smooth_damp
66
using OrdinaryDiffEq: ReturnCode.Success
77

88
@parameters t
9+
D = Differential(t)
910

1011
@testset "Constant" begin
1112
@named src = Constant(k = 2)
@@ -413,11 +414,14 @@ end
413414
time = 0:dt:t_end
414415
x = @. time^2 + 1.0
415416

417+
vars = @variables y(t)=1 dy(t)=0 ddy(t)=0
418+
416419
@named src = SampledData(Float64)
417420
@named int = Integrator()
418-
@named iosys = ODESystem([
419-
connect(src.output, int.input),
420-
],
421+
@named iosys = ODESystem([y ~ src.output.u
422+
D(y) ~ dy
423+
D(dy) ~ ddy
424+
connect(src.output, int.input)],
421425
t,
422426
systems = [int, src])
423427
sys = structural_simplify(iosys)
@@ -432,4 +436,6 @@ end
432436

433437
@test sol(time)[src.output.u]x atol=1e-3
434438
@test sol[int.output.u][end]1 / 3 * 10^3 + 10.0 atol=1e-3 # closed-form solution to integral
439+
@test sol[dy][end]2 * time[end] atol=1e-3
440+
@test sol[ddy][end]2 atol=1e-3
435441
end

test/Mechanical/translational.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ D = Differential(t)
1010
@testset "Free" begin
1111
function System(; name)
1212
systems = @named begin
13-
mass = TV.Mass(; m = 100, g = -10)
13+
acc = TV.Acceleration(false)
14+
a = Constant(k = -10)
15+
mass = TV.Mass(; m = 100)
1416
free = TV.Free()
1517
end
1618

17-
eqs = [connect(mass.flange, free.flange)]
19+
eqs = [connect(a.output, acc.a)
20+
connect(mass.flange, acc.flange, free.flange)]
1821

1922
ODESystem(eqs, t, [], []; name, systems)
2023
end
@@ -26,6 +29,7 @@ D = Differential(t)
2629
sol = solve(prob, Rosenbrock23())
2730

2831
@test sol[s.mass.flange.v][end]-0.1 * 10 atol=1e-3
32+
@test sol[s.free.f][end] 100 * 10
2933
end
3034

3135
@testset "spring damper mass fixed" begin

0 commit comments

Comments
 (0)