Skip to content

Commit d47af77

Browse files
committed
feat: support array of components in @mtkmodel
- for loop or a list comprehension can be used to declare component arrays
1 parent 26960c8 commit d47af77

File tree

3 files changed

+129
-48
lines changed

3 files changed

+129
-48
lines changed

src/systems/abstractsystem.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,7 @@ function _named(name, call, runtime = false)
14021402
end
14031403
end
14041404

1405-
function _named_idxs(name::Symbol, idxs, call)
1405+
function _named_idxs(name::Symbol, idxs, call; extra_args = "")
14061406
if call.head !== :->
14071407
throw(ArgumentError("Not an anonymous function"))
14081408
end
@@ -1413,7 +1413,10 @@ function _named_idxs(name::Symbol, idxs, call)
14131413
ex = Base.Cartesian.poplinenum(ex)
14141414
ex = _named(:(Symbol($(Meta.quot(name)), :_, $sym)), ex, true)
14151415
ex = Base.Cartesian.poplinenum(ex)
1416-
:($name = $map($sym -> $ex, $idxs))
1416+
:($name = map($sym -> begin
1417+
$extra_args
1418+
$ex
1419+
end, $idxs))
14171420
end
14181421

14191422
function single_named_expr(expr)

src/systems/model_parsing.jl

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ function _model_macro(mod, name, expr, isconnector)
4040
:kwargs => Dict{Symbol, Dict}(),
4141
:structural_parameters => Dict{Symbol, Dict}()
4242
)
43-
comps = Symbol[]
43+
comps = Union{Symbol, Expr}[]
4444
ext = Ref{Any}(nothing)
4545
eqs = Expr[]
4646
icon = Ref{Union{String, URI}}()
@@ -745,7 +745,7 @@ end
745745

746746
### Parsing Components:
747747

748-
function component_args!(a, b, expr, varexpr, kwargs)
748+
function component_args!(a, b, varexpr, kwargs; index_name = nothing)
749749
# Whenever `b` is a function call, skip the first arg aka the function name.
750750
# Whenever it is a kwargs list, include it.
751751
start = b.head == :call ? 2 : 1
@@ -754,73 +754,115 @@ function component_args!(a, b, expr, varexpr, kwargs)
754754
arg isa LineNumberNode && continue
755755
MLStyle.@match arg begin
756756
x::Symbol || Expr(:kw, x) => begin
757-
_v = _rename(a, x)
758-
b.args[i] = Expr(:kw, x, _v)
759-
push!(varexpr.args, :((@isdefined $x) && ($_v = $x)))
760-
push!(kwargs, Expr(:kw, _v, nothing))
761-
# dict[:kwargs][_v] = nothing
757+
varname, _varname = _rename(a, x)
758+
b.args[i] = Expr(:kw, x, _varname)
759+
push!(varexpr.args, :((if $varname !== nothing
760+
$_varname = $varname
761+
elseif @isdefined $x
762+
# Allow users to define a var in `structural_parameters` and set
763+
# that as positional arg of subcomponents; it is useful for cases
764+
# where it needs to be passed to multiple subcomponents.
765+
$_varname = $x
766+
end)))
767+
push!(kwargs, Expr(:kw, varname, nothing))
768+
# dict[:kwargs][varname] = nothing
762769
end
763770
Expr(:parameters, x...) => begin
764-
component_args!(a, arg, expr, varexpr, kwargs)
771+
component_args!(a, arg, varexpr, kwargs)
765772
end
766773
Expr(:kw, x, y) => begin
767-
_v = _rename(a, x)
768-
b.args[i] = Expr(:kw, x, _v)
769-
push!(varexpr.args, :($_v = $_v === nothing ? $y : $_v))
770-
push!(kwargs, Expr(:kw, _v, nothing))
771-
# dict[:kwargs][_v] = nothing
774+
varname, _varname = _rename(a, x)
775+
b.args[i] = Expr(:kw, x, _varname)
776+
if isnothing(index_name)
777+
push!(varexpr.args, :($_varname = $varname === nothing ? $y : $varname))
778+
else
779+
push!(varexpr.args,
780+
:($_varname = $varname === nothing ? $y : $varname[$index_name]))
781+
end
782+
push!(kwargs, Expr(:kw, varname, nothing))
783+
# dict[:kwargs][varname] = nothing
772784
end
773785
_ => error("Could not parse $arg of component $a")
774786
end
775787
end
776788
end
777789

778-
function _parse_components!(exprs, body, kwargs)
779-
expr = Expr(:block)
790+
model_name(name, range) = Symbol.(name, :_, collect(range))
791+
792+
function _parse_components!(body, kwargs)
793+
local expr
780794
varexpr = Expr(:block)
781-
# push!(exprs, varexpr)
782-
comps = Vector{Union{Symbol, Expr}}[]
795+
comps = Vector{Union{Union{Expr, Symbol}, Expr}}[]
783796
comp_names = []
784797

785-
for arg in body.args
786-
arg isa LineNumberNode && continue
787-
MLStyle.@match arg begin
788-
Expr(:block) => begin
789-
# TODO: Do we need this?
790-
error("Multiple `@components` block detected within a single block")
791-
end
792-
Expr(:(=), a, b) => begin
793-
arg = deepcopy(arg)
794-
b = deepcopy(arg.args[2])
798+
Base.remove_linenums!(body)
799+
arg = body.args[end]
800+
801+
MLStyle.@match arg begin
802+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d)))) => begin
803+
array_varexpr = Expr(:block)
795804

796-
component_args!(a, b, expr, varexpr, kwargs)
805+
push!(comp_names, :($a...))
806+
push!(comps, [a, b.args[1], d])
807+
b = deepcopy(b)
797808

798-
arg.args[2] = b
799-
push!(expr.args, arg)
800-
push!(comp_names, a)
809+
component_args!(a, b, array_varexpr, kwargs; index_name = c)
810+
811+
expr = _named_idxs(a, d, :($c -> $b); extra_args = array_varexpr)
812+
end
813+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:filter, e, Expr(:(=), c, d))))) => begin
814+
error("List comprehensions with conditional statements aren't supported.")
815+
end
816+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d), e...))) => begin
817+
# Note that `e` is of the form `Tuple{Expr(:(=), c, d)}`
818+
error("More than one index isn't supported while building component array")
819+
end
820+
Expr(:block) => begin
821+
# TODO: Do we need this?
822+
error("Multiple `@components` block detected within a single block")
823+
end
824+
Expr(:(=), a, Expr(:for, Expr(:(=), c, d), b)) => begin
825+
Base.remove_linenums!(b)
826+
array_varexpr = Expr(:block)
827+
push!(array_varexpr.args, b.args[1:(end - 1)]...)
828+
push!(comp_names, :($a...))
829+
push!(comps, [a, b.args[end].args[1], d])
830+
b = deepcopy(b)
831+
832+
component_args!(a, b.args[end], array_varexpr, kwargs; index_name = c)
833+
834+
expr = _named_idxs(a, d, :($c -> $(b.args[end])); extra_args = array_varexpr)
835+
end
836+
Expr(:(=), a, b) => begin
837+
arg = deepcopy(arg)
838+
b = deepcopy(arg.args[2])
839+
840+
component_args!(a, b, varexpr, kwargs)
841+
842+
arg.args[2] = b
843+
expr = :(@named $arg)
844+
push!(comp_names, a)
801845
if (isa(b.args[1], Symbol) || Meta.isexpr(b.args[1], :.))
802-
push!(comps, [a, b.args[1]])
846+
push!(comps, [a, b.args[1]])
803847
end
804-
end
805-
_ => error("Couldn't parse the component body: $arg")
806848
end
849+
_ => error("Couldn't parse the component body: $arg")
807850
end
851+
808852
return comp_names, comps, expr, varexpr
809853
end
810854

811855
function push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
812856
blk = Expr(:block)
813857
push!(blk.args, varexpr)
814-
push!(blk.args, :(@named begin
815-
$(expr_vec.args...)
816-
end))
858+
push!(blk.args, expr_vec)
817859
push!(blk.args, :($push!(systems, $(comp_names...))))
818860
push!(ifexpr.args, blk)
819861
end
820862

821863
function handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition = nothing)
822864
push!(ifexpr.args, condition)
823-
comp_names, comps, expr_vec, varexpr = _parse_components!(ifexpr, x, kwargs)
865+
comp_names, comps, expr_vec, varexpr = _parse_components!(x, kwargs)
824866
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
825867
comps
826868
end
@@ -836,7 +878,7 @@ function handle_if_y!(exprs, ifexpr, y, kwargs)
836878
push!(ifexpr.args, elseifexpr)
837879
(comps...,)
838880
else
839-
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs, y, kwargs)
881+
comp_names, comps, expr_vec, varexpr = _parse_components!(y, kwargs)
840882
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
841883
comps
842884
end
@@ -861,25 +903,23 @@ function parse_components!(exprs, cs, dict, compbody, kwargs)
861903
Expr(:if, condition, x, y) => begin
862904
handle_conditional_components(condition, dict, exprs, kwargs, x, y)
863905
end
864-
Expr(:(=), a, b) => begin
865-
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs,
866-
:(begin
906+
# Either the arg is top level component declaration or an invalid cause - both are handled by `_parse_components`
907+
_ => begin
908+
comp_names, comps, expr_vec, varexpr = _parse_components!(:(begin
867909
$arg
868910
end),
869911
kwargs)
870912
push!(cs, comp_names...)
871913
push!(dict[:components], comps...)
872-
push!(exprs, varexpr, :(@named begin
873-
$(expr_vec.args...)
874-
end))
914+
push!(exprs, varexpr, expr_vec)
875915
end
876-
_ => error("Couldn't parse the component body $compbody")
877916
end
878917
end
879918
end
880919

881920
function _rename(compname, varname)
882921
compname = Symbol(compname, :__, varname)
922+
(compname, Symbol(:_, compname))
883923
end
884924

885925
# Handle top level branching

test/model_parsing.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,41 @@ end
650650
@named m = MyModel()
651651
@variables x___(t)
652652
@test isequal(x___, _b[])
653+
654+
@testset "Component array" begin
655+
@mtkmodel SubComponent begin
656+
@parameters begin
657+
sc
658+
end
659+
end
660+
661+
@mtkmodel Component begin
662+
@structural_parameters begin
663+
N = 2
664+
end
665+
@components begin
666+
comprehension = [SubComponent(sc = i) for i in 1:N]
667+
written_out_for = for i in 1:N
668+
sc = i + 1
669+
SubComponent(; sc)
670+
end
671+
single_sub_component = SubComponent()
672+
end
673+
end
674+
675+
@named component = Component()
676+
component = complete(component)
677+
678+
@test nameof.(ModelingToolkit.get_systems(component)) == [
679+
:comprehension_1,
680+
:comprehension_2,
681+
:written_out_for_1,
682+
:written_out_for_2,
683+
:single_sub_component,
684+
]
685+
686+
@test getdefault(component.comprehension_1.sc) == 1
687+
@test getdefault(component.comprehension_2.sc) == 2
688+
@test getdefault(component.written_out_for_1.sc) == 2
689+
@test getdefault(component.written_out_for_2.sc) == 3
690+
end

0 commit comments

Comments
 (0)