Skip to content

Commit d678d99

Browse files
committed
feat: support array of components in @mtkmodel
- for loop or a list comprehension can be used to declare component arrays
1 parent 3fbbdb0 commit d678d99

File tree

3 files changed

+129
-48
lines changed

3 files changed

+129
-48
lines changed

src/systems/abstractsystem.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,7 @@ function _named(name, call, runtime = false)
11321132
end
11331133
end
11341134

1135-
function _named_idxs(name::Symbol, idxs, call)
1135+
function _named_idxs(name::Symbol, idxs, call; extra_args = "")
11361136
if call.head !== :->
11371137
throw(ArgumentError("Not an anonymous function"))
11381138
end
@@ -1143,7 +1143,10 @@ function _named_idxs(name::Symbol, idxs, call)
11431143
ex = Base.Cartesian.poplinenum(ex)
11441144
ex = _named(:(Symbol($(Meta.quot(name)), :_, $sym)), ex, true)
11451145
ex = Base.Cartesian.poplinenum(ex)
1146-
:($name = $map($sym -> $ex, $idxs))
1146+
:($name = map($sym -> begin
1147+
$extra_args
1148+
$ex
1149+
end, $idxs))
11471150
end
11481151

11491152
function single_named_expr(expr)

src/systems/model_parsing.jl

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function _model_macro(mod, name, expr, isconnector)
3737
exprs = Expr(:block)
3838
dict = Dict{Symbol, Any}()
3939
dict[:kwargs] = Dict{Symbol, Any}()
40-
comps = Symbol[]
40+
comps = Union{Symbol, Expr}[]
4141
ext = Ref{Any}(nothing)
4242
eqs = Expr[]
4343
icon = Ref{Union{String, URI}}()
@@ -634,7 +634,7 @@ end
634634

635635
### Parsing Components:
636636

637-
function component_args!(a, b, expr, varexpr, kwargs)
637+
function component_args!(a, b, varexpr, kwargs; index_name = nothing)
638638
# Whenever `b` is a function call, skip the first arg aka the function name.
639639
# Whenever it is a kwargs list, include it.
640640
start = b.head == :call ? 2 : 1
@@ -643,73 +643,115 @@ function component_args!(a, b, expr, varexpr, kwargs)
643643
arg isa LineNumberNode && continue
644644
MLStyle.@match arg begin
645645
x::Symbol || Expr(:kw, x) => begin
646-
_v = _rename(a, x)
647-
b.args[i] = Expr(:kw, x, _v)
648-
push!(varexpr.args, :((@isdefined $x) && ($_v = $x)))
649-
push!(kwargs, Expr(:kw, _v, nothing))
650-
# dict[:kwargs][_v] = nothing
646+
varname, _varname = _rename(a, x)
647+
b.args[i] = Expr(:kw, x, _varname)
648+
push!(varexpr.args, :((if $varname !== nothing
649+
$_varname = $varname
650+
elseif @isdefined $x
651+
# Allow users to define a var in `structural_parameters` and set
652+
# that as positional arg of subcomponents; it is useful for cases
653+
# where it needs to be passed to multiple subcomponents.
654+
$_varname = $x
655+
end)))
656+
push!(kwargs, Expr(:kw, varname, nothing))
657+
# dict[:kwargs][varname] = nothing
651658
end
652659
Expr(:parameters, x...) => begin
653-
component_args!(a, arg, expr, varexpr, kwargs)
660+
component_args!(a, arg, varexpr, kwargs)
654661
end
655662
Expr(:kw, x, y) => begin
656-
_v = _rename(a, x)
657-
b.args[i] = Expr(:kw, x, _v)
658-
push!(varexpr.args, :($_v = $_v === nothing ? $y : $_v))
659-
push!(kwargs, Expr(:kw, _v, nothing))
660-
# dict[:kwargs][_v] = nothing
663+
varname, _varname = _rename(a, x)
664+
b.args[i] = Expr(:kw, x, _varname)
665+
if isnothing(index_name)
666+
push!(varexpr.args, :($_varname = $varname === nothing ? $y : $varname))
667+
else
668+
push!(varexpr.args,
669+
:($_varname = $varname === nothing ? $y : $varname[$index_name]))
670+
end
671+
push!(kwargs, Expr(:kw, varname, nothing))
672+
# dict[:kwargs][varname] = nothing
661673
end
662674
_ => error("Could not parse $arg of component $a")
663675
end
664676
end
665677
end
666678

667-
function _parse_components!(exprs, body, kwargs)
668-
expr = Expr(:block)
679+
model_name(name, range) = Symbol.(name, :_, collect(range))
680+
681+
function _parse_components!(body, kwargs)
682+
local expr
669683
varexpr = Expr(:block)
670-
# push!(exprs, varexpr)
671-
comps = Vector{Union{Symbol, Expr}}[]
684+
comps = Vector{Union{Union{Expr, Symbol}, Expr}}[]
672685
comp_names = []
673686

674-
for arg in body.args
675-
arg isa LineNumberNode && continue
676-
MLStyle.@match arg begin
677-
Expr(:block) => begin
678-
# TODO: Do we need this?
679-
error("Multiple `@components` block detected within a single block")
680-
end
681-
Expr(:(=), a, b) => begin
682-
arg = deepcopy(arg)
683-
b = deepcopy(arg.args[2])
687+
Base.remove_linenums!(body)
688+
arg = body.args[end]
689+
690+
MLStyle.@match arg begin
691+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d)))) => begin
692+
array_varexpr = Expr(:block)
684693

685-
component_args!(a, b, expr, varexpr, kwargs)
694+
push!(comp_names, :($a...))
695+
push!(comps, [a, b.args[1], d])
696+
b = deepcopy(b)
686697

687-
arg.args[2] = b
688-
push!(expr.args, arg)
689-
push!(comp_names, a)
698+
component_args!(a, b, array_varexpr, kwargs; index_name = c)
699+
700+
expr = _named_idxs(a, d, :($c -> $b); extra_args = array_varexpr)
701+
end
702+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:filter, e, Expr(:(=), c, d))))) => begin
703+
error("List comprehensions with conditional statements aren't supported.")
704+
end
705+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d), e...))) => begin
706+
# Note that `e` is of the form `Tuple{Expr(:(=), c, d)}`
707+
error("More than one index isn't supported while building component array")
708+
end
709+
Expr(:block) => begin
710+
# TODO: Do we need this?
711+
error("Multiple `@components` block detected within a single block")
712+
end
713+
Expr(:(=), a, Expr(:for, Expr(:(=), c, d), b)) => begin
714+
Base.remove_linenums!(b)
715+
array_varexpr = Expr(:block)
716+
push!(array_varexpr.args, b.args[1:(end - 1)]...)
717+
push!(comp_names, :($a...))
718+
push!(comps, [a, b.args[end].args[1], d])
719+
b = deepcopy(b)
720+
721+
component_args!(a, b.args[end], array_varexpr, kwargs; index_name = c)
722+
723+
expr = _named_idxs(a, d, :($c -> $(b.args[end])); extra_args = array_varexpr)
724+
end
725+
Expr(:(=), a, b) => begin
726+
arg = deepcopy(arg)
727+
b = deepcopy(arg.args[2])
728+
729+
component_args!(a, b, varexpr, kwargs)
730+
731+
arg.args[2] = b
732+
expr = :(@named $arg)
733+
push!(comp_names, a)
690734
if (isa(b.args[1], Symbol) || Meta.isexpr(b.args[1], :.))
691-
push!(comps, [a, b.args[1]])
735+
push!(comps, [a, b.args[1]])
692736
end
693-
end
694-
_ => error("Couldn't parse the component body: $arg")
695737
end
738+
_ => error("Couldn't parse the component body: $arg")
696739
end
740+
697741
return comp_names, comps, expr, varexpr
698742
end
699743

700744
function push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
701745
blk = Expr(:block)
702746
push!(blk.args, varexpr)
703-
push!(blk.args, :(@named begin
704-
$(expr_vec.args...)
705-
end))
747+
push!(blk.args, expr_vec)
706748
push!(blk.args, :($push!(systems, $(comp_names...))))
707749
push!(ifexpr.args, blk)
708750
end
709751

710752
function handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition = nothing)
711753
push!(ifexpr.args, condition)
712-
comp_names, comps, expr_vec, varexpr = _parse_components!(ifexpr, x, kwargs)
754+
comp_names, comps, expr_vec, varexpr = _parse_components!(x, kwargs)
713755
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
714756
comps
715757
end
@@ -725,7 +767,7 @@ function handle_if_y!(exprs, ifexpr, y, kwargs)
725767
push!(ifexpr.args, elseifexpr)
726768
(comps...,)
727769
else
728-
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs, y, kwargs)
770+
comp_names, comps, expr_vec, varexpr = _parse_components!(y, kwargs)
729771
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
730772
comps
731773
end
@@ -750,25 +792,23 @@ function parse_components!(exprs, cs, dict, compbody, kwargs)
750792
Expr(:if, condition, x, y) => begin
751793
handle_conditional_components(condition, dict, exprs, kwargs, x, y)
752794
end
753-
Expr(:(=), a, b) => begin
754-
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs,
755-
:(begin
795+
# Either the arg is top level component declaration or an invalid cause - both are handled by `_parse_components`
796+
_ => begin
797+
comp_names, comps, expr_vec, varexpr = _parse_components!(:(begin
756798
$arg
757799
end),
758800
kwargs)
759801
push!(cs, comp_names...)
760802
push!(dict[:components], comps...)
761-
push!(exprs, varexpr, :(@named begin
762-
$(expr_vec.args...)
763-
end))
803+
push!(exprs, varexpr, expr_vec)
764804
end
765-
_ => error("Couldn't parse the component body $compbody")
766805
end
767806
end
768807
end
769808

770809
function _rename(compname, varname)
771810
compname = Symbol(compname, :__, varname)
811+
(compname, Symbol(:_, compname))
772812
end
773813

774814
# Handle top level branching

test/model_parsing.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,41 @@ end
539539
@test Equation[ternary_true.ternary_parameter_true ~ 0] == equations(ternary_true)
540540
@test Equation[ternary_false.ternary_parameter_false ~ 0] == equations(ternary_false)
541541
end
542+
543+
@testset "Component array" begin
544+
@mtkmodel SubComponent begin
545+
@parameters begin
546+
sc
547+
end
548+
end
549+
550+
@mtkmodel Component begin
551+
@structural_parameters begin
552+
N = 2
553+
end
554+
@components begin
555+
comprehension = [SubComponent(sc = i) for i in 1:N]
556+
written_out_for = for i in 1:N
557+
sc = i + 1
558+
SubComponent(; sc)
559+
end
560+
single_sub_component = SubComponent()
561+
end
562+
end
563+
564+
@named component = Component()
565+
component = complete(component)
566+
567+
@test nameof.(ModelingToolkit.get_systems(component)) == [
568+
:comprehension_1,
569+
:comprehension_2,
570+
:written_out_for_1,
571+
:written_out_for_2,
572+
:single_sub_component,
573+
]
574+
575+
@test getdefault(component.comprehension_1.sc) == 1
576+
@test getdefault(component.comprehension_2.sc) == 2
577+
@test getdefault(component.written_out_for_1.sc) == 2
578+
@test getdefault(component.written_out_for_2.sc) == 3
579+
end

0 commit comments

Comments
 (0)