Skip to content

Add conservative kwarg in structural_transformation #2805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ end

Perform Pantelides algorithm.
"""
function pantelides!(state::TransformationState; finalize = true, maxiters = 8000)
function pantelides!(
state::TransformationState; finalize = true, maxiters = 8000, kwargs...)
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
neqs = nsrcs(graph)
nvars = nv(var_to_diff)
Expand Down Expand Up @@ -181,7 +182,7 @@ function pantelides!(state::TransformationState; finalize = true, maxiters = 800
ecolor[eq] || continue
# introduce a new equation
neqs += 1
eq_derivative!(state, eq)
eq_derivative!(state, eq; kwargs...)
end

for var in eachindex(vcolor)
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function dummy_derivative_graph!(state::TransformationState, jac = nothing;
state_priority = nothing, log = Val(false), kwargs...)
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
complete!(state.structure)
var_eq_matching = complete(pantelides!(state))
var_eq_matching = complete(pantelides!(state; kwargs...))
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log)
end

Expand Down
5 changes: 3 additions & 2 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function eq_derivative_graph!(s::SystemStructure, eq::Int)
return eq_diff
end

function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...)
s = ts.structure

eq_diff = eq_derivative_graph!(s, ieq)
Expand All @@ -75,7 +75,8 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
add_edge!(s.graph, eq_diff, s.var_to_diff[var])
end
s.solvable_graph === nothing ||
find_eq_solvables!(ts, eq_diff; may_be_zero = true, allow_symbolic = false)
find_eq_solvables!(
ts, eq_diff; may_be_zero = true, allow_symbolic = false, kwargs...)

return eq_diff
end
Expand Down
5 changes: 4 additions & 1 deletion src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ end

function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = nothing;
may_be_zero = false,
allow_symbolic = false, allow_parameter = true, kwargs...)
allow_symbolic = false, allow_parameter = true,
conservative = false,
kwargs...)
fullvars = state.fullvars
@unpack graph, solvable_graph = state.structure
eq = equations(state)[ieq]
Expand Down Expand Up @@ -220,6 +222,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
coeffs === nothing || push!(coeffs, convert(Int, a))
else
all_int_vars = false
conservative && continue
end
if a != 0
add_edge!(solvable_graph, ieq, j)
Expand Down
6 changes: 4 additions & 2 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ $(SIGNATURES)
Structurally simplify algebraic equations in a system and compute the
topological sort of the observed equations. When `simplify=true`, the `simplify`
function will be applied during the tearing process. It also takes kwargs
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
types during tearing.
`allow_symbolic=false`, `allow_parameter=true`, and `conservative=false` which
limits the coefficient types during tearing. In particular, `conservative=true`
limits tearing to only solve for trivial linear systems where the coefficient
has the absolute value of ``1``.

The optional argument `io` may take a tuple `(inputs, outputs)`.
This will convert all `inputs` to parameters and allow them to be unconnected, i.e.,
Expand Down
9 changes: 6 additions & 3 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,15 +691,18 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
ModelingToolkit.check_consistency(state, orig_inputs)
end
if fully_determined && dummy_derivative
sys = ModelingToolkit.dummy_derivative(sys, state; simplify, mm, check_consistency)
sys = ModelingToolkit.dummy_derivative(
sys, state; simplify, mm, check_consistency, kwargs...)
elseif fully_determined
var_eq_matching = pantelides!(state; finalize = false, kwargs...)
sys = pantelides_reassemble(state, var_eq_matching)
state = TearingState(sys)
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
sys = ModelingToolkit.dummy_derivative(sys, state; simplify, mm, check_consistency)
sys = ModelingToolkit.dummy_derivative(
sys, state; simplify, mm, check_consistency, kwargs...)
else
sys = ModelingToolkit.tearing(sys, state; simplify, mm, check_consistency)
sys = ModelingToolkit.tearing(
sys, state; simplify, mm, check_consistency, kwargs...)
end
fullunknowns = [map(eq -> eq.lhs, observed(sys)); unknowns(sys)]
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullunknowns)
Expand Down
9 changes: 9 additions & 0 deletions test/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,12 @@ eqs = [u3 ~ u1 + u2, u4 ~ 2 * (u1 + u2), u3 + u4 ~ 3 * (u1 + u2)]
@named ns = NonlinearSystem(eqs, [u1, u2], [u3, u4])
sys = structural_simplify(ns; fully_determined = false)
@test length(unknowns(sys)) == 1

# Conservative
@variables X(t)
alg_eqs = [1 ~ 2X]
@named ns = NonlinearSystem(alg_eqs)
sys = structural_simplify(ns)
@test length(equations(sys)) == 0
sys = structural_simplify(ns; conservative = true)
@test length(equations(sys)) == 1
Loading