@@ -98,7 +98,8 @@ Create an `SDEFunction` from the [`SDESystem`](@ref). The arguments `dvs` and `p
98
98
are used to set the order of the dependent variable and parameter vectors,
99
99
respectively.
100
100
"""
101
- function DiffEqBase. SDEFunction {iip} (sys:: SDESystem , dvs = sys. states, ps = sys. ps;
101
+ function DiffEqBase. SDEFunction {iip} (sys:: SDESystem , dvs = sys. states, ps = sys. ps,
102
+ u0 = nothing ;
102
103
version = nothing , tgrad= false , sparse = false ,
103
104
jac = false , Wfact = false , kwargs... ) where {iip}
104
105
f_oop,f_iip = generate_function (sys, dvs, ps; expression= Val{false }, kwargs... )
@@ -138,12 +139,13 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
138
139
end
139
140
140
141
M = calculate_massmatrix (sys)
142
+ _M = u0 === nothing ? M : ArrayInterface. restructure (u0 .* u0' ,M)
141
143
142
144
SDEFunction {iip} (f,g,jac= _jac,
143
145
tgrad = _tgrad,
144
146
Wfact = _Wfact,
145
147
Wfact_t = _Wfact_t,
146
- mass_matrix = M ,
148
+ mass_matrix = _M ,
147
149
syms = Symbol .(sys. states))
148
150
end
149
151
@@ -177,11 +179,14 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
177
179
linenumbers = true , parallel= SerialForm (),
178
180
kwargs... ) where iip
179
181
180
- f = SDEFunction {iip} (sys;tgrad= tgrad,jac= jac,Wfact= Wfact,checkbounds= checkbounds,
181
- linenumbers= linenumbers,parallel= parallel,
182
- sparse= sparse)
183
- u0 = varmap_to_vars (u0map,states (sys))
184
- p = varmap_to_vars (parammap,parameters (sys))
182
+ dvs = states (sys)
183
+ ps = parameters (sys)
184
+ u0 = varmap_to_vars (u0map,dvs)
185
+ p = varmap_to_vars (parammap,ps)
186
+ f = SDEFunction {iip} (sys,dvs,ps,u0;tgrad= tgrad,jac= jac,Wfact= Wfact,
187
+ checkbounds= checkbounds,
188
+ linenumbers= linenumbers,parallel= parallel,
189
+ sparse= sparse)
185
190
if typeof (sys. noiseeqs) <: AbstractVector
186
191
noise_rate_prototype = nothing
187
192
elseif sparsenoise
0 commit comments