Skip to content

feat!: use SciMLStructures and add new MTKParameters #2447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
Expand Down Expand Up @@ -99,6 +100,7 @@ SciMLBase = "2.0.1"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1"
SciMLStructures = "1.0"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
Expand Down
1 change: 1 addition & 0 deletions ext/MTKBifurcationKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
if !ModelingToolkit.iscomplete(nsys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
end
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
# Creates F and J functions.
ofun = NonlinearFunction(nsys; jac = jac)
F = ofun.f
Expand Down
5 changes: 4 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using PrecompileTools, Reexport
import Distributions
import FunctionWrappersWrappers
using URIs: URI
using SciMLStructures

using RecursiveArrayTools

Expand Down Expand Up @@ -62,7 +63,7 @@ using PrecompileTools, Reexport
ParallelForm, SerialForm, MultithreadedForm, build_function,
rhss, lhss, prettify_expr, gradient,
jacobian, hessian, derivative, sparsejacobian, sparsehessian,
substituter, scalarize, getparent
substituter, scalarize, getparent, hasderiv, hasdiff

import DiffEqBase: @add_kwonly
import OrdinaryDiffEq
Expand Down Expand Up @@ -128,6 +129,8 @@ include("constants.jl")
include("utils.jl")
include("domains.jl")

include("systems/index_cache.jl")
include("systems/parameter_buffer.jl")
include("systems/abstractsystem.jl")
include("systems/model_parsing.jl")
include("systems/connectors.jl")
Expand Down
4 changes: 3 additions & 1 deletion src/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ true if `x` contains only discrete-domain signals.
See also [`has_discrete_domain`](@ref)
"""
function is_discrete_domain(x)
issym(x) && return getmetadata(x, TimeDomain, false) isa Discrete
if hasmetadata(x, TimeDomain) || issym(x)
return getmetadata(x, TimeDomain, false) isa AbstractDiscrete
end
!has_discrete_domain(x) && has_continuous_domain(x)
end

Expand Down
4 changes: 2 additions & 2 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Base.hash(D::Sample, u::UInt) = hash(D.clock, xor(u, 0x055640d6d952f101))

Returns true if the expression or equation `O` contains [`Sample`](@ref) terms.
"""
hassample(O) = recursive_hasoperator(Sample, O)
hassample(O) = recursive_hasoperator(Sample, unwrap(O))

# Hold

Expand All @@ -140,7 +140,7 @@ Hold(x) = Hold()(x)

Returns true if the expression or equation `O` contains [`Hold`](@ref) terms.
"""
hashold(O) = recursive_hasoperator(Hold, O)
hashold(O) = recursive_hasoperator(Hold, unwrap(O))

# ShiftIndex

Expand Down
118 changes: 101 additions & 17 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,14 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
return unwrap(sym) in 1:length(variable_symbols(sys))
end
return any(isequal(sym), variable_symbols(sys)) ||
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
h = getsymbolhash(sym)
return haskey(ic.unknown_idx, h) || haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) || hasname(sym) && is_variable(sys, getname(sym))
else
return any(isequal(sym), variable_symbols(sys)) ||
hasname(sym) && is_variable(sys, getname(sym))
end
end

function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
Expand All @@ -202,6 +208,22 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym)
end
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
h = getsymbolhash(sym)
return if haskey(ic.unknown_idx, h)
ic.unknown_idx[h]
else
h = getsymbolhash(default_toterm(sym))
if haskey(ic.unknown_idx, h)
ic.unknown_idx[h]
elseif hasname(sym)
variable_index(sys, getname(sym))
else
nothing
end
end
end
idx = findfirst(isequal(sym), variable_symbols(sys))
if idx === nothing && hasname(sym)
idx = variable_index(sys, getname(sym))
Expand Down Expand Up @@ -230,7 +252,19 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym) in 1:length(parameter_symbols(sys))
end

if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
h = getsymbolhash(sym)
return if haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
true
else
h = getsymbolhash(default_toterm(sym))
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
hasname(sym) && is_parameter(sys, getname(sym))
end
end
return any(isequal(sym), parameter_symbols(sys)) ||
hasname(sym) && is_parameter(sys, getname(sym))
end
Expand All @@ -246,6 +280,33 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym)
end
if has_index_cache(sys) && get_index_cache(sys) !== nothing
ic = get_index_cache(sys)
h = getsymbolhash(sym)
return if haskey(ic.param_idx, h)
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
elseif haskey(ic.discrete_idx, h)
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
elseif haskey(ic.constant_idx, h)
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
elseif haskey(ic.dependent_idx, h)
ParameterIndex(nothing, ic.dependent_idx[h])
else
h = getsymbolhash(default_toterm(sym))
if haskey(ic.param_idx, h)
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
elseif haskey(ic.discrete_idx, h)
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
elseif haskey(ic.constant_idx, h)
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
elseif haskey(ic.dependent_idx, h)
ParameterIndex(nothing, ic.dependent_idx[h])
else
nothing
end
end
end

idx = findfirst(isequal(sym), parameter_symbols(sys))
if idx === nothing && hasname(sym)
idx = parameter_index(sys, getname(sym))
Expand Down Expand Up @@ -313,6 +374,9 @@ Mark a system as completed. If a system is complete, the system will no longer
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
"""
function complete(sys::AbstractSystem)
if has_index_cache(sys)
@set! sys.index_cache = IndexCache(sys)
end
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
end

Expand Down Expand Up @@ -354,7 +418,8 @@ for prop in [:eqs
:discrete_subsystems
:solved_unknowns
:split_idxs
:parent]
:parent
:index_cache]
fname1 = Symbol(:get_, prop)
fname2 = Symbol(:has_, prop)
@eval begin
Expand Down Expand Up @@ -1437,14 +1502,19 @@ function linearization_function(sys::AbstractSystem, inputs,
end
sys = ssys
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
u0, p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
p, split_idxs = split_parameters_by_type(p)
ps = parameters(sys)
if p isa Tuple
ps = Base.Fix1(getindex, ps).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, p)
ps = reorder_parameters(sys, parameters(sys))
else
p = _p
p, split_idxs = split_parameters_by_type(p)
ps = parameters(sys)
if p isa Tuple
ps = Base.Fix1(getindex, ps).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
end
end

lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
Expand All @@ -1468,7 +1538,7 @@ function linearization_function(sys::AbstractSystem, inputs,
uf = SciMLBase.UJacobianWrapper(fun, t, p)
fg_xz = ForwardDiff.jacobian(uf, u)
h_xz = ForwardDiff.jacobian(let p = p, t = t
xz -> h(xz, p, t)
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
end, u)
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
Expand All @@ -1479,7 +1549,9 @@ function linearization_function(sys::AbstractSystem, inputs,
h_xz = fg_u = zeros(0, length(inputs))
end
hp = let u = u, t = t
p -> h(u, p, t)
_hp(p) = h(u, p, t)
_hp(p::MTKParameters) = h(u, p..., t)
_hp
end
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
(f_x = fg_xz[diff_idxs, diff_idxs],
Expand Down Expand Up @@ -1521,13 +1593,14 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
kwargs...)
sts = unknowns(sys)
t = get_iv(sys)
p = parameters(sys)
ps = parameters(sys)
p = reorder_parameters(sys, ps)

fun = generate_function(sys, sts, p; expression = Val{false})[1]
dx = fun(sts, p, t)
fun = generate_function(sys, sts, ps; expression = Val{false})[1]
dx = fun(sts, p..., t)

h = build_explicit_observed_function(sys, outputs)
y = h(sts, p, t)
y = h(sts, p..., t)

fg_xz = Symbolics.jacobian(dx, sts)
fg_u = Symbolics.jacobian(dx, inputs)
Expand Down Expand Up @@ -1722,7 +1795,18 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
p = DiffEqBase.NullParameters())
x0 = merge(defaults(sys), op)
u0, p2, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)

if has_index_cache(sys) && get_index_cache(sys) !== nothing
if p isa SciMLBase.NullParameters
p = op
elseif p isa Dict
p = merge(p, op)
elseif p isa Vector && eltype(p) <: Pair
p = merge(Dict(p), op)
elseif p isa Vector
p = merge(Dict(parameters(sys) .=> p), op)
end
p2 = MTKParameters(sys, p)
end
linres = lin_fun(u0, p2, t)
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres

Expand Down
Loading