Skip to content

Commit c3d4cbe

Browse files
Merge pull request #397 from SciML/mass_matrix
Fix type of mass matrix to be based on u0
2 parents ce386bd + 5a7bc0f commit c3d4cbe

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ are used to set the order of the dependent variable and parameter vectors,
139139
respectively.
140140
"""
141141
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
142-
ps = parameters(sys);
142+
ps = parameters(sys), u0 = nothing;
143143
version = nothing, tgrad=false,
144144
jac = false, Wfact = false,
145145
sparse = false,
@@ -179,11 +179,13 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
179179

180180
M = calculate_massmatrix(sys)
181181

182+
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
183+
182184
ODEFunction{iip}(f,jac=_jac,
183185
tgrad = _tgrad,
184186
Wfact = _Wfact,
185187
Wfact_t = _Wfact_t,
186-
mass_matrix = M,
188+
mass_matrix = _M,
187189
syms = Symbol.(states(sys)))
188190
end
189191

@@ -212,10 +214,12 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
212214
checkbounds = false, sparse = false,
213215
linenumbers = true, parallel=SerialForm(),
214216
kwargs...) where iip
215-
f = ODEFunction{iip}(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
217+
dvs = states(sys)
218+
ps = parameters(sys)
219+
u0 = varmap_to_vars(u0map,dvs)
220+
p = varmap_to_vars(parammap,ps)
221+
f = ODEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
216222
linenumbers=linenumbers,parallel=parallel,
217223
sparse=sparse)
218-
u0 = varmap_to_vars(u0map,states(sys))
219-
p = varmap_to_vars(parammap,parameters(sys))
220224
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
221225
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ Create an `SDEFunction` from the [`SDESystem`](@ref). The arguments `dvs` and `p
9898
are used to set the order of the dependent variable and parameter vectors,
9999
respectively.
100100
"""
101-
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;
101+
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps,
102+
u0 = nothing;
102103
version = nothing, tgrad=false, sparse = false,
103104
jac = false, Wfact = false, kwargs...) where {iip}
104105
f_oop,f_iip = generate_function(sys, dvs, ps; expression=Val{false}, kwargs...)
@@ -138,12 +139,13 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
138139
end
139140

140141
M = calculate_massmatrix(sys)
142+
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
141143

142144
SDEFunction{iip}(f,g,jac=_jac,
143145
tgrad = _tgrad,
144146
Wfact = _Wfact,
145147
Wfact_t = _Wfact_t,
146-
mass_matrix = M,
148+
mass_matrix = _M,
147149
syms = Symbol.(sys.states))
148150
end
149151

@@ -177,11 +179,14 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
177179
linenumbers = true, parallel=SerialForm(),
178180
kwargs...) where iip
179181

180-
f = SDEFunction{iip}(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
181-
linenumbers=linenumbers,parallel=parallel,
182-
sparse=sparse)
183-
u0 = varmap_to_vars(u0map,states(sys))
184-
p = varmap_to_vars(parammap,parameters(sys))
182+
dvs = states(sys)
183+
ps = parameters(sys)
184+
u0 = varmap_to_vars(u0map,dvs)
185+
p = varmap_to_vars(parammap,ps)
186+
f = SDEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,
187+
checkbounds=checkbounds,
188+
linenumbers=linenumbers,parallel=parallel,
189+
sparse=sparse)
185190
if typeof(sys.noiseeqs) <: AbstractVector
186191
noise_rate_prototype = nothing
187192
elseif sparsenoise

0 commit comments

Comments
 (0)