Skip to content

Commit c6a40c0

Browse files
Merge remote-tracking branch 'origin/master' into unkwnsarray
2 parents d2684e6 + 335c7ba commit c6a40c0

File tree

3 files changed

+117
-26
lines changed

3 files changed

+117
-26
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
119119
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
120120
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
121121
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
122+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
122123
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
123124
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
124125
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -137,4 +138,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
137138
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
138139

139140
[targets]
140-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg"]
141+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]

src/systems/parameter_buffer.jl

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -155,35 +155,53 @@ function MTKParameters(
155155
end
156156

157157
function buffer_to_arraypartition(buf)
158-
return ArrayPartition(Tuple(eltype(v) <: AbstractArray ? buffer_to_arraypartition(v) :
159-
v for v in buf))
158+
return ArrayPartition(ntuple(i -> _buffer_to_arrp_helper(buf[i]), Val(length(buf))))
160159
end
161160

162-
function split_into_buffers(raw::AbstractArray, buf; recurse = true)
163-
idx = 1
164-
function _helper(buf_v; recurse = true)
165-
if eltype(buf_v) <: AbstractArray && recurse
166-
return _helper.(buf_v; recurse = false)
167-
else
168-
res = reshape(raw[idx:(idx + length(buf_v) - 1)], size(buf_v))
169-
idx += length(buf_v)
170-
return res
171-
end
172-
end
173-
return Tuple(_helper(buf_v; recurse) for buf_v in buf)
161+
_buffer_to_arrp_helper(v::T) where {T} = _buffer_to_arrp_helper(eltype(T), v)
162+
_buffer_to_arrp_helper(::Type{<:AbstractArray}, v) = buffer_to_arraypartition(v)
163+
_buffer_to_arrp_helper(::Any, v) = v
164+
165+
function _split_helper(buf_v::T, recurse, raw, idx) where {T}
166+
_split_helper(eltype(T), buf_v, recurse, raw, idx)
167+
end
168+
169+
function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{true}, raw, idx)
170+
map(b -> _split_helper(eltype(b), b, Val(false), raw, idx), buf_v)
171+
end
172+
173+
function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{false}, raw, idx)
174+
_split_helper((), buf_v, (), raw, idx)
175+
end
176+
177+
function _split_helper(_, buf_v, _, raw, idx)
178+
res = reshape(raw[idx[]:(idx[] + length(buf_v) - 1)], size(buf_v))
179+
idx[] += length(buf_v)
180+
return res
181+
end
182+
183+
function split_into_buffers(raw::AbstractArray, buf, recurse = Val(true))
184+
idx = Ref(1)
185+
ntuple(i -> _split_helper(buf[i], recurse, raw, idx), Val(length(buf)))
186+
end
187+
188+
function _update_tuple_helper(buf_v::T, raw, idx) where {T}
189+
_update_tuple_helper(eltype(T), buf_v, raw, idx)
190+
end
191+
192+
function _update_tuple_helper(::Type{<:AbstractArray}, buf_v, raw, idx)
193+
ntuple(i -> _update_tuple_helper(buf_v[i], raw, idx), Val(length(buf_v)))
194+
end
195+
196+
function _update_tuple_helper(::Any, buf_v, raw, idx)
197+
copyto!(buf_v, view(raw, idx[]:(idx[] + length(buf_v) - 1)))
198+
idx[] += length(buf_v)
199+
return nothing
174200
end
175201

176202
function update_tuple_of_buffers(raw::AbstractArray, buf)
177-
idx = 1
178-
function _helper(buf_v)
179-
if eltype(buf_v) <: AbstractArray
180-
_helper.(buf_v)
181-
else
182-
copyto!(buf_v, view(raw, idx:(idx + length(buf_v) - 1)))
183-
idx += length(buf_v)
184-
end
185-
end
186-
_helper.(buf)
203+
idx = Ref(1)
204+
ntuple(i -> _update_tuple_helper(buf[i], raw, idx), Val(length(buf)))
187205
end
188206

189207
SciMLStructures.isscimlstructure(::MTKParameters) = true
@@ -213,7 +231,7 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable)
213231
@set! p.$field = split_into_buffers(newvals, p.$field)
214232
if p.dependent_update_oop !== nothing
215233
raw = p.dependent_update_oop(p...)
216-
@set! p.dependent = split_into_buffers(raw, p.dependent; recurse = false)
234+
@set! p.dependent = split_into_buffers(raw, p.dependent, Val(false))
217235
end
218236
p
219237
end

test/mtkparameters.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using SymbolicIndexingInterface
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
55
using OrdinaryDiffEq
66
using ForwardDiff
7+
using JET
78

89
@parameters a b c d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
910
@named sys = ODESystem(
@@ -136,6 +137,77 @@ ps = [p => 1.0] # Value for `d` is missing
136137
@test_throws ModelingToolkit.MissingVariablesError ODEProblem(sys, u0, tspan, ps)
137138
@test_nowarn ODEProblem(sys, u0, tspan, [ps..., d => 1.0])
138139

140+
# JET tests
141+
142+
# scalar parameters only
143+
function level1()
144+
@parameters p1=0.5 [tunable = true] p2=1 [tunable = true] p3=3 [tunable = false] p4=3 [tunable = true] y0=1
145+
@variables x(t)=2 y(t)=y0
146+
D = Differential(t)
147+
148+
eqs = [D(x) ~ p1 * x - p2 * x * y
149+
D(y) ~ -p3 * y + p4 * x * y]
150+
151+
sys = structural_simplify(complete(ODESystem(
152+
eqs, t, tspan = (0, 3.0), name = :sys, parameter_dependencies = [y0 => 2p4])))
153+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys)
154+
end
155+
156+
# scalar and vector parameters
157+
function level2()
158+
@parameters p1=0.5 [tunable = true] (p23[1:2]=[1, 3.0]) [tunable = true] p4=3 [tunable = false] y0=1
159+
@variables x(t)=2 y(t)=y0
160+
D = Differential(t)
161+
162+
eqs = [D(x) ~ p1 * x - p23[1] * x * y
163+
D(y) ~ -p23[2] * y + p4 * x * y]
164+
165+
sys = structural_simplify(complete(ODESystem(
166+
eqs, t, tspan = (0, 3.0), name = :sys, parameter_dependencies = [y0 => 2p4])))
167+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys)
168+
end
169+
170+
# scalar and vector parameters with different scalar types
171+
function level3()
172+
@parameters p1=0.5 [tunable = true] (p23[1:2]=[1, 3.0]) [tunable = true] p4::Int=3 [tunable = true] y0::Int=1
173+
@variables x(t)=2 y(t)=y0
174+
D = Differential(t)
175+
176+
eqs = [D(x) ~ p1 * x - p23[1] * x * y
177+
D(y) ~ -p23[2] * y + p4 * x * y]
178+
179+
sys = structural_simplify(complete(ODESystem(
180+
eqs, t, tspan = (0, 3.0), name = :sys, parameter_dependencies = [y0 => 2p4])))
181+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys)
182+
end
183+
184+
@testset "level$i" for (i, prob) in enumerate([level1(), level2(), level3()])
185+
ps = prob.p
186+
@testset "Type stability of $portion" for portion in [
187+
Tunable(), Discrete(), Constants()]
188+
@test_call canonicalize(portion, ps)
189+
# @inferred canonicalize(portion, ps)
190+
broken = (i [2, 3] && portion == Tunable())
191+
192+
# broken because the size of a vector of vectors can't be determined at compile time
193+
@test_opt broken=broken target_modules=(ModelingToolkit,) canonicalize(
194+
portion, ps)
195+
196+
buffer, repack, alias = canonicalize(portion, ps)
197+
198+
@test_call SciMLStructures.replace(portion, ps, ones(length(buffer)))
199+
@inferred SciMLStructures.replace(portion, ps, ones(length(buffer)))
200+
@test_opt target_modules=(ModelingToolkit,) SciMLStructures.replace(
201+
portion, ps, ones(length(buffer)))
202+
203+
@test_call target_modules=(ModelingToolkit,) SciMLStructures.replace!(
204+
portion, ps, ones(length(buffer)))
205+
@inferred SciMLStructures.replace!(portion, ps, ones(length(buffer)))
206+
@test_opt target_modules=(ModelingToolkit,) SciMLStructures.replace!(
207+
portion, ps, ones(length(buffer)))
208+
end
209+
end
210+
139211
# Issue#2642
140212
@parameters α β γ δ
141213
@variables x(t) y(t)

0 commit comments

Comments
 (0)