Skip to content

Commit 9850aee

Browse files
feat: add parameter type and size validation in remake_buffer and setp
1 parent 335c7ba commit 9850aee

File tree

6 files changed

+191
-59
lines changed

6 files changed

+191
-59
lines changed

src/systems/index_cache.jl

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct BufferTemplate
2-
type::DataType
2+
type::Union{DataType, UnionAll}
33
length::Int
44
end
55

@@ -14,8 +14,11 @@ const NONNUMERIC_PORTION = :nonnumeric
1414
struct ParameterIndex{P, I}
1515
portion::P
1616
idx::I
17+
validate_size::Bool
1718
end
1819

20+
ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false)
21+
1922
const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
2023
const UnknownIndexMap = Dict{
2124
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
@@ -32,11 +35,14 @@ struct IndexCache
3235
constant_buffer_sizes::Vector{BufferTemplate}
3336
dependent_buffer_sizes::Vector{BufferTemplate}
3437
nonnumeric_buffer_sizes::Vector{BufferTemplate}
38+
symbol_to_variable::Dict{Symbol, BasicSymbolic}
3539
end
3640

3741
function IndexCache(sys::AbstractSystem)
3842
unks = solved_unknowns(sys)
3943
unk_idxs = UnknownIndexMap()
44+
symbol_to_variable = Dict{Symbol, BasicSymbolic}()
45+
4046
let idx = 1
4147
for sym in unks
4248
usym = unwrap(sym)
@@ -48,7 +54,9 @@ function IndexCache(sys::AbstractSystem)
4854
unk_idxs[usym] = sym_idx
4955

5056
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
51-
unk_idxs[getname(usym)] = sym_idx
57+
name = getname(usym)
58+
unk_idxs[name] = sym_idx
59+
symbol_to_variable[name] = sym
5260
end
5361
idx += length(sym)
5462
end
@@ -64,7 +72,9 @@ function IndexCache(sys::AbstractSystem)
6472
end
6573
unk_idxs[arrsym] = idxs
6674
if hasname(arrsym)
67-
unk_idxs[getname(arrsym)] = idxs
75+
name = getname(arrsym)
76+
unk_idxs[name] = idxs
77+
symbol_to_variable[name] = arrsym
6878
end
6979
end
7080
end
@@ -142,14 +152,15 @@ function IndexCache(sys::AbstractSystem)
142152
idxs[default_toterm(p)] = (i, j)
143153
if hasname(p) && (!istree(p) || operation(p) !== getindex)
144154
idxs[getname(p)] = (i, j)
155+
symbol_to_variable[getname(p)] = p
145156
idxs[getname(default_toterm(p))] = (i, j)
157+
symbol_to_variable[getname(default_toterm(p))] = p
146158
end
147159
end
148160
push!(buffer_sizes, BufferTemplate(T, length(buf)))
149161
end
150162
return idxs, buffer_sizes
151163
end
152-
153164
disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs(disc_buffers)
154165
tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers)
155166
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
@@ -167,7 +178,8 @@ function IndexCache(sys::AbstractSystem)
167178
tunable_buffer_sizes,
168179
const_buffer_sizes,
169180
dependent_buffer_sizes,
170-
nonnumeric_buffer_sizes
181+
nonnumeric_buffer_sizes,
182+
symbol_to_variable
171183
)
172184
end
173185

@@ -188,16 +200,21 @@ function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
188200
end
189201

190202
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
203+
if sym isa Symbol
204+
sym = ic.symbol_to_variable[sym]
205+
end
206+
validate_size = Symbolics.isarraysymbolic(sym) &&
207+
Symbolics.shape(sym) !== Symbolics.Unknown()
191208
return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing
192-
ParameterIndex(SciMLStructures.Tunable(), idx)
209+
ParameterIndex(SciMLStructures.Tunable(), idx, validate_size)
193210
elseif (idx = check_index_map(ic.discrete_idx, sym)) !== nothing
194-
ParameterIndex(SciMLStructures.Discrete(), idx)
211+
ParameterIndex(SciMLStructures.Discrete(), idx, validate_size)
195212
elseif (idx = check_index_map(ic.constant_idx, sym)) !== nothing
196-
ParameterIndex(SciMLStructures.Constants(), idx)
213+
ParameterIndex(SciMLStructures.Constants(), idx, validate_size)
197214
elseif (idx = check_index_map(ic.nonnumeric_idx, sym)) !== nothing
198-
ParameterIndex(NONNUMERIC_PORTION, idx)
215+
ParameterIndex(NONNUMERIC_PORTION, idx, validate_size)
199216
elseif (idx = check_index_map(ic.dependent_idx, sym)) !== nothing
200-
ParameterIndex(DEPENDENT_PORTION, idx)
217+
ParameterIndex(DEPENDENT_PORTION, idx, validate_size)
201218
else
202219
nothing
203220
end
@@ -222,26 +239,6 @@ function check_index_map(idxmap, sym)
222239
end
223240
end
224241

225-
function ParameterIndex(ic::IndexCache, p, sub_idx = ())
226-
p = unwrap(p)
227-
return if haskey(ic.tunable_idx, p)
228-
ParameterIndex(SciMLStructures.Tunable(), (ic.tunable_idx[p]..., sub_idx...))
229-
elseif haskey(ic.discrete_idx, p)
230-
ParameterIndex(SciMLStructures.Discrete(), (ic.discrete_idx[p]..., sub_idx...))
231-
elseif haskey(ic.constant_idx, p)
232-
ParameterIndex(SciMLStructures.Constants(), (ic.constant_idx[p]..., sub_idx...))
233-
elseif haskey(ic.dependent_idx, p)
234-
ParameterIndex(DEPENDENT_PORTION, (ic.dependent_idx[p]..., sub_idx...))
235-
elseif haskey(ic.nonnumeric_idx, p)
236-
ParameterIndex(NONNUMERIC_PORTION, (ic.nonnumeric_idx[p]..., sub_idx...))
237-
elseif istree(p) && operation(p) === getindex
238-
_p, sub_idx... = arguments(p)
239-
ParameterIndex(ic, _p, sub_idx)
240-
else
241-
nothing
242-
end
243-
end
244-
245242
function discrete_linear_index(ic::IndexCache, idx::ParameterIndex)
246243
idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected")
247244
ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0)

src/systems/parameter_buffer.jl

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,22 +282,31 @@ end
282282

283283
function SymbolicIndexingInterface.set_parameter!(
284284
p::MTKParameters, val, idx::ParameterIndex)
285-
@unpack portion, idx = idx
285+
@unpack portion, idx, validate_size = idx
286286
i, j, k... = idx
287287
if portion isa SciMLStructures.Tunable
288288
if isempty(k)
289+
if validate_size && size(val) !== size(p.tunable[i][j])
290+
throw(InvalidParameterSizeException(size(p.tunable[i][j]), size(val)))
291+
end
289292
p.tunable[i][j] = val
290293
else
291294
p.tunable[i][j][k...] = val
292295
end
293296
elseif portion isa SciMLStructures.Discrete
294297
if isempty(k)
298+
if validate_size && size(val) !== size(p.discrete[i][j])
299+
throw(InvalidParameterSizeException(size(p.discrete[i][j]), size(val)))
300+
end
295301
p.discrete[i][j] = val
296302
else
297303
p.discrete[i][j][k...] = val
298304
end
299305
elseif portion isa SciMLStructures.Constants
300306
if isempty(k)
307+
if validate_size && size(val) !== size(p.constant[i][j])
308+
throw(InvalidParameterSizeException(size(p.constant[i][j]), size(val)))
309+
end
301310
p.constant[i][j] = val
302311
else
303312
p.constant[i][j][k...] = val
@@ -366,13 +375,59 @@ function narrow_buffer_type_and_fallback_undefs(oldbuf::Vector, newbuf::Vector)
366375
isassigned(newbuf, i) || continue
367376
type = promote_type(type, typeof(newbuf[i]))
368377
end
378+
if type == Union{}
379+
type = eltype(oldbuf)
380+
end
369381
for i in eachindex(newbuf)
370382
isassigned(newbuf, i) && continue
371383
newbuf[i] = convert(type, oldbuf[i])
372384
end
373385
return convert(Vector{type}, newbuf)
374386
end
375387

388+
function validate_parameter_type(ic::IndexCache, p, index, val)
389+
p = unwrap(p)
390+
if p isa Symbol
391+
p = get(ic.symbol_to_variable, p, nothing)
392+
if p === nothing
393+
@warn "No matching variable found for `Symbol` $p, skipping type validation."
394+
return
395+
end
396+
end
397+
(; portion,) = index
398+
if portion === NONNUMERIC_PORTION
399+
stype = concrete_symtype(p)
400+
val isa stype ||
401+
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
402+
return
403+
end
404+
stype = concrete_symtype(p)
405+
if stype <: AbstractArray && !isa(val, AbstractArray)
406+
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
407+
end
408+
if stype <: AbstractArray && Symbolics.shape(p) !== Symbolics.Unknown() &&
409+
size(val) != size(p)
410+
throw(InvalidParameterSizeException(p, val))
411+
end
412+
val isa stype && return
413+
if stype <: AbstractArray
414+
etype = eltype(stype)
415+
if etype <: Real
416+
etype = Real
417+
end
418+
etype = SciMLBase.parameterless_type(etype)
419+
eltype(val) <: etype || throw(ParameterTypeException(
420+
:validate_parameter_type, p, AbstractArray{etype}, val))
421+
else
422+
if stype <: Real
423+
stype = Real
424+
end
425+
stype = SciMLBase.parameterless_type(stype)
426+
val isa stype ||
427+
throw(ParameterTypeException(:validate_parameter_type, p, stype, val))
428+
end
429+
end
430+
376431
function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, vals::Dict)
377432
newbuf = @set oldbuf.tunable = Tuple(Vector{Any}(undef, length(buf))
378433
for buf in oldbuf.tunable)
@@ -383,9 +438,12 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
383438
@set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf))
384439
for buf in newbuf.nonnumeric)
385440

441+
ic = get_index_cache(sys)
386442
for (p, val) in vals
443+
idx = parameter_index(sys, p)
444+
validate_parameter_type(ic, p, idx, val)
387445
_set_parameter_unchecked!(
388-
newbuf, val, parameter_index(sys, p); update_dependent = false)
446+
newbuf, val, idx; update_dependent = false)
389447
end
390448

391449
@set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs.(
@@ -549,3 +607,15 @@ function as_duals(p::MTKParameters, dualtype)
549607
discrete = dualtype.(p.discrete)
550608
return MTKParameters{typeof(tunable), typeof(discrete)}(tunable, discrete)
551609
end
610+
611+
function InvalidParameterSizeException(param, val)
612+
DimensionMismatch("InvalidParameterSizeException: For parameter $(param) expected value of size $(size(param)). Received value $(val) of size $(size(val)).")
613+
end
614+
615+
function InvalidParameterSizeException(param::Tuple, val::Tuple)
616+
DimensionMismatch("InvalidParameterSizeException: Expected value of size $(param). Received value of size $(val).")
617+
end
618+
619+
function ParameterTypeException(func, param, expected, val)
620+
TypeError(func, "Parameter $param", expected, val)
621+
end

test/index_cache.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: t_nounits as t
3+
4+
# Ensure indexes of array symbolics are cached appropriately
5+
@variables x(t)[1:2]
6+
@named sys = ODESystem(Equation[], t, [x], [])
7+
sys1 = complete(sys)
8+
@named sys = ODESystem(Equation[], t, [x...], [])
9+
sys2 = complete(sys)
10+
for sys in [sys1, sys2]
11+
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
12+
@test is_variable(sys, sym)
13+
@test variable_index(sys, sym) == idx
14+
end
15+
end
16+
17+
@variables x(t)[1:2, 1:2]
18+
@named sys = ODESystem(Equation[], t, [x], [])
19+
sys1 = complete(sys)
20+
@named sys = ODESystem(Equation[], t, [x...], [])
21+
sys2 = complete(sys)
22+
for sys in [sys1, sys2]
23+
@test is_variable(sys, x)
24+
@test variable_index(sys, x) == [1 3; 2 4]
25+
for i in eachindex(x)
26+
@test is_variable(sys, x[i])
27+
@test variable_index(sys, x[i]) == variable_index(sys, x)[i]
28+
end
29+
end
30+
31+
# Ensure Symbol to symbolic map is correct
32+
@parameters p1 p2[1:2] p3::String
33+
@variables x(t) y(t)[1:2] z(t)
34+
35+
@named sys = ODESystem(Equation[], t, [x, y, z], [p1, p2, p3])
36+
sys = complete(sys)
37+
38+
ic = ModelingToolkit.get_index_cache(sys)
39+
40+
@test isequal(ic.symbol_to_variable[:p1], p1)
41+
@test isequal(ic.symbol_to_variable[:p2], p2)
42+
@test isequal(ic.symbol_to_variable[:p3], p3)
43+
@test isequal(ic.symbol_to_variable[:x], x)
44+
@test isequal(ic.symbol_to_variable[:y], y)
45+
@test isequal(ic.symbol_to_variable[:z], z)

test/mtkparameters.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,49 @@ function loss(x)
228228
end
229229

230230
@test_nowarn ForwardDiff.gradient(loss, collect(tunables))
231+
232+
@testset "Parameter type validation" begin
233+
struct Foo{T}
234+
x::T
235+
end
236+
237+
@parameters a b::Int c::Vector{Float64} d[1:2, 1:2]::Int e::Foo{Int} f::Foo
238+
@named sys = ODESystem(Equation[], t, [], [a, b, c, d, e, f])
239+
sys = complete(sys)
240+
ps = MTKParameters(sys,
241+
Dict(a => 1.0, b => 2, c => 3ones(2),
242+
d => 3ones(Int, 2, 2), e => Foo(1), f => Foo("a")))
243+
@test_broken setp(sys, c)(ps, ones(4)) # so this is fixed when SII is fixed
244+
begin
245+
c_idx = parameter_index(sys, c)
246+
@test_nowarn set_parameter!(ps, ones(4), c_idx)
247+
end
248+
@test_throws DimensionMismatch set_parameter!(
249+
ps, 4ones(Int, 3, 2), parameter_index(sys, d))
250+
@test_throws DimensionMismatch set_parameter!(
251+
ps, 4ones(Int, 4), parameter_index(sys, d)) # size has to match, not just length
252+
@test_nowarn setp(sys, f)(ps, Foo(:a)) # can change non-concrete type
253+
254+
# Same flexibility is afforded to `b::Int` to allow for ForwardDiff
255+
for sym in [a, b]
256+
@test_nowarn remake_buffer(sys, ps, Dict(sym => 1))
257+
newps = @test_nowarn remake_buffer(sys, ps, Dict(sym => 1.0f0)) # Can change type if it's numeric
258+
@test getp(sys, sym)(newps) isa Float32
259+
newps = @test_nowarn remake_buffer(sys, ps, Dict(sym => ForwardDiff.Dual(1.0)))
260+
@test getp(sys, sym)(newps) isa ForwardDiff.Dual
261+
@test_throws TypeError remake_buffer(sys, ps, Dict(sym => :a)) # still has to be numeric
262+
end
263+
264+
newps = @test_nowarn remake_buffer(sys, ps, Dict(c => view(1.0:4.0, 2:4))) # can change type of array
265+
@test_broken getp(sys, c)(newps) # so this is fixed when SII is fixed
266+
@test parameter_values(newps, parameter_index(sys, c)) [2.0, 3.0, 4.0]
267+
@test_throws TypeError remake_buffer(sys, ps, Dict(c => [:a, :b, :c])) # can't arbitrarily change eltype
268+
@test_throws TypeError remake_buffer(sys, ps, Dict(c => :a)) # can't arbitrarily change type
269+
270+
newps = @test_nowarn remake_buffer(sys, ps, Dict(d => ForwardDiff.Dual.(ones(2, 2)))) # can change eltype
271+
@test_throws TypeError remake_buffer(sys, ps, Dict(d => [:a :b; :c :d])) # eltype still has to be numeric
272+
@test getp(sys, d)(newps) isa Matrix{<:ForwardDiff.Dual}
273+
274+
@test_throws TypeError remake_buffer(sys, ps, Dict(e => Foo(2.0))) # need exact same type for nonnumeric
275+
@test_nowarn remake_buffer(sys, ps, Dict(f => Foo(:a)))
276+
end

test/odesystem.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,30 +1132,3 @@ outer = structural_simplify(outer)
11321132
prob = ODEProblem(outer, [outer.y => 2.0], (0.0, 10.0))
11331133
int = init(prob, Rodas4())
11341134
@test int[outer.sys.x] == 1.0
1135-
1136-
# Ensure indexes of array symbolics are cached appropriately
1137-
@variables x(t)[1:2]
1138-
@named sys = ODESystem(Equation[], t, [x], [])
1139-
sys1 = complete(sys)
1140-
@named sys = ODESystem(Equation[], t, [x...], [])
1141-
sys2 = complete(sys)
1142-
for sys in [sys1, sys2]
1143-
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
1144-
@test is_variable(sys, sym)
1145-
@test variable_index(sys, sym) == idx
1146-
end
1147-
end
1148-
1149-
@variables x(t)[1:2, 1:2]
1150-
@named sys = ODESystem(Equation[], t, [x], [])
1151-
sys1 = complete(sys)
1152-
@named sys = ODESystem(Equation[], t, [x...], [])
1153-
sys2 = complete(sys)
1154-
for sys in [sys1, sys2]
1155-
@test is_variable(sys, x)
1156-
@test variable_index(sys, x) == [1 3; 2 4]
1157-
for i in eachindex(x)
1158-
@test is_variable(sys, x[i])
1159-
@test variable_index(sys, x[i]) == variable_index(sys, x)[i]
1160-
end
1161-
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ end
2424
@safetestset "Parsing Test" include("variable_parsing.jl")
2525
@safetestset "Simplify Test" include("simplify.jl")
2626
@safetestset "Direct Usage Test" include("direct.jl")
27+
@safetestset "IndexCache Test" include("index_cache.jl")
2728
@safetestset "System Linearity Test" include("linearity.jl")
2829
@safetestset "Input Output Test" include("input_output_handling.jl")
2930
@safetestset "Clock Test" include("clock.jl")

0 commit comments

Comments
 (0)