Skip to content

Commit b4246b8

Browse files
feat!: add MTKParameters struct, use as ODEProblem.p
1 parent ec687fd commit b4246b8

File tree

5 files changed

+125
-19
lines changed

5 files changed

+125
-19
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
3939
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
4040
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
4141
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
42+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
4243
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
4344
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
4445
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
@@ -99,6 +100,7 @@ SciMLBase = "2.0.1"
99100
Serialization = "1"
100101
Setfield = "0.7, 0.8, 1"
101102
SimpleNonlinearSolve = "0.1.0, 1"
103+
SciMLStructures = "1.0"
102104
SparseArrays = "1"
103105
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
104106
StaticArrays = "0.10, 0.11, 0.12, 1.0"

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ using PrecompileTools, Reexport
3131
import Distributions
3232
import FunctionWrappersWrappers
3333
using URIs: URI
34+
using SciMLStructures
3435

3536
using RecursiveArrayTools
3637

@@ -132,6 +133,8 @@ include("systems/abstractsystem.jl")
132133
include("systems/model_parsing.jl")
133134
include("systems/connectors.jl")
134135
include("systems/callbacks.jl")
136+
include("systems/index_cache.jl")
137+
include("systems/parameter_buffer.jl")
135138

136139
include("systems/diffeqs/odesystem.jl")
137140
include("systems/diffeqs/sdesystem.jl")

src/clock.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ true if `x` contains only discrete-domain signals.
8383
See also [`has_discrete_domain`](@ref)
8484
"""
8585
function is_discrete_domain(x)
86-
issym(x) && return getmetadata(x, TimeDomain, false) isa Discrete
86+
if hasmetadata(x, TimeDomain) || issym(x)
87+
return getmetadata(x, TimeDomain, false) isa AbstractDiscrete
88+
end
8789
!has_discrete_domain(x) && has_continuous_domain(x)
8890
end
8991

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -160,27 +160,26 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys), ps = par
160160

161161
# TODO: add an optional check on the ordering of observed equations
162162
u = map(x -> time_varying_as_func(value(x), sys), dvs)
163-
p = map(x -> time_varying_as_func(value(x), sys), ps)
163+
p = if has_index_cache(sys)
164+
reorder_parameters(get_index_cache(sys), ps)
165+
else
166+
(map(x -> time_varying_as_func(value(x), sys), ps),)
167+
end
164168
t = get_iv(sys)
165169

166170
if isdde
167-
build_function(rhss, u, DDE_HISTORY_FUN, p, t; kwargs...)
171+
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...)
168172
else
169173
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
170174

171175
if implicit_dae
172-
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre,
176+
build_function(rhss, ddvs, u, p..., t; postprocess_fbody = pre,
173177
states = sol_states,
174178
kwargs...)
175179
else
176-
if p isa Tuple
177-
build_function(rhss, u, p..., t; postprocess_fbody = pre,
178-
states = sol_states,
179-
kwargs...)
180-
else
181-
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
182-
kwargs...)
183-
end
180+
build_function(rhss, u, p..., t; postprocess_fbody = pre,
181+
states = sol_states,
182+
kwargs...)
184183
end
185184
end
186185
end
@@ -321,6 +320,10 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = u
321320
g(u, p, t) = f_oop(u, p..., t)
322321
g(du, u, p, t) = f_iip(du, u, p..., t)
323322
f = g
323+
elseif p isa MTKParameters
324+
h(u, p, t) = f_oop(u, raw_vectors(p)..., t)
325+
h(du, u, p, t) = f_iip(du, u, raw_vectors(p)..., t)
326+
f = h
324327
else
325328
k(u, p, t) = f_oop(u, p, t)
326329
k(du, u, p, t) = f_iip(du, u, p, t)
@@ -758,7 +761,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
758761
ps = parameters(sys)
759762
iv = get_iv(sys)
760763

761-
u0, p, defs = get_u0_p(sys,
764+
u0, _, defs = get_u0_p(sys,
762765
u0map,
763766
parammap;
764767
tofloat,
@@ -768,11 +771,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
768771
u0 = u0_constructor(u0)
769772
end
770773

771-
p, split_idxs = split_parameters_by_type(p)
772-
if p isa Tuple
773-
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
774-
ps = (ps...,) #if p is Tuple, ps should be Tuple
775-
end
774+
p = MTKParameters(sys, parammap; toterm = default_toterm)
776775

777776
if implicit_dae && du0map !== nothing
778777
ddvs = map(Differential(iv), dvs)
@@ -789,7 +788,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
789788
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
790789
checkbounds = checkbounds, p = p,
791790
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
792-
sparse = sparse, eval_expression = eval_expression, split_idxs,
791+
sparse = sparse, eval_expression = eval_expression,
793792
kwargs...)
794793
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
795794
end

src/systems/parameter_buffer.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
struct MTKParameters{T, D}
2+
tunable::T
3+
discrete::D
4+
end
5+
6+
function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm)
7+
ic = if has_index_cache(sys)
8+
get_index_cache(sys)
9+
else
10+
IndexCache(sys)
11+
end
12+
tunable_buffer = if length(ic.param_buffer_type_and_size) == 0
13+
Float64[]
14+
elseif length(ic.param_buffer_type_and_size) == 1
15+
T, sz = only(ic.param_buffer_type_and_size)
16+
Vector{T == Real ? Float64 : T}(undef, sz)
17+
else
18+
ArrayPartition((Vector{T == Real ? Float64 : T}(undef, sz) for (T, sz) in ic.param_buffer_type_and_size)...)
19+
end
20+
21+
disc_buffer = if length(ic.discrete_buffer_type_and_size) == 0
22+
Float64[]
23+
elseif length(ic.discrete_buffer_type_and_size) == 1
24+
T, sz = only(ic.discrete_buffer_type_and_size)
25+
Vector{T == Real ? Float64 : T}(undef, sz)
26+
else
27+
ArrayPartition((Vector{T == Real ? Float64 : T}(undef, sz) for (T, sz) in ic.discrete_buffer_type_and_size)...)
28+
end
29+
30+
for (sym, value) in defaults(sys)
31+
sym = toterm(unwrap(sym))
32+
h = hasmetadata(sym, SymbolHash) ? getmetadata(sym, SymbolHash) : hash(sym)
33+
if haskey(ic.discrete_idx, h)
34+
disc_buffer[ic.discrete_idx[h]] = value
35+
elseif haskey(ic.param_idx, h)
36+
tunable_buffer[ic.param_idx[h]] = value
37+
end
38+
end
39+
40+
if !isa(p, SciMLBase.NullParameters)
41+
for (sym, value) in p
42+
sym = toterm(unwrap(sym))
43+
h = hasmetadata(sym, SymbolHash) ? getmetadata(sym, SymbolHash) : hash(sym)
44+
if haskey(ic.discrete_idx, h)
45+
disc_buffer[ic.discrete_idx[h]] = value
46+
elseif haskey(ic.param_idx, h)
47+
tunable_buffer[ic.param_idx[h]] = value
48+
else
49+
error("Invalid parameter $sym")
50+
end
51+
end
52+
end
53+
54+
return MTKParameters{typeof(tunable_buffer), typeof(disc_buffer)}(tunable_buffer, disc_buffer)
55+
end
56+
57+
SciMLStructures.isscimlstructure(::MTKParameters) = true
58+
59+
SciMLStructures.ismutablescimlstructure(::MTKParameters) = true
60+
61+
for (Portion, field) in [
62+
(SciMLStructures.Tunable, :tunable)
63+
(SciMLStructures.Discrete, :discrete)
64+
]
65+
@eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters)
66+
function repack(values)
67+
p.$field .= values
68+
end
69+
return p.$field, repack, !isa(p.$field, ArrayPartition)
70+
end
71+
72+
@eval function SciMLStructures.replace(::$Portion, p::MTKParameters, newvals)
73+
new_field = similar(p.$field)
74+
new_field .= newvals
75+
@set p.$field = new_field
76+
end
77+
78+
@eval function SciMLStructures.replace!(::$Portion, p::MTKParameters, newvals)
79+
p.$field .= newvals
80+
nothing
81+
end
82+
end
83+
84+
function raw_vectors(buf::MTKParameters)
85+
tunable = if isempty(buf.tunable)
86+
()
87+
elseif buf.tunable isa ArrayPartition
88+
buf.tunable.x
89+
else
90+
(buf.tunable,)
91+
end
92+
discrete = if isempty(buf.discrete)
93+
()
94+
elseif buf.discrete isa ArrayPartition
95+
buf.discrete.x
96+
else
97+
(buf.discrete,)
98+
end
99+
return (tunable..., discrete...)
100+
end

0 commit comments

Comments
 (0)