Skip to content

Commit d52de90

Browse files
authored
Merge pull request #2437 from ven-k/vkb/at-defaults
feat: add `@defaults` to `@mtkmodel`
2 parents 3ee066f + 9ee11b7 commit d52de90

File tree

4 files changed

+73
-8
lines changed

4 files changed

+73
-8
lines changed

docs/src/basics/MTKModel_Connector.md

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ equations.
2525

2626
- `@components`: for listing sub-components of the system
2727
- `@constants`: for declaring constants
28+
- `@defaults`: for passing `defaults` to ODESystem
2829
- `@equations`: for the list of equations
2930
- `@extend`: for extending a base system and unpacking its unknowns
3031
- `@icon` : for embedding the model icon
@@ -66,6 +67,7 @@ end
6667
@variables begin
6768
v(t) = v_var
6869
v_array(t)[1:2, 1:3]
70+
v_for_defaults(t)
6971
end
7072
@extend ModelB(; p1)
7173
@components begin
@@ -79,6 +81,9 @@ end
7981
@equations begin
8082
model_a.k ~ f(v)
8183
end
84+
@defaults begin
85+
v_for_defaults => 2.0
86+
end
8287
end
8388
```
8489

@@ -172,6 +177,11 @@ getdefault(model_c3.model_a.k_array[2])
172177

173178
- List all the equations here
174179

180+
#### `@defaults` begin block
181+
182+
- Default values can be passed as pairs.
183+
- This is equivalent to passing `defaults` argument to `ODESystem`.
184+
175185
#### A begin block
176186

177187
- Any other Julia operations can be included with dedicated begin blocks.
@@ -239,6 +249,7 @@ end
239249

240250
- `:components`: The list of sub-components in the form of [[name, sub_component_name],...].
241251
- `:constants`: Dictionary of constants mapped to its metadata.
252+
- `:defaults`: Dictionary of variables and default values specified in the `@defaults`.
242253
- `:extend`: The list of extended unknowns, name given to the base system, and name of the base system.
243254
- `:structural_parameters`: Dictionary of structural parameters mapped to their metadata.
244255
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. For
@@ -253,15 +264,16 @@ For example, the structure of `ModelC` is:
253264

254265
```julia
255266
julia> ModelC.structure
256-
Dict{Symbol, Any} with 9 entries:
267+
Dict{Symbol, Any} with 10 entries:
257268
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
258-
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)))
269+
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)), :v_for_defaults=>Dict(:type=>Real))
259270
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
260-
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Union{Nothing, UnionAll}}(:value=>nothing, :type=>AbstractArray{Real}), :p1=>Dict(:value=>nothing))
271+
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Union{Nothing, UnionAll}}(:value=>nothing, :type=>AbstractArray{Real}), :v_for_defaults=>Dict{Symbol, Union{Nothing, DataType}}(:value=>nothing, :type=>Real), :p1=>Dict(:value=>nothing))
261272
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2))
262273
:independent_variable => t
263274
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
264275
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
276+
:defaults => Dict{Symbol, Any}(:v_for_defaults=>2.0)
265277
:equations => Any["model_a.k ~ f(v)"]
266278
```
267279

@@ -327,6 +339,9 @@ used inside the if-elseif-else statements.
327339
a2 ~ 0
328340
end
329341
end
342+
@defaults begin
343+
a1 => 10
344+
end
330345
end
331346
```
332347

src/systems/model_parsing.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ function _model_macro(mod, name, expr, isconnector)
3939
exprs = Expr(:block)
4040
dict = Dict{Symbol, Any}(
4141
:constants => Dict{Symbol, Dict}(),
42+
:defaults => Dict{Symbol, Any}(),
4243
:kwargs => Dict{Symbol, Dict}(),
4344
:structural_parameters => Dict{Symbol, Dict}()
4445
)
@@ -54,6 +55,7 @@ function _model_macro(mod, name, expr, isconnector)
5455
push!(exprs.args, :(parameters = []))
5556
push!(exprs.args, :(systems = ODESystem[]))
5657
push!(exprs.args, :(equations = Equation[]))
58+
push!(exprs.args, :(defaults = Dict{Num, Union{Number, Symbol, Function}}()))
5759

5860
Base.remove_linenums!(expr)
5961
for arg in expr.args
@@ -99,8 +101,11 @@ function _model_macro(mod, name, expr, isconnector)
99101
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
100102
GUIMetadata(GlobalRef(mod, name))
101103

104+
@inline pop_structure_dict!.(
105+
Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters])
106+
102107
sys = :($ODESystem($Equation[equations...], $iv, variables, parameters;
103-
name, systems, gui_metadata = $gui_metadata))
108+
name, systems, gui_metadata = $gui_metadata, defaults))
104109

105110
if ext[] === nothing
106111
push!(exprs.args, :(var"#___sys___" = $sys))
@@ -122,6 +127,8 @@ function _model_macro(mod, name, expr, isconnector)
122127
:($name = $Model($f, $dict, $isconnector))
123128
end
124129

130+
pop_structure_dict!(dict, key) = length(dict[key]) == 0 && pop!(dict, key)
131+
125132
function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
126133
varclass, where_types)
127134
if indices isa Nothing
@@ -355,6 +362,8 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps,
355362
elseif mname == Symbol("@icon")
356363
isassigned(icon) && error("This model has more than one icon.")
357364
parse_icon!(body, dict, icon, mod)
365+
elseif mname == Symbol("@defaults")
366+
parse_system_defaults!(exprs, arg, dict)
358367
else
359368
error("$mname is not handled.")
360369
end
@@ -400,6 +409,28 @@ function parse_constants!(exprs, dict, body, mod)
400409
end
401410
end
402411

412+
push_additional_defaults!(dict, a, b::Number) = dict[:defaults][a] = b
413+
push_additional_defaults!(dict, a, b::QuoteNode) = dict[:defaults][a] = b.value
414+
function push_additional_defaults!(dict, a, b::Expr)
415+
dict[:defaults][a] = readable_code(b)
416+
end
417+
418+
function parse_system_defaults!(exprs, defaults_body, dict)
419+
for default_arg in defaults_body.args[end].args
420+
# for arg in default_arg.args
421+
MLStyle.@match default_arg begin
422+
# For cases like `p => 1` and `p => f()`. In both cases the definitions of
423+
# `a`, here `p` and when `b` is a function, here `f` are available while
424+
# defining the model
425+
Expr(:call, :(=>), a, b) => begin
426+
push!(exprs, :(defaults[$a] = $b))
427+
push_additional_defaults!(dict, a, b)
428+
end
429+
_ => error("Invalid `defaults` entry $default_arg $(typeof(a)) $(typeof(b))")
430+
end
431+
end
432+
end
433+
403434
function parse_structural_parameters!(exprs, sps, dict, mod, body, kwargs)
404435
Base.remove_linenums!(body)
405436
for arg in body.args

test/model_parsing.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using ModelingToolkit, Test
2-
using ModelingToolkit: get_gui_metadata, get_systems, get_connector_type,
3-
get_ps, getdefault, getname, scalarize, symtype,
4-
VariableDescription, RegularConnector
2+
using ModelingToolkit: get_connector_type, get_defaults, get_gui_metadata,
3+
get_systems, get_ps, getdefault, getname, scalarize, symtype,
4+
VariableDescription,
5+
RegularConnector
56
using URIs: URI
67
using Distributions
78
using DynamicQuantities, OrdinaryDiffEq
@@ -219,17 +220,26 @@ end
219220
j(t) = jval, [description = "j(t)"]
220221
k = kval, [description = "k"]
221222
l(t)[1:2, 1:3] = 2, [description = "l is more than 1D"]
223+
n # test defaults with Number input
224+
n2 # test defaults with Function input
222225
end
223226
@structural_parameters begin
224227
m = 1
225228
func
226229
end
230+
begin
231+
g() = 5
232+
end
233+
@defaults begin
234+
n => 1.0
235+
n2 => g()
236+
end
227237
end
228238

229239
kval = 5
230240
@named model = MockModel(; b2 = [1, 3], kval, cval = 1, func = identity)
231241

232-
@test lastindex(parameters(model)) == 29
242+
@test lastindex(parameters(model)) == 31
233243

234244
@test all(getdescription.([model.e2...]) .== "e2")
235245
@test all(getdescription.([model.h2...]) .== "h2(t)")
@@ -256,6 +266,10 @@ end
256266
@test all(getdefault.(scalarize(model.l)) .== 2)
257267
@test isequal(getdefault(model.j), model.jval)
258268
@test isequal(getdefault(model.k), model.kval)
269+
@test get_defaults(model)[model.n] == 1.0
270+
@test get_defaults(model)[model.n2] == 5
271+
272+
@test MockModel.structure[:defaults] == Dict(:n => 1.0, :n2 => "g()")
259273
end
260274

261275
@testset "Type annotation" begin

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using SafeTestsets, Pkg, Test
22

3+
#=
34
const GROUP = get(ENV, "GROUP", "All")
45
56
function activate_extensions_env()
@@ -91,3 +92,7 @@ end
9192
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
9293
end
9394
end
95+
96+
=#
97+
98+
@safetestset "Model Parsing Test" include("model_parsing.jl")

0 commit comments

Comments
 (0)