Skip to content

Commit 85e1863

Browse files
Merge pull request #2581 from AayushSabharwal/as/mtkparameters-tests
test: add MTKParameters tests, fix bugs
2 parents df6b314 + 3b39362 commit 85e1863

13 files changed

+112
-115
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
104104
StaticArrays = "0.10, 0.11, 0.12, 1.0"
105105
SymbolicIndexingInterface = "0.3.11"
106106
SymbolicUtils = "1.0"
107-
Symbolics = "5.24"
107+
Symbolics = "5.26"
108108
URIs = "1"
109109
UnPack = "0.1, 1.0"
110110
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ using PrecompileTools, Reexport
5656
VariableSource, getname, variable, Connection, connect,
5757
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
5858
initial_state, transition, activeState, entry,
59-
ticksInState, timeInState
59+
ticksInState, timeInState, fixpoint_sub, fast_substitute
6060
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
6161
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
6262
tosymbol, lower_varname, diff2term, var_from_nested_derivative,

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module StructuralTransformations
33
using Setfield: @set!, @set
44
using UnPack: @unpack
55

6-
using Symbolics: unwrap, linear_expansion
6+
using Symbolics: unwrap, linear_expansion, fast_substitute
77
using SymbolicUtils
88
using SymbolicUtils.Code
99
using SymbolicUtils.Rewriters
@@ -23,7 +23,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
2525
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
26-
fast_substitute, get_fullvars, has_equations, observed,
26+
get_fullvars, has_equations, observed,
2727
Schedule
2828

2929
using ModelingToolkit.BipartiteGraphs

src/structural_transformation/symbolics_tearing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
8181
end
8282

8383
function tearing_sub(expr, dict, s)
84-
expr = ModelingToolkit.fixpoint_sub(expr, dict)
84+
expr = Symbolics.fixpoint_sub(expr, dict)
8585
s ? simplify(expr) : expr
8686
end
8787

@@ -439,7 +439,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
439439
order, lv = var_order(iv)
440440
dx = D(simplify_shifts(lower_varname_withshift(
441441
fullvars[lv], idep, order - 1)))
442-
eq = dx ~ simplify_shifts(ModelingToolkit.fixpoint_sub(
442+
eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
443443
Symbolics.solve_for(neweqs[ieq],
444444
fullvars[iv]),
445445
total_sub; operator = ModelingToolkit.Shift))
@@ -467,7 +467,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
467467
@warn "Tearing: solving $eq for $var is singular!"
468468
else
469469
rhs = -b / a
470-
neweq = var ~ ModelingToolkit.fixpoint_sub(
470+
neweq = var ~ Symbolics.fixpoint_sub(
471471
simplify ?
472472
Symbolics.simplify(rhs) : rhs,
473473
total_sub; operator = ModelingToolkit.Shift)
@@ -481,7 +481,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
481481
if !(eq.lhs isa Number && eq.lhs == 0)
482482
rhs = eq.rhs - eq.lhs
483483
end
484-
push!(alge_eqs, 0 ~ ModelingToolkit.fixpoint_sub(rhs, total_sub))
484+
push!(alge_eqs, 0 ~ Symbolics.fixpoint_sub(rhs, total_sub))
485485
push!(algeeq_idxs, ieq)
486486
end
487487
end

src/systems/abstractsystem.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,12 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
489489
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
490490
end
491491

492-
SymbolicIndexingInterface.default_values(sys::AbstractSystem) = get_defaults(sys)
492+
function SymbolicIndexingInterface.default_values(sys::AbstractSystem)
493+
return merge(
494+
Dict(eq.lhs => eq.rhs for eq in observed(sys)),
495+
defaults(sys)
496+
)
497+
end
493498

494499
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
495500
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false

src/systems/alias_elimination.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -462,13 +462,3 @@ function observed2graph(eqs, unknowns)
462462

463463
return graph, assigns
464464
end
465-
466-
function fixpoint_sub(x, dict; operator = Nothing)
467-
y = fast_substitute(x, dict; operator)
468-
while !isequal(x, y)
469-
y = x
470-
x = fast_substitute(y, dict; operator)
471-
end
472-
473-
return x
474-
end

src/systems/discrete_system/discrete_system.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, paramm
231231
trueu0map[var] = defs[root]
232232
end
233233
end
234-
@show trueu0map u0map
235234
if has_index_cache(sys) && get_index_cache(sys) !== nothing
236235
u0, defs = get_u0(sys, trueu0map, parammap)
237236
p = MTKParameters(sys, parammap, trueu0map)

src/systems/index_cache.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,16 @@ end
190190
function check_index_map(idxmap, sym)
191191
if (idx = get(idxmap, sym, nothing)) !== nothing
192192
return idx
193-
elseif hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing
193+
elseif !isa(sym, Symbol) && (!istree(sym) || operation(sym) !== getindex) &&
194+
hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing
194195
return idx
195196
end
196197
dsym = default_toterm(sym)
197198
isequal(sym, dsym) && return nothing
198199
if (idx = get(idxmap, dsym, nothing)) !== nothing
199200
idx
200-
elseif hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing
201+
elseif !isa(dsym, Symbol) && (!istree(dsym) || operation(dsym) !== getindex) &&
202+
hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing
201203
idx
202204
else
203205
nothing

src/systems/parameter_buffer.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,13 @@ function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::Para
224224
@unpack portion, idx = pind
225225
i, j, k... = idx
226226
if portion isa SciMLStructures.Tunable
227-
return p.tunable[i][j][k...]
227+
return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...]
228228
elseif portion isa SciMLStructures.Discrete
229-
return p.discrete[i][j][k...]
229+
return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...]
230230
elseif portion isa SciMLStructures.Constants
231-
return p.constant[i][j][k...]
231+
return isempty(k) ? p.constant[i][j] : p.constant[i][j][k...]
232232
elseif portion === DEPENDENT_PORTION
233-
return p.dependent[i][j][k...]
233+
return isempty(k) ? p.dependent[i][j] : p.dependent[i][j][k...]
234234
elseif portion === NONNUMERIC_PORTION
235235
return isempty(k) ? p.nonnumeric[i][j] : p.nonnumeric[i][j][k...]
236236
else

src/utils.jl

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -799,86 +799,6 @@ function fold_constants(ex)
799799
end
800800
end
801801

802-
# Symbolics needs to call unwrap on the substitution rules, but most of the time
803-
# we don't want to do that in MTK.
804-
const Eq = Union{Equation, Inequality}
805-
function fast_substitute(eq::Eq, subs; operator = Nothing)
806-
if eq isa Inequality
807-
Inequality(fast_substitute(eq.lhs, subs; operator),
808-
fast_substitute(eq.rhs, subs; operator),
809-
eq.relational_op)
810-
else
811-
Equation(fast_substitute(eq.lhs, subs; operator),
812-
fast_substitute(eq.rhs, subs; operator))
813-
end
814-
end
815-
function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq}
816-
T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator))
817-
end
818-
function fast_substitute(eqs::AbstractArray, subs; operator = Nothing)
819-
fast_substitute.(eqs, (subs,); operator)
820-
end
821-
function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing)
822-
fast_substitute.(eqs, (subs,); operator)
823-
end
824-
for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair))
825-
@eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing)
826-
fast_substitute(value(expr), subs; operator)
827-
end
828-
end
829-
function fast_substitute(expr, subs; operator = Nothing)
830-
if (_val = get(subs, expr, nothing)) !== nothing
831-
return _val
832-
end
833-
istree(expr) || return expr
834-
op = fast_substitute(operation(expr), subs; operator)
835-
args = SymbolicUtils.unsorted_arguments(expr)
836-
if !(op isa operator)
837-
canfold = Ref(!(op isa Symbolic))
838-
args = let canfold = canfold
839-
map(args) do x
840-
x′ = fast_substitute(x, subs; operator)
841-
canfold[] = canfold[] && !(x′ isa Symbolic)
842-
x′
843-
end
844-
end
845-
canfold[] && return op(args...)
846-
end
847-
similarterm(expr,
848-
op,
849-
args,
850-
symtype(expr);
851-
metadata = metadata(expr))
852-
end
853-
function fast_substitute(expr, pair::Pair; operator = Nothing)
854-
a, b = pair
855-
isequal(expr, a) && return b
856-
if a isa AbstractArray
857-
for (ai, bi) in zip(a, b)
858-
expr = fast_substitute(expr, ai => bi; operator)
859-
end
860-
end
861-
istree(expr) || return expr
862-
op = fast_substitute(operation(expr), pair; operator)
863-
args = SymbolicUtils.unsorted_arguments(expr)
864-
if !(op isa operator)
865-
canfold = Ref(!(op isa Symbolic))
866-
args = let canfold = canfold
867-
map(args) do x
868-
x′ = fast_substitute(x, pair; operator)
869-
canfold[] = canfold[] && !(x′ isa Symbolic)
870-
x′
871-
end
872-
end
873-
canfold[] && return op(args...)
874-
end
875-
similarterm(expr,
876-
op,
877-
args,
878-
symtype(expr);
879-
metadata = metadata(expr))
880-
end
881-
882802
normalize_to_differential(s) = s
883803

884804
function restrict_array_to_union(arr)

test/mtkparameters.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
3+
using SymbolicIndexingInterface
4+
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
6+
@parameters a b c d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
7+
@named sys = ODESystem(
8+
Equation[], t, [], [a, c, d, e, f, g, h], parameter_dependencies = [b => 2a],
9+
continuous_events = [[a ~ 0] => [c ~ 0]], defaults = Dict(a => 0.0))
10+
sys = complete(sys)
11+
12+
ivs = Dict(c => 3a, d => 4, e => [5.0, 6.0, 7.0],
13+
f => ones(Int, 3, 3), g => [0.1, 0.2, 0.3], h => "foo")
14+
15+
ps = MTKParameters(sys, ivs)
16+
@test_nowarn copy(ps)
17+
# dependent initialization, also using defaults
18+
@test getp(sys, a)(ps) == getp(sys, b)(ps) == getp(sys, c)(ps) == 0.0
19+
@test getp(sys, d)(ps) isa Int
20+
21+
ivs[a] = 1.0
22+
ps = MTKParameters(sys, ivs)
23+
@test_broken getp(sys, g) # SII bug
24+
for (p, val) in ivs
25+
isequal(p, g) && continue # broken
26+
if isequal(p, c)
27+
val = 3ivs[a]
28+
end
29+
idx = parameter_index(sys, p)
30+
# ensure getindex with `ParameterIndex` works
31+
@test ps[idx] == getp(sys, p)(ps) == val
32+
end
33+
34+
# ensure setindex! with `ParameterIndex` works
35+
ps[parameter_index(sys, a)] = 3.0
36+
@test getp(sys, a)(ps) == 3.0
37+
setp(sys, a)(ps, 1.0)
38+
39+
@test getp(sys, a)(ps) == getp(sys, b)(ps) / 2 == getp(sys, c)(ps) / 3 == 1.0
40+
41+
for (portion, values) in [(Tunable(), vcat(ones(9), [1.0, 4.0, 5.0, 6.0, 7.0]))
42+
(Discrete(), [3.0])
43+
(Constants(), [0.1, 0.2, 0.3])]
44+
buffer, repack, alias = canonicalize(portion, ps)
45+
@test alias
46+
@test sort(collect(buffer)) == values
47+
@test all(isone,
48+
canonicalize(portion, SciMLStructures.replace(portion, ps, ones(length(buffer))))[1])
49+
# make sure it is out-of-place
50+
@test sort(collect(buffer)) == values
51+
SciMLStructures.replace!(portion, ps, ones(length(buffer)))
52+
# make sure it is in-place
53+
@test all(isone, canonicalize(portion, ps)[1])
54+
repack(zeros(length(buffer)))
55+
@test all(iszero, canonicalize(portion, ps)[1])
56+
end
57+
58+
setp(sys, a)(ps, 2.0) # test set_parameter!
59+
@test getp(sys, a)(ps) == 2.0
60+
61+
setp(sys, e)(ps, 5ones(3)) # with an array
62+
@test getp(sys, e)(ps) == 5ones(3)
63+
64+
setp(sys, f[2, 2])(ps, 42) # with a sub-index
65+
@test getp(sys, f[2, 2])(ps) == 42
66+
67+
# SII bug
68+
@test_broken setp(sys, g)(ps, ones(100)) # with non-fixed-length array
69+
@test_broken getp(sys, g)(ps) == ones(100)
70+
71+
setp(sys, h)(ps, "bar") # with a non-numeric
72+
@test getp(sys, h)(ps) == "bar"

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ end
2424
@safetestset "Parsing Test" include("variable_parsing.jl")
2525
@safetestset "Simplify Test" include("simplify.jl")
2626
@safetestset "Direct Usage Test" include("direct.jl")
27-
@safetestset "SymbolicIndeingInterface test" include("symbolic_indexing_interface.jl")
2827
@safetestset "System Linearity Test" include("linearity.jl")
2928
@safetestset "Input Output Test" include("input_output_handling.jl")
3029
@safetestset "Clock Test" include("clock.jl")
@@ -72,6 +71,11 @@ end
7271
end
7372
end
7473

74+
if GROUP == "All" || GROUP == "InterfaceI" || GROUP == "SymbolicIndexingInterface"
75+
@safetestset "SymbolicIndexingInterface test" include("symbolic_indexing_interface.jl")
76+
@safetestset "MTKParameters Test" include("mtkparameters.jl")
77+
end
78+
7579
if GROUP == "All" || GROUP == "InterfaceII"
7680
println("C compilation test requires gcc available in the path!")
7781
@safetestset "C Compilation Test" include("ccompile.jl")

test/symbolic_indexing_interface.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using ModelingToolkit, SymbolicIndexingInterface, SciMLBase
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
23

3-
@parameters t a b
4-
@variables x(t)=1.0 y(t)=2.0
5-
D = Differential(t)
4+
@parameters a b
5+
@variables x(t)=1.0 y(t)=2.0 xy(t)
66
eqs = [D(x) ~ a * y + t, D(y) ~ b * t]
7-
@named odesys = ODESystem(eqs, t, [x, y], [a, b])
7+
@named odesys = ODESystem(eqs, t, [x, y], [a, b]; observed = [xy ~ x + y])
88

99
@test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y]))
1010
@test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b]))
@@ -24,6 +24,11 @@ eqs = [D(x) ~ a * y + t, D(y) ~ b * t]
2424
@test !isempty(default_values(odesys))
2525
@test default_values(odesys)[x] == 1.0
2626
@test default_values(odesys)[y] == 2.0
27+
@test isequal(default_values(odesys)[xy], x + y)
28+
29+
@named odesys = ODESystem(
30+
eqs, t, [x, y], [a, b]; defaults = [xy => 3.0], observed = [xy ~ x + y])
31+
@test default_values(odesys)[xy] == 3.0
2732

2833
@variables x y z
2934
@parameters σ ρ β
@@ -36,10 +41,10 @@ eqs = [0 ~ σ * (y - x),
3641
@test !is_time_dependent(ns)
3742

3843
@parameters x
39-
@variables t u(..)
44+
@variables u(..)
4045
Dxx = Differential(x)^2
4146
Dtt = Differential(t)^2
42-
Dt = Differential(t)
47+
Dt = D
4348

4449
#2D PDE
4550
C = 1
@@ -60,10 +65,10 @@ domains = [t ∈ (0.0, 1.0),
6065
@test pde_system.ps == SciMLBase.NullParameters()
6166
@test parameter_symbols(pde_system) == []
6267

63-
@parameters t x
68+
@parameters x
6469
@constants h = 1
6570
@variables u(..)
66-
Dt = Differential(t)
71+
Dt = D
6772
Dxx = Differential(x)^2
6873
eq = Dt(u(t, x)) ~ h * Dxx(u(t, x))
6974
bcs = [u(0, x) ~ -h * x * (x - 1) * sin(x),

0 commit comments

Comments
 (0)