Skip to content

Commit 786a1bc

Browse files
fix: avoid scalarizing params in structural_simplify, variable defaults in get_u0
1 parent a64aad8 commit 786a1bc

File tree

3 files changed

+21
-17
lines changed

3 files changed

+21
-17
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -820,19 +820,6 @@ function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false)
820820
defs = mergedefaults(defs, parammap, ps)
821821
end
822822
defs = mergedefaults(defs, u0map, dvs)
823-
for (k, v) in defs
824-
if Symbolics.isarraysymbolic(k) &&
825-
Symbolics.shape(unwrap(k)) !== Symbolics.Unknown()
826-
ks = scalarize(k)
827-
length(ks) == length(v) || error("$k has default value $v with unmatched size")
828-
for (kk, vv) in zip(ks, v)
829-
if !haskey(defs, kk)
830-
defs[kk] = vv
831-
end
832-
end
833-
end
834-
end
835-
836823
if symbolic_u0
837824
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
838825
else

src/systems/systemstructure.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,12 @@ function TearingState(sys; quick_cancel = false, check = true)
285285
end
286286
vars!(vars, eq.rhs, op = Symbolics.Operator)
287287
for v in vars
288+
_var, _ = var_from_nested_derivative(v)
289+
any(isequal(_var), ivs) && continue
290+
if isparameter(_var) ||
291+
(istree(_var) && isparameter(operation(_var)) || isconstant(_var))
292+
continue
293+
end
288294
v = scalarize(v)
289295
if v isa AbstractArray
290296
v = setmetadata.(v, VariableIrreducible, true)

src/variables.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,24 @@ end
187187

188188
function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false,
189189
toterm = Symbolics.diff2term, initialization_phase = false)
190-
varmap = merge(defaults, varmap) # prefers the `varmap`
191-
varmap = Dict(toterm(value(k)) => value(varmap[k]) for k in keys(varmap))
190+
_varmap = merge(defaults, varmap) # prefers the `varmap`
191+
varmap = Dict()
192+
for (k, v) in _varmap
193+
varmap[value(k)] = value(v)
194+
varmap[toterm(value(k))] = value(v)
195+
end
192196
# resolve symbolic parameter expressions
193197
for (p, v) in pairs(varmap)
194-
varmap[p] = fixpoint_sub(v, varmap)
198+
varmap[p] = fixpoint_sub(unwrap(v), varmap)
199+
end
200+
for var in varlist
201+
var = value(var)
202+
haskey(varmap, var) && continue
203+
val = fixpoint_sub(unwrap(var), varmap)
204+
if symbolic_type(val) === NotSymbolic()
205+
varmap[var] = val
206+
end
195207
end
196-
197208
missingvars = setdiff(varlist, collect(keys(varmap)))
198209
check && (isempty(missingvars) || throw(MissingVariablesError(missingvars)))
199210

0 commit comments

Comments
 (0)