Skip to content

Commit 34ea159

Browse files
fix: fix varmap_to_vars
1 parent 4db0053 commit 34ea159

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/variables.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,11 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
192192
toterm = Symbolics.diff2term, initialization_phase = false)
193193
varmap = canonicalize_varmap(varmap; toterm)
194194
defaults = canonicalize_varmap(defaults; toterm)
195+
varmap = merge(defaults, varmap)
195196
values = Dict()
196197
for var in varlist
197198
var = unwrap(var)
198-
val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap; operator = Symbolics.Operator),
199-
defaults; operator = Symbolics.Operator))
199+
val = unwrap(fixpoint_sub(var, varmap; operator = Symbolics.Operator))
200200
if symbolic_type(val) === NotSymbolic()
201201
values[var] = val
202202
end
@@ -211,6 +211,11 @@ function canonicalize_varmap(varmap; toterm = Symbolics.diff2term)
211211
for (k, v) in varmap
212212
new_varmap[unwrap(k)] = unwrap(v)
213213
new_varmap[toterm(unwrap(k))] = unwrap(v)
214+
if Symbolics.isarraysymbolic(k) && Symbolics.shape(k) !== Symbolics.Unknown()
215+
for i in eachindex(k)
216+
new_varmap[k[i]] = v[i]
217+
end
218+
end
214219
end
215220
return new_varmap
216221
end

test/initial_values.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,16 @@ getter = getu(sys, [x..., y, z...])
3333
@test getter(get_u0(
3434
sys, [y => 2w, w => 3.0, z[1] => 2p1, z[2] => 3p2], [p1 => 3.0, p2 => 4.0])[1]) ==
3535
[1.0, 2.0, 3.0, 6.0, 6.0, 12.0]
36+
37+
# Issue#2566
38+
@variables X(t)
39+
@parameters p1 p2 p3
40+
41+
p_vals = [p1 => 1.0, p2 => 2.0]
42+
u_vals = [X => 3.0]
43+
44+
var_vals = [p1 => 1.0, p2 => 2.0, X => 3.0]
45+
desired_values = [p1, p2, p3]
46+
defaults = Dict([p3 => X])
47+
vals = ModelingToolkit.varmap_to_vars(var_vals, desired_values; defaults = defaults)
48+
@test vals == [1.0, 2.0, 3.0]

0 commit comments

Comments
 (0)