1
- struct MTKParameters{T, D}
1
+ struct MTKParameters{T, D, C, E, F }
2
2
tunable:: T
3
3
discrete:: D
4
+ constant:: C
5
+ dependent:: E
6
+ dependent_update:: F
4
7
end
5
8
6
9
function MTKParameters (sys:: AbstractSystem , p; toterm = default_toterm, tofloat = false , use_union = false )
@@ -10,7 +13,7 @@ function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat
10
13
error (" Cannot create MTKParameters if system does not have index_cache" )
11
14
end
12
15
all_ps = Set (unwrap .(parameters (sys)))
13
- if p isa Vector && ! (eltype (p) <: Pair )
16
+ if p isa Vector && ! (eltype (p) <: Pair ) && ! isempty (p)
14
17
ps = parameters (sys)
15
18
length (p) == length (ps) || error (" Invalid parameters" )
16
19
p = ps .=> p
@@ -24,12 +27,20 @@ function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat
24
27
25
28
tunable_buffer = ArrayPartition ((Vector {temp.type} (undef, temp. length) for temp in ic. param_buffer_sizes). .. )
26
29
disc_buffer = ArrayPartition ((Vector {temp.type} (undef, temp. length) for temp in ic. discrete_buffer_sizes). .. )
30
+ const_buffer = ArrayPartition ((Vector {temp.type} (undef, temp. length) for temp in ic. constant_buffer_sizes). .. )
31
+ dep_buffer = ArrayPartition ((Vector {temp.type} (undef, temp. length) for temp in ic. dependent_buffer_sizes). .. )
32
+ dependencies = Dict {Num, Num} ()
27
33
function set_value (sym, val)
28
34
h = getsymbolhash (sym)
29
35
if haskey (ic. param_idx, h)
30
36
tunable_buffer[ic. param_idx[h]] = val
31
37
elseif haskey (ic. discrete_idx, h)
32
38
disc_buffer[ic. discrete_idx[h]] = val
39
+ elseif haskey (ic. constant_idx, h)
40
+ const_buffer[ic. constant_idx[h]] = val
41
+ elseif haskey (ic. dependent_idx, h)
42
+ dep_buffer[ic. dependent_idx[h]] = val
43
+ dependencies[wrap (sym)] = wrap (p[sym])
33
44
end
34
45
end
35
46
@@ -49,22 +60,46 @@ function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat
49
60
end
50
61
end
51
62
63
+ dep_exprs = ArrayPartition ((wrap .(v) for v in dep_buffer. x). .. )
64
+ for (sym, val) in dependencies
65
+ h = getsymbolhash (sym)
66
+ idx = ic. dependent_idx[h]
67
+ dep_exprs[idx] = wrap (fixpoint_sub (val, dependencies))
68
+ end
69
+ p = reorder_parameters (ic, parameters (sys))[begin : end - length (dep_buffer. x)]
70
+ update_function = if isempty (dep_exprs. x)
71
+ (_... ) -> ()
72
+ else
73
+ RuntimeGeneratedFunctions. @RuntimeGeneratedFunction (build_function (dep_exprs, p... )[2 ])
74
+ end
52
75
# everything is an ArrayPartition so it's easy to figure out how many
53
76
# distinct vectors we have for each portion as `ArrayPartition.x`
54
- if tunable_buffer isa ArrayPartition && isempty (tunable_buffer. x) || isempty (tunable_buffer )
77
+ if isempty (tunable_buffer. x)
55
78
tunable_buffer = ArrayPartition (Float64[])
56
79
end
57
- if disc_buffer isa ArrayPartition && isempty (disc_buffer. x) || isempty (disc_buffer )
80
+ if isempty (disc_buffer. x)
58
81
disc_buffer = ArrayPartition (Float64[])
59
82
end
83
+ if isempty (const_buffer. x)
84
+ const_buffer = ArrayPartition (Float64[])
85
+ end
86
+ if isempty (dep_buffer. x)
87
+ dep_buffer = ArrayPartition (Float64[])
88
+ end
60
89
if use_union
61
90
tunable_buffer = ArrayPartition (restrict_array_to_union (tunable_buffer))
62
91
disc_buffer = ArrayPartition (restrict_array_to_union (disc_buffer))
92
+ const_buffer = ArrayPartition (restrict_array_to_union (const_buffer))
93
+ dep_buffer = ArrayPartition (restrict_array_to_union (dep_buffer))
63
94
elseif tofloat
64
95
tunable_buffer = ArrayPartition (Float64 .(tunable_buffer))
65
96
disc_buffer = ArrayPartition (Float64 .(disc_buffer))
97
+ const_buffer = ArrayPartition (Float64 .(const_buffer))
98
+ dep_buffer = ArrayPartition (Float64 .(dep_buffer))
66
99
end
67
- return MTKParameters {typeof(tunable_buffer), typeof(disc_buffer)} (tunable_buffer, disc_buffer)
100
+ return MTKParameters{typeof (tunable_buffer), typeof (disc_buffer), typeof (const_buffer),
101
+ typeof (dep_buffer), typeof (update_function)}(tunable_buffer,
102
+ disc_buffer, const_buffer, dep_buffer, update_function)
68
103
end
69
104
70
105
SciMLStructures. isscimlstructure (:: MTKParameters ) = true
@@ -74,20 +109,24 @@ SciMLStructures.ismutablescimlstructure(::MTKParameters) = true
74
109
for (Portion, field) in [
75
110
(SciMLStructures. Tunable, :tunable )
76
111
(SciMLStructures. Discrete, :discrete )
112
+ (SciMLStructures. Constants, :constant )
77
113
]
78
114
@eval function SciMLStructures. canonicalize (:: $Portion , p:: MTKParameters )
79
115
function repack (values)
80
116
p.$ field .= values
117
+ p. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
81
118
end
82
119
return p.$ field, repack, true
83
120
end
84
121
85
122
@eval function SciMLStructures. replace (:: $Portion , p:: MTKParameters , newvals)
86
- @set p.$ field = newvals
123
+ @set! p.$ field = newvals
124
+ p. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
87
125
end
88
126
89
127
@eval function SciMLStructures. replace! (:: $Portion , p:: MTKParameters , newvals)
90
128
p.$ field .= newvals
129
+ p. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
91
130
nothing
92
131
end
93
132
end
@@ -98,6 +137,10 @@ function SymbolicIndexingInterface.parameter_values(p::MTKParameters, i::Paramet
98
137
return p. tunable[idx]
99
138
elseif portion isa SciMLStructures. Discrete
100
139
return p. discrete[idx]
140
+ elseif portion isa SciMLStructures. Constants
141
+ return p. constant[idx]
142
+ elseif portion === nothing
143
+ return p. dependent[idx]
101
144
else
102
145
error (" Unhandled portion $portion " )
103
146
end
@@ -109,51 +152,62 @@ function SymbolicIndexingInterface.set_parameter!(p::MTKParameters, val, idx::Pa
109
152
p. tunable[idx] = val
110
153
elseif portion isa SciMLStructures. Discrete
111
154
p. discrete[idx] = val
155
+ elseif portion isa SciMLStructures. Constants
156
+ p. constant[idx] = val
157
+ elseif portion === nothing
158
+ error (" Cannot set value of parameter: " )
112
159
else
113
160
error (" Unhandled portion $portion " )
114
161
end
162
+ p. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
115
163
end
116
164
117
165
# for compiling callbacks
118
166
# getindex indexes the vectors, setindex! linearly indexes values
119
167
# it's inconsistent, but we need it to be this way
120
168
function Base. getindex (buf:: MTKParameters , i)
121
- if i <= length (buf. tunable. x)
122
- buf. tunable. x[i]
123
- else
124
- buf. discrete. x[i - length (buf. tunable. x)]
169
+ if ! isempty (buf. tunable)
170
+ i <= length (buf. tunable. x) && return buf. tunable. x[i]
171
+ i -= length (buf. tunable. x)
172
+ end
173
+ if ! isempty (buf. discrete)
174
+ i <= length (buf. discrete. x) && return buf. discrete. x[i]
175
+ i -= length (buf. discrete. x)
176
+ end
177
+ if ! isempty (buf. constant)
178
+ i <= length (buf. constant. x) && return buf. constant. x[i]
179
+ i -= length (buf. constant. x)
125
180
end
181
+ isempty (buf. dependent) || return buf. dependent. x[i]
182
+ throw (BoundsError (buf, i))
126
183
end
127
184
function Base. setindex! (buf:: MTKParameters , val, i)
128
185
if i <= length (buf. tunable)
129
186
buf. tunable[i] = val
130
- else
187
+ elseif i <= length (buf . tunable) + length (buf . discrete)
131
188
buf. discrete[i - length (buf. tunable)] = val
189
+ else
190
+ buf. constant[i - length (buf. tunable) - length (buf. discrete)] = val
132
191
end
192
+ buf. dependent_update (p. dependent, p. tunable. x... , p. discrete. x... , p. constant. x... )
133
193
end
134
194
135
195
function Base. iterate (buf:: MTKParameters , state = 1 )
136
- tunable = if isempty (buf. tunable)
137
- ()
138
- elseif buf. tunable isa ArrayPartition
139
- buf. tunable. x
140
- end
141
- discrete = if isempty (buf. discrete)
142
- ()
143
- elseif buf. discrete isa ArrayPartition
144
- buf. discrete. x
145
- end
146
- if state <= length (tunable)
147
- return (tunable[state], state + 1 )
148
- elseif state <= length (tunable) + length (discrete)
149
- return (discrete[state - length (tunable)], state + 1 )
196
+ total_len = 0
197
+ isempty (buf. tunable) || (total_len += length (buf. tunable. x))
198
+ isempty (buf. discrete) || (total_len += length (buf. discrete. x))
199
+ isempty (buf. constant) || (total_len += length (buf. constant. x))
200
+ isempty (buf. dependent) || (total_len += length (buf. dependent. x))
201
+ if state <= total_len
202
+ return (buf[state], state + 1 )
150
203
else
151
204
return nothing
152
205
end
153
206
end
154
207
155
208
function Base.:(== )(a:: MTKParameters , b:: MTKParameters )
156
- return a. tunable == b. tunable && a. discrete == b. discrete
209
+ return a. tunable == b. tunable && a. discrete == b. discrete &&
210
+ a. constant == b. constant && a. dependent == b. dependent
157
211
end
158
212
159
213
# to support linearize/linearization_function
0 commit comments