Skip to content

Commit 7eff5dd

Browse files
refactor: use fixpoint_sub and fast_substitute from Symbolics
1 parent 4db0053 commit 7eff5dd

File tree

6 files changed

+8
-98
lines changed

6 files changed

+8
-98
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
104104
StaticArrays = "0.10, 0.11, 0.12, 1.0"
105105
SymbolicIndexingInterface = "0.3.11"
106106
SymbolicUtils = "1.0"
107-
Symbolics = "5.24"
107+
Symbolics = "5.26"
108108
URIs = "1"
109109
UnPack = "0.1, 1.0"
110110
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ using PrecompileTools, Reexport
5656
VariableSource, getname, variable, Connection, connect,
5757
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
5858
initial_state, transition, activeState, entry,
59-
ticksInState, timeInState
59+
ticksInState, timeInState, fixpoint_sub, fast_substitute
6060
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
6161
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
6262
tosymbol, lower_varname, diff2term, var_from_nested_derivative,

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module StructuralTransformations
33
using Setfield: @set!, @set
44
using UnPack: @unpack
55

6-
using Symbolics: unwrap, linear_expansion
6+
using Symbolics: unwrap, linear_expansion, fast_substitute
77
using SymbolicUtils
88
using SymbolicUtils.Code
99
using SymbolicUtils.Rewriters
@@ -23,7 +23,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
2525
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
26-
fast_substitute, get_fullvars, has_equations, observed,
26+
get_fullvars, has_equations, observed,
2727
Schedule
2828

2929
using ModelingToolkit.BipartiteGraphs

src/structural_transformation/symbolics_tearing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
8181
end
8282

8383
function tearing_sub(expr, dict, s)
84-
expr = ModelingToolkit.fixpoint_sub(expr, dict)
84+
expr = Symbolics.fixpoint_sub(expr, dict)
8585
s ? simplify(expr) : expr
8686
end
8787

@@ -439,7 +439,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
439439
order, lv = var_order(iv)
440440
dx = D(simplify_shifts(lower_varname_withshift(
441441
fullvars[lv], idep, order - 1)))
442-
eq = dx ~ simplify_shifts(ModelingToolkit.fixpoint_sub(
442+
eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
443443
Symbolics.solve_for(neweqs[ieq],
444444
fullvars[iv]),
445445
total_sub; operator = ModelingToolkit.Shift))
@@ -467,7 +467,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
467467
@warn "Tearing: solving $eq for $var is singular!"
468468
else
469469
rhs = -b / a
470-
neweq = var ~ ModelingToolkit.fixpoint_sub(
470+
neweq = var ~ Symbolics.fixpoint_sub(
471471
simplify ?
472472
Symbolics.simplify(rhs) : rhs,
473473
total_sub; operator = ModelingToolkit.Shift)
@@ -481,7 +481,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
481481
if !(eq.lhs isa Number && eq.lhs == 0)
482482
rhs = eq.rhs - eq.lhs
483483
end
484-
push!(alge_eqs, 0 ~ ModelingToolkit.fixpoint_sub(rhs, total_sub))
484+
push!(alge_eqs, 0 ~ Symbolics.fixpoint_sub(rhs, total_sub))
485485
push!(algeeq_idxs, ieq)
486486
end
487487
end

src/systems/alias_elimination.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -462,13 +462,3 @@ function observed2graph(eqs, unknowns)
462462

463463
return graph, assigns
464464
end
465-
466-
function fixpoint_sub(x, dict; operator = Nothing)
467-
y = fast_substitute(x, dict; operator)
468-
while !isequal(x, y)
469-
y = x
470-
x = fast_substitute(y, dict; operator)
471-
end
472-
473-
return x
474-
end

src/utils.jl

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -799,86 +799,6 @@ function fold_constants(ex)
799799
end
800800
end
801801

802-
# Symbolics needs to call unwrap on the substitution rules, but most of the time
803-
# we don't want to do that in MTK.
804-
const Eq = Union{Equation, Inequality}
805-
function fast_substitute(eq::Eq, subs; operator = Nothing)
806-
if eq isa Inequality
807-
Inequality(fast_substitute(eq.lhs, subs; operator),
808-
fast_substitute(eq.rhs, subs; operator),
809-
eq.relational_op)
810-
else
811-
Equation(fast_substitute(eq.lhs, subs; operator),
812-
fast_substitute(eq.rhs, subs; operator))
813-
end
814-
end
815-
function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq}
816-
T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator))
817-
end
818-
function fast_substitute(eqs::AbstractArray, subs; operator = Nothing)
819-
fast_substitute.(eqs, (subs,); operator)
820-
end
821-
function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing)
822-
fast_substitute.(eqs, (subs,); operator)
823-
end
824-
for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair))
825-
@eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing)
826-
fast_substitute(value(expr), subs; operator)
827-
end
828-
end
829-
function fast_substitute(expr, subs; operator = Nothing)
830-
if (_val = get(subs, expr, nothing)) !== nothing
831-
return _val
832-
end
833-
istree(expr) || return expr
834-
op = fast_substitute(operation(expr), subs; operator)
835-
args = SymbolicUtils.unsorted_arguments(expr)
836-
if !(op isa operator)
837-
canfold = Ref(!(op isa Symbolic))
838-
args = let canfold = canfold
839-
map(args) do x
840-
x′ = fast_substitute(x, subs; operator)
841-
canfold[] = canfold[] && !(x′ isa Symbolic)
842-
x′
843-
end
844-
end
845-
canfold[] && return op(args...)
846-
end
847-
similarterm(expr,
848-
op,
849-
args,
850-
symtype(expr);
851-
metadata = metadata(expr))
852-
end
853-
function fast_substitute(expr, pair::Pair; operator = Nothing)
854-
a, b = pair
855-
isequal(expr, a) && return b
856-
if a isa AbstractArray
857-
for (ai, bi) in zip(a, b)
858-
expr = fast_substitute(expr, ai => bi; operator)
859-
end
860-
end
861-
istree(expr) || return expr
862-
op = fast_substitute(operation(expr), pair; operator)
863-
args = SymbolicUtils.unsorted_arguments(expr)
864-
if !(op isa operator)
865-
canfold = Ref(!(op isa Symbolic))
866-
args = let canfold = canfold
867-
map(args) do x
868-
x′ = fast_substitute(x, pair; operator)
869-
canfold[] = canfold[] && !(x′ isa Symbolic)
870-
x′
871-
end
872-
end
873-
canfold[] && return op(args...)
874-
end
875-
similarterm(expr,
876-
op,
877-
args,
878-
symtype(expr);
879-
metadata = metadata(expr))
880-
end
881-
882802
normalize_to_differential(s) = s
883803

884804
function restrict_array_to_union(arr)

0 commit comments

Comments
 (0)