Skip to content

Commit 0e7c570

Browse files
test: add MTKParameters tests
1 parent 6015036 commit 0e7c570

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

test/mtkparameters.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
3+
using SymbolicIndexingInterface
4+
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
6+
@parameters a b c d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
7+
@named sys = ODESystem(
8+
Equation[], t, [], [a, c, d, e, f, g, h], parameter_dependencies = [b => 2a],
9+
continuous_events = [[a ~ 0] => [c ~ 0]], defaults = Dict(a => 0.0))
10+
sys = complete(sys)
11+
12+
ivs = Dict(c => 3a, d => 4, e => [5.0, 6.0, 7.0],
13+
f => ones(Int, 3, 3), g => [0.1, 0.2, 0.3], h => "foo")
14+
15+
ps = MTKParameters(sys, ivs)
16+
@test_nowarn copy(ps)
17+
# dependent initialization, also using defaults
18+
@test getp(sys, a)(ps) == getp(sys, b)(ps) == getp(sys, c)(ps) == 0.0
19+
@test getp(sys, d)(ps) isa Int
20+
21+
ivs[a] = 1.0
22+
ps = MTKParameters(sys, ivs)
23+
@test_broken getp(sys, g) # SII bug
24+
for (p, val) in ivs
25+
isequal(p, g) && continue # broken
26+
if isequal(p, c)
27+
val = 3ivs[a]
28+
end
29+
idx = parameter_index(sys, p)
30+
# ensure getindex with `ParameterIndex` works
31+
@test ps[idx] == getp(sys, p)(ps) == val
32+
end
33+
34+
# ensure setindex! with `ParameterIndex` works
35+
ps[parameter_index(sys, a)] = 3.0
36+
@test getp(sys, a)(ps) == 3.0
37+
setp(sys, a)(ps, 1.0)
38+
39+
@test getp(sys, a)(ps) == getp(sys, b)(ps) / 2 == getp(sys, c)(ps) / 3 == 1.0
40+
41+
for (portion, values) in [(Tunable(), vcat(ones(9), [1.0, 4.0, 5.0, 6.0, 7.0]))
42+
(Discrete(), [3.0])
43+
(Constants(), [0.1, 0.2, 0.3])]
44+
buffer, repack, alias = canonicalize(portion, ps)
45+
@test alias
46+
@test sort(collect(buffer)) == values
47+
@test all(isone,
48+
canonicalize(portion, SciMLStructures.replace(portion, ps, ones(length(buffer))))[1])
49+
# make sure it is out-of-place
50+
@test sort(collect(buffer)) == values
51+
SciMLStructures.replace!(portion, ps, ones(length(buffer)))
52+
# make sure it is in-place
53+
@test all(isone, canonicalize(portion, ps)[1])
54+
repack(zeros(length(buffer)))
55+
@test all(iszero, canonicalize(portion, ps)[1])
56+
end
57+
58+
setp(sys, a)(ps, 2.0) # test set_parameter!
59+
@test getp(sys, a)(ps) == 2.0
60+
61+
setp(sys, e)(ps, 5ones(3)) # with an array
62+
@test getp(sys, e)(ps) == 5ones(3)
63+
64+
setp(sys, f[2, 2])(ps, 42) # with a sub-index
65+
@test getp(sys, f[2, 2])(ps) == 42
66+
67+
# SII bug
68+
@test_broken setp(sys, g)(ps, ones(100)) # with non-fixed-length array
69+
@test_broken getp(sys, g)(ps) == ones(100)
70+
71+
setp(sys, h)(ps, "bar") # with a non-numeric
72+
@test getp(sys, h)(ps) == "bar"
73+
74+
newps = remake_buffer(sys,
75+
ps,
76+
Dict(a => 1.0f0, b => 5.0f0, c => 2.0, d => 0x5, e => [0.4, 0.5, 0.6],
77+
f => 3ones(UInt, 3, 3), g => ones(Float32, 4), h => "bar"))
78+
79+
for fname in (:tunable, :discrete, :constant, :dependent)
80+
# ensure same number of sub-buffers
81+
@test length(getfield(ps, fname)) == length(getfield(newps, fname))
82+
end
83+
@test ps.dependent_update_iip === newps.dependent_update_iip
84+
@test ps.dependent_update_oop === newps.dependent_update_oop
85+
86+
@test getp(sys, a)(newps) isa Float32
87+
@test getp(sys, b)(newps) == 2.0f0 # ensure dependent update still happened, despite explicit value
88+
@test getp(sys, c)(newps) isa Float64
89+
@test getp(sys, d)(newps) isa UInt8
90+
@test getp(sys, f)(newps) isa Matrix{UInt}
91+
# SII bug
92+
@test_broken getp(sys, g)(newps) isa Vector{Float32}

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ end
6969
@safetestset "Generate Custom Function Test" include("generate_custom_function.jl")
7070
@safetestset "Initial Values Test" include("initial_values.jl")
7171
@safetestset "Discrete System" include("discrete_system.jl")
72+
@safetestset "MTKParameters Test" include("mtkparameters.jl")
7273
end
7374
end
7475

0 commit comments

Comments
 (0)