Skip to content

Commit c26d4d9

Browse files
AayushSabharwalChrisRackauckas
authored andcommitted
feat: un-scalarize inferred parameters, improve parameter initialization
1 parent 55f2730 commit c26d4d9

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

src/systems/abstractsystem.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,7 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
192192
ic = get_index_cache(sys)
193193
h = getsymbolhash(sym)
194194
return haskey(ic.unknown_idx, h) ||
195-
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) ||
196-
hasname(sym) && is_variable(sys, getname(sym))
195+
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym)))
197196
else
198197
return any(isequal(sym), variable_symbols(sys)) ||
199198
hasname(sym) && is_variable(sys, getname(sym))
@@ -220,8 +219,6 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
220219
h = getsymbolhash(default_toterm(sym))
221220
if haskey(ic.unknown_idx, h)
222221
ic.unknown_idx[h]
223-
elseif hasname(sym)
224-
variable_index(sys, getname(sym))
225222
else
226223
nothing
227224
end
@@ -264,8 +261,7 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
264261
else
265262
h = getsymbolhash(default_toterm(sym))
266263
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
267-
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
268-
hasname(sym) && is_parameter(sys, getname(sym))
264+
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
269265
end
270266
end
271267
return any(isequal(sym), parameter_symbols(sys)) ||

src/systems/diffeqs/odesystem.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,23 @@ function ODESystem(eqs, iv; kwargs...)
280280
isdelay(v, iv) || continue
281281
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
282282
end
283+
new_ps = OrderedSet()
284+
for p in ps
285+
if istree(p) && operation(p) === getindex
286+
par = arguments(p)[begin]
287+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && all(par[i] in ps for i in eachindex(par))
288+
push!(new_ps, par)
289+
else
290+
push!(new_ps, p)
291+
end
292+
else
293+
push!(new_ps, p)
294+
end
295+
end
283296
algevars = setdiff(allunknowns, diffvars)
284297
# the orders here are very important!
285298
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
286-
collect(Iterators.flatten((diffvars, algevars))), collect(ps); kwargs...)
299+
collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
287300
end
288301

289302
# NOTE: equality does not check cached Jacobian

src/systems/parameter_buffer.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
3636
for (k, v) in p if !haskey(extra_params, unwrap(k)))
3737
end
3838

39+
for (sym, _) in p
40+
if istree(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[begin])
41+
# error("Scalarized parameter values are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
42+
end
43+
end
44+
3945
tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length)
4046
for temp in ic.param_buffer_sizes)
4147
disc_buffer = Tuple(Vector{temp.type}(undef, temp.length)
@@ -48,6 +54,7 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
4854
for temp in ic.nonnumeric_buffer_sizes)
4955
function set_value(sym, val)
5056
h = getsymbolhash(sym)
57+
done = true
5158
if haskey(ic.param_idx, h)
5259
i, j = ic.param_idx[h]
5360
tunable_buffer[i][j] = val
@@ -64,17 +71,24 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
6471
i, j = ic.nonnumeric_idx[h]
6572
nonnumeric_buffer[i][j] = val
6673
elseif !isequal(default_toterm(sym), sym)
67-
set_value(default_toterm(sym), val)
74+
done = set_value(default_toterm(sym), val)
6875
else
69-
error("Symbol $sym does not have an index")
76+
done = false
7077
end
78+
return done
7179
end
7280

7381
for (sym, val) in p
7482
sym = unwrap(sym)
7583
ctype = concrete_symtype(sym)
7684
val = convert(ctype, fixpoint_sub(val, p))
77-
set_value(sym, val)
85+
done = set_value(sym, val)
86+
if !done && Symbolics.isarraysymbolic(sym)
87+
done = all(set_value.(collect(sym), val))
88+
end
89+
if !done
90+
error("Symbol $sym does not have an index")
91+
end
7892
end
7993

8094
if has_parameter_dependencies(sys) &&

0 commit comments

Comments
 (0)