Skip to content

Commit af10f44

Browse files
fix: fix SymScope metadata for array variables
Co-authored-by: contradict <[email protected]>
1 parent 7fb1d99 commit af10f44

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

src/systems/abstractsystem.jl

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -741,38 +741,64 @@ end
741741
abstract type SymScope end
742742

743743
struct LocalScope <: SymScope end
744-
function LocalScope(sym::Union{Num, Symbolic})
744+
function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
745745
apply_to_variables(sym) do sym
746-
setmetadata(sym, SymScope, LocalScope())
746+
if istree(sym) && operation(sym) === getindex
747+
args = arguments(sym)
748+
a1 = setmetadata(args[1], SymScope, LocalScope())
749+
similarterm(sym, operation(sym), [a1, args[2:end]...])
750+
else
751+
setmetadata(sym, SymScope, LocalScope())
752+
end
747753
end
748754
end
749755

750756
struct ParentScope <: SymScope
751757
parent::SymScope
752758
end
753-
function ParentScope(sym::Union{Num, Symbolic})
759+
function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
754760
apply_to_variables(sym) do sym
755-
setmetadata(sym, SymScope,
756-
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
761+
if istree(sym) && operation(sym) == getindex
762+
args = arguments(sym)
763+
a1 = setmetadata(args[1], SymScope,
764+
ParentScope(getmetadata(value(args[1]), SymScope, LocalScope())))
765+
similarterm(sym, operation(sym), [a1, args[2:end]...])
766+
else
767+
setmetadata(sym, SymScope,
768+
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
769+
end
757770
end
758771
end
759772

760773
struct DelayParentScope <: SymScope
761774
parent::SymScope
762775
N::Int
763776
end
764-
function DelayParentScope(sym::Union{Num, Symbolic}, N)
777+
function DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}, N)
765778
apply_to_variables(sym) do sym
766-
setmetadata(sym, SymScope,
767-
DelayParentScope(getmetadata(value(sym), SymScope, LocalScope()), N))
779+
if istree(sym) && operation(sym) == getindex
780+
args = arguments(sym)
781+
a1 = setmetadata(args[1], SymScope,
782+
DelayParentScope(getmetadata(value(args[1]), SymScope, LocalScope()), N))
783+
similarterm(sym, operation(sym), [a1, args[2:end]...])
784+
else
785+
setmetadata(sym, SymScope,
786+
DelayParentScope(getmetadata(value(sym), SymScope, LocalScope()), N))
787+
end
768788
end
769789
end
770-
DelayParentScope(sym::Union{Num, Symbolic}) = DelayParentScope(sym, 1)
790+
DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) = DelayParentScope(sym, 1)
771791

772792
struct GlobalScope <: SymScope end
773-
function GlobalScope(sym::Union{Num, Symbolic})
793+
function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
774794
apply_to_variables(sym) do sym
775-
setmetadata(sym, SymScope, GlobalScope())
795+
if istree(sym) && operation(sym) == getindex
796+
args = arguments(sym)
797+
a1 = setmetadata(args[1], SymScope, GlobalScope())
798+
similarterm(sym, operation(sym), [a1, args[2:end]...])
799+
else
800+
setmetadata(sym, SymScope, GlobalScope())
801+
end
776802
end
777803
end
778804

@@ -1495,8 +1521,7 @@ function default_to_parentscope(v)
14951521
uv isa Symbolic || return v
14961522
apply_to_variables(v) do sym
14971523
if !hasmetadata(uv, SymScope)
1498-
setmetadata(sym, SymScope,
1499-
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
1524+
ParentScope(sym)
15001525
else
15011526
sym
15021527
end

test/variable_scope.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,13 @@ ps = ModelingToolkit.getname.(parameters(level3))
7373
@test isequal(ps[4], :level2₊level0₊d)
7474
@test isequal(ps[5], :level1₊level0₊e)
7575
@test isequal(ps[6], :f)
76+
77+
# Issue@2252
78+
# Tests from PR#2354
79+
@parameters xx[1:2]
80+
arr_p = [ParentScope(xx[1]), xx[2]]
81+
arr0 = ODESystem(Equation[], t, [], arr_p; name = :arr0)
82+
arr1 = ODESystem(Equation[], t, [], []; name = :arr1) arr0
83+
arr_ps = ModelingToolkit.getname.(parameters(arr1))
84+
@test isequal(arr_ps[1], Symbol("xx"))
85+
@test isequal(arr_ps[2], Symbol("arr0₊xx"))

0 commit comments

Comments
 (0)