Skip to content

Commit 2288569

Browse files
Merge pull request #2524 from AayushSabharwal/as/no-scalarize
fix: avoid scalarizing params in structural_simplify, variable defaults in get_u0
2 parents d52de90 + e2807f6 commit 2288569

File tree

8 files changed

+63
-27
lines changed

8 files changed

+63
-27
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
#
3838
# julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))'
3939
run: |
40-
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
40+
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="1.0.50"))'
4141
julia -e 'using JuliaFormatter; format(".", verbose=true)'
4242
- name: Format check
4343
run: |

src/clock.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ struct SolverStepClock <: AbstractClock
142142
end
143143
SolverStepClock() = SolverStepClock(nothing)
144144

145-
sampletime(c) = nothing
146145
Base.hash(c::SolverStepClock, seed::UInt) = seed 0x953d7b9a18874b91
147146
function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock)
148147
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t))

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
207207
for (j, x) in enumerate(dvs)
208208
if istree(x) && operation(x) == getindex
209209
arg = arguments(x)[1]
210-
arg in allvars || continue
210+
any(isequal(arg), allvars) || continue
211211
inds = get!(() -> Int[], array_vars, arg)
212212
push!(inds, j)
213213
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -820,19 +820,6 @@ function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false)
820820
defs = mergedefaults(defs, parammap, ps)
821821
end
822822
defs = mergedefaults(defs, u0map, dvs)
823-
for (k, v) in defs
824-
if Symbolics.isarraysymbolic(k) &&
825-
Symbolics.shape(unwrap(k)) !== Symbolics.Unknown()
826-
ks = scalarize(k)
827-
length(ks) == length(v) || error("$k has default value $v with unmatched size")
828-
for (kk, vv) in zip(ks, v)
829-
if !haskey(defs, kk)
830-
defs[kk] = vv
831-
end
832-
end
833-
end
834-
end
835-
836823
if symbolic_u0
837824
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
838825
else

src/systems/systemstructure.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,12 @@ function TearingState(sys; quick_cancel = false, check = true)
285285
end
286286
vars!(vars, eq.rhs, op = Symbolics.Operator)
287287
for v in vars
288+
_var, _ = var_from_nested_derivative(v)
289+
any(isequal(_var), ivs) && continue
290+
if isparameter(_var) ||
291+
(istree(_var) && isparameter(operation(_var)) || isconstant(_var))
292+
continue
293+
end
288294
v = scalarize(v)
289295
if v isa AbstractArray
290296
v = setmetadata.(v, VariableIrreducible, true)

src/variables.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,28 @@ end
187187

188188
function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false,
189189
toterm = Symbolics.diff2term, initialization_phase = false)
190-
varmap = merge(defaults, varmap) # prefers the `varmap`
191-
varmap = Dict(toterm(value(k)) => value(varmap[k]) for k in keys(varmap))
192-
# resolve symbolic parameter expressions
193-
for (p, v) in pairs(varmap)
194-
varmap[p] = fixpoint_sub(v, varmap)
190+
varmap = canonicalize_varmap(varmap; toterm)
191+
defaults = canonicalize_varmap(defaults; toterm)
192+
values = Dict()
193+
for var in varlist
194+
var = unwrap(var)
195+
val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap), defaults))
196+
if symbolic_type(val) === NotSymbolic()
197+
values[var] = val
198+
end
195199
end
196-
197-
missingvars = setdiff(varlist, collect(keys(varmap)))
200+
missingvars = setdiff(varlist, collect(keys(values)))
198201
check && (isempty(missingvars) || throw(MissingVariablesError(missingvars)))
202+
return [values[unwrap(var)] for var in varlist]
203+
end
199204

200-
out = [varmap[var] for var in varlist]
205+
function canonicalize_varmap(varmap; toterm = Symbolics.diff2term)
206+
new_varmap = Dict()
207+
for (k, v) in varmap
208+
new_varmap[unwrap(k)] = unwrap(v)
209+
new_varmap[toterm(unwrap(k))] = unwrap(v)
210+
end
211+
return new_varmap
201212
end
202213

203214
@noinline function throw_missingvars(vars)

test/initial_values.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: t_nounits as t, D_nounits as D, get_u0
3+
using SymbolicIndexingInterface: getu
4+
5+
@variables x(t)[1:3]=[1.0, 2.0, 3.0] y(t) z(t)[1:2]
6+
7+
@mtkbuild sys=ODESystem([D(x) ~ t * x], t) simplify=false
8+
@test get_u0(sys, [])[1] == [1.0, 2.0, 3.0]
9+
@test get_u0(sys, [x => [2.0, 3.0, 4.0]])[1] == [2.0, 3.0, 4.0]
10+
@test get_u0(sys, [x[1] => 2.0, x[2] => 3.0, x[3] => 4.0])[1] == [2.0, 3.0, 4.0]
11+
@test get_u0(sys, [2.0, 3.0, 4.0])[1] == [2.0, 3.0, 4.0]
12+
13+
@mtkbuild sys=ODESystem([
14+
D(x) ~ 3x,
15+
D(y) ~ t,
16+
D(z[1]) ~ z[2] + t,
17+
D(z[2]) ~ y + z[1]
18+
], t) simplify=false
19+
20+
@test_throws ModelingToolkit.MissingVariablesError get_u0(sys, [])
21+
getter = getu(sys, [x..., y, z...])
22+
@test getter(get_u0(sys, [y => 4.0, z => [5.0, 6.0]])[1]) == collect(1.0:6.0)
23+
@test getter(get_u0(sys, [y => 4.0, z => [3y, 4y]])[1]) == [1.0, 2.0, 3.0, 4.0, 12.0, 16.0]
24+
@test getter(get_u0(sys, [y => 3.0, z[1] => 3y, z[2] => 2x[1]])[1]) ==
25+
[1.0, 2.0, 3.0, 3.0, 9.0, 2.0]
26+
27+
@variables w(t)
28+
@parameters p1 p2
29+
30+
@test getter(get_u0(sys, [y => 2p1, z => [3y, 2p2]], [p1 => 5.0, p2 => 6.0])[1]) ==
31+
[1.0, 2.0, 3.0, 10.0, 30.0, 12.0]
32+
@test_throws Any getter(get_u0(sys, [y => 2w, w => 3.0, z[1] => 2p1, z[2] => 3p2]))
33+
@test getter(get_u0(
34+
sys, [y => 2w, w => 3.0, z[1] => 2p1, z[2] => 3p2], [p1 => 3.0, p2 => 4.0])[1]) ==
35+
[1.0, 2.0, 3.0, 6.0, 6.0, 12.0]

test/runtests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using SafeTestsets, Pkg, Test
22

3-
#=
43
const GROUP = get(ENV, "GROUP", "All")
54

65
function activate_extensions_env()
@@ -67,6 +66,7 @@ end
6766
@safetestset "Constants Test" include("constants.jl")
6867
@safetestset "Parameter Dependency Test" include("parameter_dependencies.jl")
6968
@safetestset "Generate Custom Function Test" include("generate_custom_function.jl")
69+
@safetestset "Initial Values Test" include("initial_values.jl")
7070
end
7171
end
7272

@@ -93,6 +93,4 @@ end
9393
end
9494
end
9595

96-
=#
97-
9896
@safetestset "Model Parsing Test" include("model_parsing.jl")

0 commit comments

Comments
 (0)