Skip to content

Bye bye ₊, use var symbols with . #2798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down Expand Up @@ -67,6 +68,7 @@ ArrayInterface = "6, 7"
BifurcationKit = "0.3"
Combinatorics = "1"
Compat = "3.42, 4"
ComponentArrays = "0.15"
ConstructionBase = "1"
DataStructures = "0.17, 0.18"
DeepDiffs = "1"
Expand Down
32 changes: 16 additions & 16 deletions docs/src/basics/Composition.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ connected = compose(
equations(connected)

#4-element Vector{Equation}:
# Differential(t)(decay1f(t)) ~ 0
# decay2f(t) ~ decay1x(t)
# Differential(t)(decay1x(t)) ~ decay1f(t) - (decay1a*(decay1x(t)))
# Differential(t)(decay2x(t)) ~ decay2f(t) - (decay2a*(decay2x(t)))
# Differential(t)(decay1.f(t)) ~ 0
# decay2.f(t) ~ decay1.x(t)
# Differential(t)(decay1.x(t)) ~ decay1.f(t) - (decay1.a*(decay1.x(t)))
# Differential(t)(decay2.x(t)) ~ decay2.f(t) - (decay2.a*(decay2.x(t)))

simplified_sys = structural_simplify(connected)

Expand Down Expand Up @@ -149,27 +149,27 @@ p = [a, b, c, d, e, f]
level0 = ODESystem(Equation[], t, [], p; name = :level0)
level1 = ODESystem(Equation[], t, [], []; name = :level1) ∘ level0
parameters(level1)
#level0a
#level0.a
#b
#c
#level0d
#level0e
#level0.d
#level0.e
#f
level2 = ODESystem(Equation[], t, [], []; name = :level2) ∘ level1
parameters(level2)
#level1level0a
#level1b
#level1.level0.a
#level1.b
#c
#level0d
#level1level0e
#level0.d
#level1.level0.e
#f
level3 = ODESystem(Equation[], t, [], []; name = :level3) ∘ level2
parameters(level3)
#level2level1level0a
#level2level1b
#level2c
#level2level0d
#level1level0e
#level2.level1.level0.a
#level2.level1.b
#level2.c
#level2.level0.d
#level1.level0.e
#f
```

Expand Down
4 changes: 2 additions & 2 deletions docs/src/basics/Events.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ When accessing variables of a sub-system, it can be useful to rename them
(alternatively, an affect function may be reused in different contexts):

```julia
[x ~ 0] => (affect!, [resistorv => :v, x], [p, q => :p2], [], ctx)
[x ~ 0] => (affect!, [resistor.v => :v, x], [p, q => :p2], [], ctx)
```

Here, the symbolic variable `resistorv` is passed as `v` while the symbolic
Here, the symbolic variable `resistor.v` is passed as `v` while the symbolic
parameter `q` has been renamed `p2`.

As an example, here is the bouncing ball example from above using the functional
Expand Down
4 changes: 2 additions & 2 deletions docs/src/tutorials/domain_connections.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ end
nothing #hide
```

To see how the domain works, we can examine the set parameter values for each of the ports `src.port` and `vol.port`. First we assemble the system using `structural_simplify()` and then check the default value of `vol.port.ρ`, whichs points to the setter value `fluidρ`. Likewise, `src.port.ρ`, will also point to the setter value `fluidρ`. Therefore, there is now only 1 defined density value `fluidρ` which sets the density for the connected network.
To see how the domain works, we can examine the set parameter values for each of the ports `src.port` and `vol.port`. First we assemble the system using `structural_simplify()` and then check the default value of `vol.port.ρ`, whichs points to the setter value `fluid.ρ`. Likewise, `src.port.ρ`, will also point to the setter value `fluid.ρ`. Therefore, there is now only 1 defined density value `fluid.ρ` which sets the density for the connected network.

```@repl domain
sys = structural_simplify(odesys)
Expand Down Expand Up @@ -181,7 +181,7 @@ end
nothing #hide
```

After running `structural_simplify()` on `actsys2`, the defaults will show that `act.port_a.ρ` points to `fluid_aρ` and `act.port_b.ρ` points to `fluid_bρ`. This is a special case, in most cases a hydraulic system will have only 1 fluid, however this simple system has 2 separate domain networks. Therefore, we can connect a single fluid to both networks. This does not interfere with the mathematical equations of the system, since no unknown variables are connected.
After running `structural_simplify()` on `actsys2`, the defaults will show that `act.port_a.ρ` points to `fluid_a.ρ` and `act.port_b.ρ` points to `fluid_b.ρ`. This is a special case, in most cases a hydraulic system will have only 1 fluid, however this simple system has 2 separate domain networks. Therefore, we can connect a single fluid to both networks. This does not interfere with the mathematical equations of the system, since no unknown variables are connected.

```@example domain
@component function ActuatorSystem1(; name)
Expand Down
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using DiffEqCallbacks
using Graphs
import ExprTools: splitdef, combinedef
import OrderedCollections
import ComponentArrays

using SymbolicIndexingInterface
using LinearAlgebra, SparseArrays, LabelledArrays
Expand Down
22 changes: 11 additions & 11 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@
In the following scenario
julia> observed(syss)
2-element Vector{Equation}:
sysy(tv) ~ sysx(tv)
y(tv) ~ sysx(tv)
sysy(t) is bound to the outer y(t) through the variable sysx(t) and should thus return is_bound(sysy(t)) = true.
When asking is_bound(sysy(t)), we know that we are looking through observed equations and can thus ask
if var is bound, if it is, then sysy(t) is also bound. This can lead to an infinite recursion, so we maintain a stack of variables we have previously asked about to be able to break cycles
sys.y(tv) ~ sys.x(tv)
y(tv) ~ sys.x(tv)
sys.y(t) is bound to the outer y(t) through the variable sys.x(t) and should thus return is_bound(sys.y(t)) = true.
When asking is_bound(sys.y(t)), we know that we are looking through observed equations and can thus ask
if var is bound, if it is, then sys.y(t) is also bound. This can lead to an infinite recursion, so we maintain a stack of variables we have previously asked about to be able to break cycles
=#
u ∈ Set(stack) && return false # Cycle detected
eqs = equations(sys)
Expand Down Expand Up @@ -119,17 +119,17 @@
nv = get_namespace(var)
nu == nv || # namespaces are the same
startswith(nv, nu) || # or nv starts with nu, i.e., nv is an inner namespace to nu
occursin('', string(getname(var))) &&
!occursin('', string(getname(u))) # or u is top level but var is internal
occursin('.', string(getname(var))) &&
!occursin('.', string(getname(u))) # or u is top level but var is internal
end

function inner_namespace(u, var)
nu = get_namespace(u)
nv = get_namespace(var)
nu == nv && return false
startswith(nv, nu) || # or nv starts with nu, i.e., nv is an inner namespace to nu
occursin('', string(getname(var))) &&
!occursin('', string(getname(u))) # or u is top level but var is internal
occursin('.', string(getname(var))) &&
!occursin('.', string(getname(u))) # or u is top level but var is internal
end

"""
Expand All @@ -139,11 +139,11 @@
"""
function get_namespace(x)
sname = string(getname(x))
parts = split(sname, '')
parts = split(sname, '.')

Check warning on line 142 in src/inputoutput.jl

View check run for this annotation

Codecov / codecov/patch

src/inputoutput.jl#L142

Added line #L142 was not covered by tests
if length(parts) == 1
return ""
end
join(parts[1:(end - 1)], '')
join(parts[1:(end - 1)], '.')

Check warning on line 146 in src/inputoutput.jl

View check run for this annotation

Codecov / codecov/patch

src/inputoutput.jl#L146

Added line #L146 was not covered by tests
end

"""
Expand Down
20 changes: 10 additions & 10 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@
return is_variable(ic, sym)
end
return any(isequal(sym), getname.(variable_symbols(sys))) ||
count('', string(sym)) == 1 &&
count(isequal(sym), Symbol.(nameof(sys), :, getname.(variable_symbols(sys)))) ==
count('.', string(sym)) == 1 &&
count(isequal(sym), Symbol.(nameof(sys), :., getname.(variable_symbols(sys)))) ==
1
end

Expand Down Expand Up @@ -399,9 +399,9 @@
idx = findfirst(isequal(sym), getname.(variable_symbols(sys)))
if idx !== nothing
return idx
elseif count('', string(sym)) == 1
elseif count('.', string(sym)) == 1

Check warning on line 402 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L402

Added line #L402 was not covered by tests
return findfirst(isequal(sym),
Symbol.(nameof(sys), :, getname.(variable_symbols(sys))))
Symbol.(nameof(sys), :., getname.(variable_symbols(sys))))
end
return nothing
end
Expand Down Expand Up @@ -431,9 +431,9 @@
return is_parameter(ic, sym)
end
return any(isequal(sym), getname.(parameter_symbols(sys))) ||
count('', string(sym)) == 1 &&
count('.', string(sym)) == 1 &&
count(isequal(sym),
Symbol.(nameof(sys), :, getname.(parameter_symbols(sys)))) == 1
Symbol.(nameof(sys), :., getname.(parameter_symbols(sys)))) == 1
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
Expand Down Expand Up @@ -466,9 +466,9 @@
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
if idx !== nothing
return idx
elseif count('', string(sym)) == 1
elseif count('.', string(sym)) == 1

Check warning on line 469 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L469

Added line #L469 was not covered by tests
return findfirst(isequal(sym),
Symbol.(nameof(sys), :, getname.(parameter_symbols(sys))))
Symbol.(nameof(sys), :., getname.(parameter_symbols(sys))))
end
return nothing
end
Expand Down Expand Up @@ -889,7 +889,7 @@
elseif x isa AbstractSystem
rename(x, renamespace(sys, nameof(x)))
else
Symbol(getname(sys), :, x)
Symbol(getname(sys), :., x)
end
end

Expand Down Expand Up @@ -1248,7 +1248,7 @@
syss = get_systems(eq.rhs)
call = Expr(:call, connect)
for sys in syss
strs = split(string(nameof(sys)), "")
strs = split(string(nameof(sys)), ".")
s = Symbol(strs[1])
for st in strs[2:end]
s = Expr(:., s, Meta.quot(Symbol(st)))
Expand Down
23 changes: 21 additions & 2 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,24 @@
end
end

# Put a wrapper on NamedTuple so that u.resistor.v indexes like u.var"resistor.v"
# Required for hierarchical, but a hack that should be fixed in the future
struct NamedTupleSymbolFix{T}
x::T
sym::Symbol
end
NamedTupleSymbolFix(x) = NamedTupleSymbolFix(x, Symbol(""))
function Base.getproperty(u::NamedTupleSymbolFix, s::Symbol)
newsym = getfield(u, :sym) == Symbol("") ? s : Symbol(getfield(u, :sym), ".", s)
x = getfield(u, :x)
if newsym in keys(x)
getproperty(x, newsym)

Check warning on line 526 in src/systems/callbacks.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/callbacks.jl#L521-L526

Added lines #L521 - L526 were not covered by tests
else
NamedTupleSymbolFix(x, newsym)

Check warning on line 528 in src/systems/callbacks.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/callbacks.jl#L528

Added line #L528 was not covered by tests
end
end
Base.getindex(u::NamedTupleSymbolFix, idxs::Int...) = getfield(u, :x)[idxs...]

Check warning on line 531 in src/systems/callbacks.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/callbacks.jl#L531

Added line #L531 was not covered by tests

function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
dvs_ind = Dict(reverse(en) for en in enumerate(dvs))
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))
Expand All @@ -526,9 +544,10 @@
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
# (MTK should keep these symbols)
u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |>
NamedTuple
NamedTuple |> NamedTupleSymbolFix

p = filter(x -> !isnothing(x[2]), collect(zip(parameters_syms(affect), p_inds))) |>
NamedTuple
NamedTuple |> NamedTupleSymbolFix

let u = u, p = p, user_affect = func(affect), ctx = context(affect)
function (integ)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function generate_isouter(sys::AbstractSystem)
function isouter(sys)::Bool
s = string(nameof(sys))
isconnector(sys) || error("$s is not a connector!")
idx = findfirst(isequal(''), s)
idx = findfirst(isequal('.'), s)
parent_name = Symbol(idx === nothing ? s : s[1:prevind(s, idx)])
parent_name in outer_connectors
end
Expand Down
4 changes: 2 additions & 2 deletions test/funcaffect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ i8 = findfirst(==(8.0), sol[:t])
ctx = [0]
function affect4!(integ, u, p, ctx)
ctx[1] += 1
@test u.resistorv == 1
@test u.resistor.v == 1
end
s1 = compose(
ODESystem(Equation[], t, [], [], name = :s1,
Expand All @@ -137,7 +137,7 @@ sol = solve(prob, Tsit5())
include("../examples/rc_model.jl")

function affect5!(integ, u, p, ctx)
@test integ.u[u.capacitorv] ≈ 0.3
@test integ.u[u.capacitor.v] ≈ 0.3
integ.ps[p.C] *= 200
end

Expand Down
4 changes: 2 additions & 2 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ end
@test get_namespace(x) == ""
@test get_namespace(sys.x) == "sys"
@test get_namespace(sys2.x) == "sys2"
@test get_namespace(sys2.sys.x) == "sys2sys"
@test get_namespace(sys21.sys1.v) == "sys21sys1"
@test get_namespace(sys2.sys.x) == "sys2.sys"
@test get_namespace(sys21.sys1.v) == "sys21.sys1"

@test !is_bound(sys, u)
@test !is_bound(sys, x)
Expand Down
6 changes: 3 additions & 3 deletions test/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ connected = ODESystem(Equation[], t, [], [], observed = connections,

sys = connected

@variables lorenz1F lorenz2F
@test pins(connected) == Variable[lorenz1F, lorenz2F]
@variables lorenz1.F lorenz2.F
@test pins(connected) == Variable[lorenz1.F, lorenz2.F]
@test isequal(observed(connected),
[connections...,
lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z,
Expand All @@ -40,7 +40,7 @@ simplifyeqs(eqs) = Equation.((x -> x.lhs).(eqs), simplify.((x -> x.rhs).(eqs)))

@test isequal(simplifyeqs(equations(connected)), simplifyeqs(collapsed_eqs))

# Variables indicated to be input/output
# Variables indicated to be input/output
@variables x [input = true]
@test hasmetadata(x, Symbolics.option_to_metadata_type(Val(:input)))
@test getmetadata(x, Symbolics.option_to_metadata_type(Val(:input))) == true
Expand Down
8 changes: 4 additions & 4 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1109,9 +1109,9 @@ function RealExpressionSystem(; name)
vars = @variables begin
x(t)
z(t)[1:1]
end # doing a collect on z doesn't work either.
@named e1 = RealExpression(y = x) # This works perfectly.
@named e2 = RealExpression(y = z[1]) # This bugs. However, `full_equations(e2)` works as expected.
end # doing a collect on z doesn't work either.
@named e1 = RealExpression(y = x) # This works perfectly.
@named e2 = RealExpression(y = z[1]) # This bugs. However, `full_equations(e2)` works as expected.
systems = [e1, e2]
ODESystem(Equation[], t, Iterators.flatten(vars), []; systems, name)
end
Expand Down Expand Up @@ -1166,7 +1166,7 @@ end
# Namespacing of array variables
@variables x(t)[1:2]
@named sys = ODESystem(Equation[], t)
@test getname(unknowns(sys, x)) == :sys₊x
@test getname(unknowns(sys, x)) == Symbol("sys.x")
@test size(unknowns(sys, x)) == size(x)

# Issue#2667
Expand Down
Loading
Loading