1
1
symconvert (:: Type{Symbolics.Struct{T}} , x) where {T} = convert (T, x)
2
2
symconvert (:: Type{T} , x) where {T} = convert (T, x)
3
+ symconvert (:: Type{Real} , x:: Integer ) = convert (Float64, x)
4
+ symconvert (:: Type{V} , x) where {V <: AbstractArray } = convert (V, symconvert .(eltype (V), x))
5
+
3
6
struct MTKParameters{T, D, C, E, N, F, G}
4
7
tunable:: T
5
8
discrete:: D
@@ -107,7 +110,7 @@ function MTKParameters(
107
110
for (sym, val) in p
108
111
sym = unwrap (sym)
109
112
val = unwrap (val)
110
- ctype = concrete_symtype (sym)
113
+ ctype = symtype (sym)
111
114
if symbolic_type (val) != = NotSymbolic ()
112
115
continue
113
116
end
@@ -126,19 +129,27 @@ function MTKParameters(
126
129
end
127
130
end
128
131
end
132
+ tunable_buffer = narrow_buffer_type .(tunable_buffer)
133
+ disc_buffer = narrow_buffer_type .(disc_buffer)
134
+ const_buffer = narrow_buffer_type .(const_buffer)
135
+ nonnumeric_buffer = narrow_buffer_type .(nonnumeric_buffer)
129
136
130
137
if has_parameter_dependencies (sys) &&
131
138
(pdeps = get_parameter_dependencies (sys)) != = nothing
132
139
pdeps = Dict (k => fixpoint_sub (v, pdeps) for (k, v) in pdeps)
133
- dep_exprs = ArrayPartition ((wrap . (v) for v in dep_buffer). .. )
140
+ dep_exprs = ArrayPartition ((Any[ missing for _ in 1 : length (v)] for v in dep_buffer). .. )
134
141
for (sym, val) in pdeps
135
142
i, j = ic. dependent_idx[sym]
136
143
dep_exprs. x[i][j] = wrap (val)
137
144
end
145
+ dep_exprs = identity .(dep_exprs)
138
146
p = reorder_parameters (ic, full_parameters (sys))
139
147
oop, iip = build_function (dep_exprs, p... )
140
148
update_function_iip, update_function_oop = RuntimeGeneratedFunctions. @RuntimeGeneratedFunction (iip),
141
149
RuntimeGeneratedFunctions. @RuntimeGeneratedFunction (oop)
150
+ update_function_iip (ArrayPartition (dep_buffer), tunable_buffer... , disc_buffer... ,
151
+ const_buffer... , nonnumeric_buffer... , dep_buffer... )
152
+ dep_buffer = narrow_buffer_type .(dep_buffer)
142
153
else
143
154
update_function_iip = update_function_oop = nothing
144
155
end
@@ -148,12 +159,26 @@ function MTKParameters(
148
159
typeof (dep_buffer), typeof (nonnumeric_buffer), typeof (update_function_iip),
149
160
typeof (update_function_oop)}(tunable_buffer, disc_buffer, const_buffer, dep_buffer,
150
161
nonnumeric_buffer, update_function_iip, update_function_oop)
151
- if mtkps. dependent_update_iip != = nothing
152
- mtkps. dependent_update_iip (ArrayPartition (mtkps. dependent), mtkps... )
153
- end
154
162
return mtkps
155
163
end
156
164
165
+ function narrow_buffer_type (buffer:: AbstractArray )
166
+ type = Union{}
167
+ for x in buffer
168
+ type = promote_type (type, typeof (x))
169
+ end
170
+ return convert .(type, buffer)
171
+ end
172
+
173
+ function narrow_buffer_type (buffer:: AbstractArray{<:AbstractArray} )
174
+ buffer = narrow_buffer_type .(buffer)
175
+ type = Union{}
176
+ for x in buffer
177
+ type = promote_type (type, eltype (x))
178
+ end
179
+ return broadcast .(convert, type, buffer)
180
+ end
181
+
157
182
function buffer_to_arraypartition (buf)
158
183
return ArrayPartition (ntuple (i -> _buffer_to_arrp_helper (buf[i]), Val (length (buf))))
159
184
end
0 commit comments