Skip to content

Commit 709148e

Browse files
Merge pull request #2603 from AayushSabharwal/as/fix-namespacing-defaults
fix: fix variable namespacing issues
2 parents 7bc758b + 8ef7908 commit 709148e

File tree

4 files changed

+112
-17
lines changed

4 files changed

+112
-17
lines changed

src/systems/abstractsystem.jl

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -746,38 +746,64 @@ end
746746
abstract type SymScope end
747747

748748
struct LocalScope <: SymScope end
749-
function LocalScope(sym::Union{Num, Symbolic})
749+
function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
750750
apply_to_variables(sym) do sym
751-
setmetadata(sym, SymScope, LocalScope())
751+
if istree(sym) && operation(sym) === getindex
752+
args = arguments(sym)
753+
a1 = setmetadata(args[1], SymScope, LocalScope())
754+
similarterm(sym, operation(sym), [a1, args[2:end]...])
755+
else
756+
setmetadata(sym, SymScope, LocalScope())
757+
end
752758
end
753759
end
754760

755761
struct ParentScope <: SymScope
756762
parent::SymScope
757763
end
758-
function ParentScope(sym::Union{Num, Symbolic})
764+
function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
759765
apply_to_variables(sym) do sym
760-
setmetadata(sym, SymScope,
761-
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
766+
if istree(sym) && operation(sym) === getindex
767+
args = arguments(sym)
768+
a1 = setmetadata(args[1], SymScope,
769+
ParentScope(getmetadata(value(args[1]), SymScope, LocalScope())))
770+
similarterm(sym, operation(sym), [a1, args[2:end]...])
771+
else
772+
setmetadata(sym, SymScope,
773+
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
774+
end
762775
end
763776
end
764777

765778
struct DelayParentScope <: SymScope
766779
parent::SymScope
767780
N::Int
768781
end
769-
function DelayParentScope(sym::Union{Num, Symbolic}, N)
782+
function DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}, N)
770783
apply_to_variables(sym) do sym
771-
setmetadata(sym, SymScope,
772-
DelayParentScope(getmetadata(value(sym), SymScope, LocalScope()), N))
784+
if istree(sym) && operation(sym) == getindex
785+
args = arguments(sym)
786+
a1 = setmetadata(args[1], SymScope,
787+
DelayParentScope(getmetadata(value(args[1]), SymScope, LocalScope()), N))
788+
similarterm(sym, operation(sym), [a1, args[2:end]...])
789+
else
790+
setmetadata(sym, SymScope,
791+
DelayParentScope(getmetadata(value(sym), SymScope, LocalScope()), N))
792+
end
773793
end
774794
end
775-
DelayParentScope(sym::Union{Num, Symbolic}) = DelayParentScope(sym, 1)
795+
DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) = DelayParentScope(sym, 1)
776796

777797
struct GlobalScope <: SymScope end
778-
function GlobalScope(sym::Union{Num, Symbolic})
798+
function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
779799
apply_to_variables(sym) do sym
780-
setmetadata(sym, SymScope, GlobalScope())
800+
if istree(sym) && operation(sym) == getindex
801+
args = arguments(sym)
802+
a1 = setmetadata(args[1], SymScope, GlobalScope())
803+
similarterm(sym, operation(sym), [a1, args[2:end]...])
804+
else
805+
setmetadata(sym, SymScope, GlobalScope())
806+
end
781807
end
782808
end
783809

@@ -793,6 +819,11 @@ function renamespace(sys, x)
793819
return similarterm(x, operation(x),
794820
Any[renamespace(sys, only(arguments(x)))])::T
795821
end
822+
if istree(x) && operation(x) === getindex
823+
args = arguments(x)
824+
return similarterm(
825+
x, operation(x), vcat(renamespace(sys, args[1]), args[2:end]))::T
826+
end
796827
let scope = getmetadata(x, SymScope, LocalScope())
797828
if scope isa LocalScope
798829
rename(x, renamespace(getname(sys), getname(x)))::T
@@ -849,7 +880,8 @@ function namespace_assignment(eq::Assignment, sys)
849880
Assignment(_lhs, _rhs)
850881
end
851882

852-
function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys))
883+
function namespace_expr(
884+
O, sys, n = nameof(sys); ivs = independent_variables(sys))
853885
O = unwrap(O)
854886
if any(isequal(O), ivs)
855887
return O
@@ -1500,8 +1532,7 @@ function default_to_parentscope(v)
15001532
uv isa Symbolic || return v
15011533
apply_to_variables(v) do sym
15021534
if !hasmetadata(uv, SymScope)
1503-
setmetadata(sym, SymScope,
1504-
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
1535+
ParentScope(sym)
15051536
else
15061537
sym
15071538
end

test/input_output_handling.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ end
2020
@named sys = ODESystem([D(x) ~ -x + u], t) # both u and x are unbound
2121
@named sys1 = ODESystem([D(x) ~ -x + v[1] + v[2]], t) # both v and x are unbound
2222
@named sys2 = ODESystem([D(x) ~ -sys.x], t, systems = [sys]) # this binds sys.x in the context of sys2, sys2.x is still unbound
23-
@named sys21 = ODESystem([D(x) ~ -sys.x], t, systems = [sys1]) # this binds sys.x in the context of sys2, sys2.x is still unbound
23+
@named sys21 = ODESystem([D(x) ~ -sys1.x], t, systems = [sys1]) # this binds sys.x in the context of sys2, sys2.x is still unbound
2424
@named sys3 = ODESystem([D(x) ~ -sys.x + sys.u], t, systems = [sys]) # This binds both sys.x and sys.u
25-
@named sys31 = ODESystem([D(x) ~ -sys.x + sys1.v[1]], t, systems = [sys1]) # This binds both sys.x and sys1.v[1]
25+
@named sys31 = ODESystem([D(x) ~ -sys1.x + sys1.v[1]], t, systems = [sys1]) # This binds both sys.x and sys1.v[1]
2626

2727
@named sys4 = ODESystem([D(x) ~ -sys.x, u ~ sys.u], t, systems = [sys]) # This binds both sys.x and sys3.u, this system is one layer deeper than the previous. u is directly forwarded to sys.u, and in this case sys.u is bound while u is not
2828

@@ -43,7 +43,7 @@ end
4343
@test is_bound(sys2, sys.x)
4444
@test !is_bound(sys2, sys.u)
4545
@test !is_bound(sys2, sys2.sys.u)
46-
@test is_bound(sys21, sys.x)
46+
@test is_bound(sys21, sys1.x)
4747
@test !is_bound(sys21, sys1.v[1])
4848
@test !is_bound(sys21, sys1.v[2])
4949
@test is_bound(sys31, sys1.v[1])

test/odesystem.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,3 +1066,57 @@ prob = SteadyStateProblem(sys, u0, p)
10661066
@test prob isa SteadyStateProblem
10671067
prob = SteadyStateProblem(ODEProblem(sys, u0, (0.0, 10.0), p))
10681068
@test prob isa SteadyStateProblem
1069+
1070+
# Issue#2344
1071+
using ModelingToolkitStandardLibrary.Blocks
1072+
1073+
function FML2(; name)
1074+
@parameters begin
1075+
k2[1:1] = [1.0]
1076+
end
1077+
systems = @named begin
1078+
constant = Constant(k = k2[1])
1079+
end
1080+
@variables begin
1081+
x(t) = 0
1082+
end
1083+
eqs = [
1084+
D(x) ~ constant.output.u + k2[1]
1085+
]
1086+
ODESystem(eqs, t; systems, name)
1087+
end
1088+
1089+
@mtkbuild model = FML2()
1090+
1091+
@test isequal(ModelingToolkit.defaults(model)[model.constant.k], model.k2[1])
1092+
@test_nowarn ODEProblem(model, [], (0.0, 10.0))
1093+
1094+
# Issue#2477
1095+
function RealExpression(; name, y)
1096+
vars = @variables begin
1097+
u(t)
1098+
end
1099+
eqns = [
1100+
u ~ y
1101+
]
1102+
sys = ODESystem(eqns, t, vars, []; name)
1103+
end
1104+
1105+
function RealExpressionSystem(; name)
1106+
vars = @variables begin
1107+
x(t)
1108+
z(t)[1:1]
1109+
end # doing a collect on z doesn't work either.
1110+
@named e1 = RealExpression(y = x) # This works perfectly.
1111+
@named e2 = RealExpression(y = z[1]) # This bugs. However, `full_equations(e2)` works as expected.
1112+
systems = [e1, e2]
1113+
ODESystem(Equation[], t, Iterators.flatten(vars), []; systems, name)
1114+
end
1115+
1116+
@named sys = RealExpressionSystem()
1117+
sys = complete(sys)
1118+
@test Set(equations(sys)) == Set([sys.e1.u ~ sys.x, sys.e2.u ~ sys.z[1]])
1119+
tearing_state = TearingState(expand_connections(sys))
1120+
ts_vars = tearing_state.fullvars
1121+
orig_vars = unknowns(sys)
1122+
@test isempty(setdiff(ts_vars, orig_vars))

test/variable_scope.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,13 @@ ps = ModelingToolkit.getname.(parameters(level3))
7373
@test isequal(ps[4], :level2₊level0₊d)
7474
@test isequal(ps[5], :level1₊level0₊e)
7575
@test isequal(ps[6], :f)
76+
77+
# Issue@2252
78+
# Tests from PR#2354
79+
@parameters xx[1:2]
80+
arr_p = [ParentScope(xx[1]), xx[2]]
81+
arr0 = ODESystem(Equation[], t, [], arr_p; name = :arr0)
82+
arr1 = ODESystem(Equation[], t, [], []; name = :arr1) arr0
83+
arr_ps = ModelingToolkit.getname.(parameters(arr1))
84+
@test isequal(arr_ps[1], Symbol("xx"))
85+
@test isequal(arr_ps[2], Symbol("arr0₊xx"))

0 commit comments

Comments
 (0)