Skip to content

Commit 5d7693e

Browse files
feat: add support for Arr and array expressions in fast_substitute
1 parent 62dee1b commit 5d7693e

File tree

1 file changed

+35
-10
lines changed

1 file changed

+35
-10
lines changed

src/utils.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -815,21 +815,46 @@ end
815815
function fast_substitute(eqs::AbstractArray, subs; operator = Nothing)
816816
fast_substitute.(eqs, (subs,); operator)
817817
end
818-
function fast_substitute(a, b; operator = Nothing)
819-
b = Dict(value(k) => value(v) for (k, v) in b)
820-
a = value(a)
821-
haskey(b, a) && return b[a]
822-
for _b in b
823-
a = fast_substitute(a, _b; operator)
818+
function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing)
819+
fast_substitute.(eqs, (subs,); operator)
820+
end
821+
for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair))
822+
@eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing)
823+
fast_substitute(value(expr), subs; operator)
824+
end
825+
end
826+
function fast_substitute(expr, subs; operator = Nothing)
827+
if (_val = get(subs, expr, nothing)) !== nothing
828+
return _val
829+
end
830+
istree(expr) || return expr
831+
op = fast_substitute(operation(expr), subs; operator)
832+
args = SymbolicUtils.unsorted_arguments(expr)
833+
if !(op isa operator)
834+
canfold = Ref(!(op isa Symbolic))
835+
args = let canfold = canfold
836+
map(args) do x
837+
x′ = fast_substitute(x, subs; operator)
838+
canfold[] = canfold[] && !(x′ isa Symbolic)
839+
x′
840+
end
841+
end
842+
canfold[] && return op(args...)
824843
end
825-
a
844+
similarterm(expr,
845+
op,
846+
args,
847+
symtype(expr);
848+
metadata = metadata(expr))
826849
end
827850
function fast_substitute(expr, pair::Pair; operator = Nothing)
828851
a, b = pair
829-
a = value(a)
830-
b = value(b)
831852
isequal(expr, a) && return b
832-
853+
if a isa AbstractArray
854+
for (ai, bi) in zip(a, b)
855+
expr = fast_substitute(expr, ai => bi; operator)
856+
end
857+
end
833858
istree(expr) || return expr
834859
op = fast_substitute(operation(expr), pair; operator)
835860
args = SymbolicUtils.unsorted_arguments(expr)

0 commit comments

Comments
 (0)