Skip to content

Commit 2c2e914

Browse files
Merge pull request #2836 from AayushSabharwal/as/more-metadata
fix: fix edge cases with metadata dumping, add tests
2 parents 6311360 + 105bdbe commit 2c2e914

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

src/systems/abstractsystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2776,11 +2776,15 @@ ModelingToolkit.dump_unknowns(sys)
27762776
See also: [`ModelingToolkit.dump_variable_metadata`](@ref), [`ModelingToolkit.dump_parameters`](@ref)
27772777
"""
27782778
function dump_unknowns(sys::AbstractSystem)
2779-
defs = defaults(sys)
2779+
defs = varmap_with_toterm(defaults(sys))
2780+
gs = varmap_with_toterm(guesses(sys))
27802781
map(dump_variable_metadata.(unknowns(sys))) do meta
27812782
if haskey(defs, meta.var)
27822783
meta = merge(meta, (; default = defs[meta.var]))
27832784
end
2785+
if haskey(gs, meta.var)
2786+
meta = merge(meta, (; guess = gs[meta.var]))
2787+
end
27842788
meta
27852789
end
27862790
end

src/variables.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ ModelingToolkit.dump_variable_metadata(p)
3030
function dump_variable_metadata(var)
3131
uvar = unwrap(var)
3232
vartype, name = get(uvar.metadata, VariableSource, (:unknown, :unknown))
33-
shape = Symbolics.shape(var)
34-
if shape == ()
33+
type = symtype(uvar)
34+
if type <: AbstractArray
35+
shape = Symbolics.shape(var)
36+
if shape == ()
37+
shape = nothing
38+
end
39+
else
3540
shape = nothing
3641
end
3742
unit = get(uvar.metadata, VariableUnit, nothing)
@@ -208,6 +213,10 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
208213
return [values[unwrap(var)] for var in varlist]
209214
end
210215

216+
function varmap_with_toterm(varmap; toterm = Symbolics.diff2term)
217+
return merge(todict(varmap), Dict(toterm(unwrap(k)) => v for (k, v) in varmap))
218+
end
219+
211220
function canonicalize_varmap(varmap; toterm = Symbolics.diff2term)
212221
new_varmap = Dict()
213222
for (k, v) in varmap

test/test_variable_metadata.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,15 @@ sp = Set(p)
131131

132132
@test_nowarn show(stdout, "text/plain", sys)
133133

134-
# Defaults overridden by system, parameter dependencies
135-
@variables x(t) = 1.0
134+
# Defaults, guesses overridden by system, parameter dependencies
135+
@variables x(t)=1.0 y(t) [guess = 1.0]
136136
@parameters p=2.0 q
137-
@named sys = ODESystem(Equation[], t, [x], [p]; defaults = Dict(x => 2.0, p => 3.0),
138-
parameter_dependencies = [q => 2p])
139-
x_meta = ModelingToolkit.dump_unknowns(sys)[]
140-
@test x_meta.default == 2.0
137+
@named sys = ODESystem(Equation[], t, [x, y], [p]; defaults = Dict(x => 2.0, p => 3.0),
138+
guesses = Dict(y => 2.0), parameter_dependencies = [q => 2p])
139+
unks_meta = ModelingToolkit.dump_unknowns(sys)
140+
unks_meta = Dict([ModelingToolkit.getname(meta.var) => meta for meta in unks_meta])
141+
@test unks_meta[:x].default == 2.0
142+
@test unks_meta[:y].guess == 2.0
141143
params_meta = ModelingToolkit.dump_parameters(sys)
142144
params_meta = Dict([ModelingToolkit.getname(meta.var) => meta for meta in params_meta])
143145
@test params_meta[:p].default == 3.0

0 commit comments

Comments
 (0)