@@ -3,29 +3,44 @@ $(TYPEDSIGNATURES)
3
3
4
4
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
5
5
"""
6
- function modelingtoolkitize (prob:: DiffEqBase.ODEProblem ; kwargs... )
6
+ function modelingtoolkitize (
7
+ prob:: DiffEqBase.ODEProblem ; u_names = nothing , p_names = nothing , kwargs... )
7
8
prob. f isa DiffEqBase. AbstractParameterizedFunction &&
8
9
return prob. f. sys
9
- @parameters t
10
-
10
+ t = t_nounits
11
11
p = prob. p
12
12
has_p = ! (p isa Union{DiffEqBase. NullParameters, Nothing})
13
13
14
- _vars = define_vars (prob. u0, t)
14
+ if u_names != = nothing
15
+ varnames_length_check (prob. u0, u_names; is_unknowns = true )
16
+ _vars = [_defvar (name)(t) for name in u_names]
17
+ elseif SciMLBase. has_sys (prob. f)
18
+ varnames = getname .(variable_symbols (prob. f. sys))
19
+ varidxs = variable_index .((prob. f. sys,), varnames)
20
+ invpermute! (varnames, varidxs)
21
+ _vars = [_defvar (name)(t) for name in varnames]
22
+ else
23
+ _vars = define_vars (prob. u0, t)
24
+ end
15
25
16
26
vars = prob. u0 isa Number ? _vars : ArrayInterface. restructure (prob. u0, _vars)
17
27
params = if has_p
18
- _params = define_params (p)
28
+ if p_names === nothing && SciMLBase. has_sys (prob. f)
29
+ p_names = Dict (parameter_index (prob. f. sys, sym) => sym
30
+ for sym in parameter_symbols (prob. f. sys))
31
+ end
32
+ _params = define_params (p, p_names)
19
33
p isa Number ? _params[1 ] :
20
- (p isa Tuple || p isa NamedTuple || p isa AbstractDict ? _params :
34
+ (p isa Tuple || p isa NamedTuple || p isa AbstractDict || p isa MTKParameters ?
35
+ _params :
21
36
ArrayInterface. restructure (p, _params))
22
37
else
23
38
[]
24
39
end
25
40
26
41
var_set = Set (vars)
27
42
28
- D = Differential (t)
43
+ D = D_nounits
29
44
mm = prob. f. mass_matrix
30
45
31
46
if mm === I
@@ -70,6 +85,8 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
70
85
default_p = if has_p
71
86
if prob. p isa AbstractDict
72
87
Dict (v => prob. p[k] for (k, v) in pairs (_params))
88
+ elseif prob. p isa MTKParameters
89
+ Dict (params .=> reduce (vcat, prob. p))
73
90
else
74
91
Dict (params .=> vec (collect (prob. p)))
75
92
end
@@ -125,44 +142,96 @@ function Base.showerror(io::IO, e::ModelingtoolkitizeParametersNotSupportedError
125
142
println (io, e. type)
126
143
end
127
144
128
- function define_params (p)
145
+ function varnames_length_check (vars, names; is_unknowns = false )
146
+ if length (names) != length (vars)
147
+ throw (ArgumentError ("""
148
+ Number of $(is_unknowns ? " unknowns" : " parameters" ) ($(length (vars)) ) \
149
+ does not match number of names ($(length (names)) ).
150
+ """ ))
151
+ end
152
+ end
153
+
154
+ function define_params (p, _ = nothing )
129
155
throw (ModelingtoolkitizeParametersNotSupportedError (typeof (p)))
130
156
end
131
157
132
- function define_params (p:: AbstractArray )
133
- [toparam (variable (:α , i)) for i in eachindex (p)]
158
+ function define_params (p:: AbstractArray , names = nothing )
159
+ if names === nothing
160
+ [toparam (variable (:α , i)) for i in eachindex (p)]
161
+ else
162
+ varnames_length_check (p, names)
163
+ [toparam (variable (names[i])) for i in eachindex (p)]
164
+ end
134
165
end
135
166
136
- function define_params (p:: Number )
137
- [toparam (variable (:α ))]
167
+ function define_params (p:: Number , names = nothing )
168
+ if names === nothing
169
+ [toparam (variable (:α ))]
170
+ elseif names isa Union{AbstractArray, AbstractDict}
171
+ varnames_length_check (p, names)
172
+ [toparam (variable (names[i])) for i in eachindex (p)]
173
+ else
174
+ [toparam (variable (names))]
175
+ end
138
176
end
139
177
140
- function define_params (p:: AbstractDict )
141
- OrderedDict (k => toparam (variable (:α , i)) for (i, k) in zip (1 : length (p), keys (p)))
178
+ function define_params (p:: AbstractDict , names = nothing )
179
+ if names === nothing
180
+ OrderedDict (k => toparam (variable (:α , i)) for (i, k) in zip (1 : length (p), keys (p)))
181
+ else
182
+ varnames_length_check (p, names)
183
+ OrderedDict (k => toparam (variable (names[k])) for k in keys (p))
184
+ end
142
185
end
143
186
144
- function define_params (p:: Union{SLArray, LArray} )
145
- [toparam (variable (x)) for x in LabelledArrays. symnames (typeof (p))]
187
+ function define_params (p:: Union{SLArray, LArray} , names = nothing )
188
+ if names === nothing
189
+ [toparam (variable (x)) for x in LabelledArrays. symnames (typeof (p))]
190
+ else
191
+ varnames_length_check (p, names)
192
+ [toparam (variable (names[i])) for i in eachindex (p)]
193
+ end
146
194
end
147
195
148
- function define_params (p:: Tuple )
149
- tuple ((toparam (variable (:α , i)) for i in eachindex (p)). .. )
196
+ function define_params (p:: Tuple , names = nothing )
197
+ if names === nothing
198
+ tuple ((toparam (variable (:α , i)) for i in eachindex (p)). .. )
199
+ else
200
+ varnames_length_check (p, names)
201
+ tuple ((toparam (variable (names[i])) for i in eachindex (p)). .. )
202
+ end
150
203
end
151
204
152
- function define_params (p:: NamedTuple )
153
- NamedTuple (x => toparam (variable (x)) for x in keys (p))
205
+ function define_params (p:: NamedTuple , names = nothing )
206
+ if names === nothing
207
+ NamedTuple (x => toparam (variable (x)) for x in keys (p))
208
+ else
209
+ varnames_length_check (p, names)
210
+ NamedTuple (x => toparam (variable (names[x])) for x in keys (p))
211
+ end
154
212
end
155
213
156
- function define_params (p:: MTKParameters )
157
- bufs = (p... ,)
158
- i = 1
159
- ps = []
160
- for buf in bufs
161
- for _ in buf
162
- push! (ps, toparam (variable (:α , i)))
214
+ function define_params (p:: MTKParameters , names = nothing )
215
+ if names === nothing
216
+ bufs = (p... ,)
217
+ i = 1
218
+ ps = []
219
+ for buf in bufs
220
+ for _ in buf
221
+ push! (
222
+ ps,
223
+ if names === nothing
224
+ toparam (variable (:α , i))
225
+ else
226
+ toparam (variable (names[i]))
227
+ end
228
+ )
229
+ end
163
230
end
231
+ return identity .(ps)
232
+ else
233
+ return collect (values (names))
164
234
end
165
- return identity .(ps)
166
235
end
167
236
168
237
"""
0 commit comments