Skip to content

Commit c5bff6a

Browse files
Merge pull request #2500 from ven-k/vkb/type
Provision to enforce types in parameters, variables and structural_parameters in `@mtkmodel`
2 parents c9bb725 + 44a9776 commit c5bff6a

File tree

3 files changed

+112
-36
lines changed

3 files changed

+112
-36
lines changed

docs/src/basics/MTKModel_Connector.md

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,12 @@ end
229229

230230
- `:components`: List of sub-components in the form of [[name, sub_component_name],...].
231231
- `:extend`: The list of extended unknowns, name given to the base system, and name of the base system.
232-
- `:structural_parameters`: Dictionary of structural parameters mapped to their default values.
232+
- `:structural_parameters`: Dictionary of structural parameters mapped to their metadata.
233233
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. For
234234
parameter arrays, length is added to the metadata as `:size`.
235235
- `:variables`: Dictionary of symbolic variables mapped to their metadata. For
236236
variable arrays, length is added to the metadata as `:size`.
237-
- `:kwargs`: Dictionary of keyword arguments mapped to their default values.
237+
- `:kwargs`: Dictionary of keyword arguments mapped to their metadata.
238238
- `:independent_variable`: Independent variable, which is added while generating the Model.
239239
- `:equations`: List of equations (represented as strings).
240240

@@ -243,13 +243,14 @@ For example, the structure of `ModelC` is:
243243
```julia
244244
julia> ModelC.structure
245245
Dict{Symbol, Any} with 7 entries:
246-
:components => [[:model_a, :ModelA]]
247-
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var), :v_array=>Dict(:size=>(2, 3)))
248-
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
249-
:kwargs => Dict{Symbol, Any}(:f=>:sin, :v=>:v_var, :v_array=>nothing, :model_a__k_array=>nothing, :p1=>nothing)
250-
:independent_variable => t
251-
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
252-
:equations => ["model_a.k ~ f(v)"]
246+
:components => [[:model_a, :ModelA]]
247+
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var), :v_array=>Dict(:size=>(2, 3)))
248+
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
249+
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :v=>Dict{Symbol, Union{Nothing, Symbol}}(:value=>:v_var, :type=>nothing), :v_array=>Dict(:value=>nothing, :type=>nothing), :p1=>Dict(:value=>nothing))
250+
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin))
251+
:independent_variable => t
252+
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
253+
:equations => ["model_a.k ~ f(v)"]
253254
```
254255

255256
### Using conditional statements
@@ -322,11 +323,12 @@ The conditional parts are reflected in the `structure`. For `BranchOutsideTheBlo
322323
```julia
323324
julia> BranchOutsideTheBlock.structure
324325
Dict{Symbol, Any} with 5 entries:
325-
:components => Any[(:if, :flag, [[:sys1, :C]], Any[])]
326-
:kwargs => Dict{Symbol, Any}(:flag=>true)
327-
:independent_variable => t
328-
:parameters => Dict{Symbol, Dict{Symbol, Any}}(:a1=>Dict(:condition=>(:if, :flag, Dict{Symbol, Any}(:kwargs => Dict{Any, Any}(:a1 => nothing), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a1 => Dict())]), Dict{Symbol, Any}(:kwargs => Dict{Any, Any}(:a2 => nothing), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a2 => Dict())]))
329-
:equations => Any[(:if, :flag, ["a1 ~ 0"], ["a2 ~ 0"])]
326+
:components => Any[(:if, :flag, [[:sys1, :C]], Any[])]
327+
:kwargs => Dict{Symbol, Dict}(:flag=>Dict{Symbol, Bool}(:value=>1))
328+
:structural_parameters => Dict{Symbol, Dict}(:flag=>Dict{Symbol, Bool}(:value=>1))
329+
:independent_variable => t
330+
:parameters => Dict{Symbol, Dict{Symbol, Any}}(:a1=>Dict(:condition=>(:if, :flag, Dict{Symbol, Any}(:kwargs => Dict{Any, Any}(:a1 => nothing), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a1 => Dict())]), Dict{Symbol, Any}(:kwargs => Dict{Any, Any}(:a2 => nothing), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a2 => Dict())]))
331+
:equations => Any[(:if, :flag, ["a1 ~ 0"], ["a2 ~ 0"])]
330332
```
331333
332334
Conditional entries are entered in the format of `(branch, condition, [case when it is true], [case when it is false])`;

src/systems/model_parsing.jl

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ end
3535

3636
function _model_macro(mod, name, expr, isconnector)
3737
exprs = Expr(:block)
38-
dict = Dict{Symbol, Any}()
39-
dict[:kwargs] = Dict{Symbol, Any}()
38+
dict = Dict{Symbol, Any}(
39+
:kwargs => Dict{Symbol, Dict}(),
40+
:structural_parameters => Dict{Symbol, Dict}()
41+
)
4042
comps = Symbol[]
4143
ext = Ref{Any}(nothing)
4244
eqs = Expr[]
@@ -107,7 +109,8 @@ function _model_macro(mod, name, expr, isconnector)
107109
end
108110

109111
function parse_variable_def!(dict, mod, arg, varclass, kwargs;
110-
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
112+
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing,
113+
type::Union{Type, Nothing} = nothing)
111114
metatypes = [(:connection_type, VariableConnectType),
112115
(:description, VariableDescription),
113116
(:unit, VariableUnit),
@@ -125,15 +128,34 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
125128
arg isa LineNumberNode && return
126129
MLStyle.@match arg begin
127130
a::Symbol => begin
128-
push!(kwargs, Expr(:kw, a, nothing))
131+
if type isa Nothing
132+
push!(kwargs, Expr(:kw, a, nothing))
133+
else
134+
push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing))
135+
end
129136
var = generate_var!(dict, a, varclass; indices)
130-
dict[:kwargs][getname(var)] = def
137+
dict[:kwargs][getname(var)] = Dict(:value => def, :type => type)
131138
(var, def)
132139
end
140+
Expr(:(::), a, type) => begin
141+
type = Core.eval(mod, type)
142+
_type_check!(a, type)
143+
parse_variable_def!(dict, mod, a, varclass, kwargs; def, type)
144+
end
145+
Expr(:(::), Expr(:call, a, b), type) => begin
146+
type = Core.eval(mod, type)
147+
def = _type_check!(def, a, type)
148+
parse_variable_def!(dict, mod, a, varclass, kwargs; def, type)
149+
end
133150
Expr(:call, a, b) => begin
134-
push!(kwargs, Expr(:kw, a, nothing))
151+
if type isa Nothing
152+
push!(kwargs, Expr(:kw, a, nothing))
153+
else
154+
push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing))
155+
end
135156
var = generate_var!(dict, a, b, varclass; indices)
136-
dict[:kwargs][getname(var)] = def
157+
type !== nothing && (dict[varclass][getname(var)][:type] = type)
158+
dict[:kwargs][getname(var)] = Dict(:value => def, :type => type)
137159
(var, def)
138160
end
139161
Expr(:(=), a, b) => begin
@@ -304,15 +326,23 @@ function parse_structural_parameters!(exprs, sps, dict, mod, body, kwargs)
304326
Base.remove_linenums!(body)
305327
for arg in body.args
306328
MLStyle.@match arg begin
329+
Expr(:(=), Expr(:(::), a, type), b) => begin
330+
type = Core.eval(mod, type)
331+
b = _type_check!(Core.eval(mod, b), a, type)
332+
push!(sps, a)
333+
push!(kwargs, Expr(:kw, Expr(:(::), a, type), b))
334+
dict[:structural_parameters][a] = dict[:kwargs][a] = Dict(
335+
:value => b, :type => type)
336+
end
307337
Expr(:(=), a, b) => begin
308338
push!(sps, a)
309339
push!(kwargs, Expr(:kw, a, b))
310-
dict[:kwargs][a] = b
340+
dict[:structural_parameters][a] = dict[:kwargs][a] = Dict(:value => b)
311341
end
312342
a => begin
313343
push!(sps, a)
314344
push!(kwargs, a)
315-
dict[:kwargs][a] = nothing
345+
dict[:structural_parameters][a] = dict[:kwargs][a] = Dict(:value => nothing)
316346
end
317347
end
318348
end
@@ -336,17 +366,17 @@ function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false)
336366
end
337367
end
338368
push!(kwargs, Expr(:kw, x, nothing))
339-
dict[:kwargs][x] = nothing
369+
dict[:kwargs][x] = Dict(:value => nothing)
340370
end
341371
Expr(:kw, x) => begin
342372
push!(kwargs, Expr(:kw, x, nothing))
343-
dict[:kwargs][x] = nothing
373+
dict[:kwargs][x] = Dict(:value => nothing)
344374
end
345375
Expr(:kw, x, y) => begin
346376
b.args[i] = Expr(:kw, x, x)
347377
push!(varexpr.args, :($x = $x === nothing ? $y : $x))
348378
push!(kwargs, Expr(:kw, x, nothing))
349-
dict[:kwargs][x] = nothing
379+
dict[:kwargs][x] = Dict(:value => nothing)
350380
end
351381
Expr(:parameters, x...) => begin
352382
has_param = true
@@ -851,3 +881,18 @@ function parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs, mod
851881
$equations_blk
852882
end))
853883
end
884+
885+
_type_check!(a, type) = return
886+
function _type_check!(val, a, type)
887+
if val isa type
888+
return val
889+
else
890+
try
891+
return convert(type, val)
892+
catch
893+
(e)
894+
throw(TypeError(Symbol("`@mtkmodel`"),
895+
"`@structural_parameters`, while assigning to `$a`", type, typeof(val)))
896+
end
897+
end
898+
end

test/model_parsing.jl

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ resistor = getproperty(rc, :resistor; namespace = false)
164164
# Test that `C_val` passed via argument is set as default of C.
165165
@test getdefault(rc.capacitor.C) == C_val
166166
# Test that `k`'s default value is unchanged.
167-
@test getdefault(rc.constant.k) == RC.structure[:kwargs][:k_val]
167+
@test getdefault(rc.constant.k) == RC.structure[:kwargs][:k_val][:value]
168168
@test getdefault(rc.capacitor.v) == 0.0
169169

170170
@test get_gui_metadata(rc.resistor).layout == Resistor.structure[:icon] ==
@@ -241,6 +241,33 @@ resistor = getproperty(rc, :resistor; namespace = false)
241241
@test isequal(getdefault(model.k), model.kval)
242242
end
243243

244+
@testset "Type annotation" begin
245+
@mtkmodel TypeModel begin
246+
@structural_parameters begin
247+
flag::Bool = true
248+
end
249+
@parameters begin
250+
par0::Bool = true
251+
par1::Int = 1
252+
par2(t)::Int,
253+
[description = "Enforced `par4` to be an Int by setting the type to the keyword-arg."]
254+
par3(t)::Float64 = 1.0
255+
par4(t)::Float64 = 1 # converts 1 to 1.0 of Float64 type
256+
end
257+
end
258+
259+
@named type_model = TypeModel()
260+
261+
@test getname.(parameters(type_model)) == [:par0, :par1, :par2, :par3, :par4]
262+
263+
@test_throws TypeError TypeModel(; name = :throws, flag = 1)
264+
@test_throws TypeError TypeModel(; name = :throws, par0 = 1)
265+
@test_throws TypeError TypeModel(; name = :throws, par1 = 1.5)
266+
@test_throws TypeError TypeModel(; name = :throws, par2 = 1.5)
267+
@test_throws TypeError TypeModel(; name = :throws, par3 = true)
268+
@test_throws TypeError TypeModel(; name = :throws, par4 = true)
269+
end
270+
244271
@testset "Defaults of subcomponents MTKModel" begin
245272
@mtkmodel A begin
246273
@parameters begin
@@ -322,7 +349,9 @@ end
322349
@test A.structure[:parameters] == Dict(:p => Dict())
323350
@test A.structure[:extend] == [[:e], :extended_e, :E]
324351
@test A.structure[:equations] == ["e ~ 0"]
325-
@test A.structure[:kwargs] == Dict(:p => nothing, :v => nothing)
352+
@test A.structure[:kwargs] ==
353+
Dict{Symbol, Dict}(:p => Dict(:value => nothing, :type => nothing),
354+
:v => Dict(:value => nothing, :type => nothing))
326355
@test A.structure[:components] == [[:cc, :C]]
327356
end
328357

@@ -392,9 +421,9 @@ end
392421
@named else_in_sys = InsideTheBlock(flag = 3)
393422
else_in_sys = complete(else_in_sys)
394423

395-
@test nameof.(parameters(if_in_sys)) == [:if_parameter, :eq]
396-
@test nameof.(parameters(elseif_in_sys)) == [:elseif_parameter, :eq]
397-
@test nameof.(parameters(else_in_sys)) == [:else_parameter, :eq]
424+
@test getname.(parameters(if_in_sys)) == [:if_parameter, :eq]
425+
@test getname.(parameters(elseif_in_sys)) == [:elseif_parameter, :eq]
426+
@test getname.(parameters(else_in_sys)) == [:else_parameter, :eq]
398427

399428
@test nameof.(get_systems(if_in_sys)) == [:if_sys, :default_sys]
400429
@test nameof.(get_systems(elseif_in_sys)) == [:elseif_sys, :default_sys]
@@ -481,9 +510,9 @@ end
481510
@named ternary_out_sys = OutsideTheBlock(condition = 4)
482511
else_out_sys = complete(else_out_sys)
483512

484-
@test nameof.(parameters(if_out_sys)) == [:if_parameter, :default_parameter]
485-
@test nameof.(parameters(elseif_out_sys)) == [:elseif_parameter, :default_parameter]
486-
@test nameof.(parameters(else_out_sys)) == [:else_parameter, :default_parameter]
513+
@test getname.(parameters(if_out_sys)) == [:if_parameter, :default_parameter]
514+
@test getname.(parameters(elseif_out_sys)) == [:elseif_parameter, :default_parameter]
515+
@test getname.(parameters(else_out_sys)) == [:else_parameter, :default_parameter]
487516

488517
@test nameof.(get_systems(if_out_sys)) == [:if_sys, :default_sys]
489518
@test nameof.(get_systems(elseif_out_sys)) == [:elseif_sys, :default_sys]
@@ -529,8 +558,8 @@ end
529558
@named ternary_false = TernaryBranchingOutsideTheBlock(condition = false)
530559
ternary_false = complete(ternary_false)
531560

532-
@test nameof.(parameters(ternary_true)) == [:ternary_parameter_true]
533-
@test nameof.(parameters(ternary_false)) == [:ternary_parameter_false]
561+
@test getname.(parameters(ternary_true)) == [:ternary_parameter_true]
562+
@test getname.(parameters(ternary_false)) == [:ternary_parameter_false]
534563

535564
@test nameof.(get_systems(ternary_true)) == [:ternary_sys_true]
536565
@test nameof.(get_systems(ternary_false)) == [:ternary_sys_false]

0 commit comments

Comments
 (0)