Skip to content

Commit 69b380b

Browse files
committed
Add eval_expression support to ODEProblem/SDEProblem/SDEFunction.
1 parent d858058 commit 69b380b

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
142142
ps = parameters(sys), u0 = nothing;
143143
version = nothing, tgrad=false,
144144
jac = false, Wfact = false,
145-
sparse = false, eval_expression=true,
145+
sparse = false, eval_expression = true,
146146
kwargs...) where {iip}
147147

148148
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
@@ -291,14 +291,15 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
291291
jac = false, Wfact = false,
292292
checkbounds = false, sparse = false,
293293
linenumbers = true, parallel=SerialForm(),
294+
eval_expression = true,
294295
kwargs...) where iip
295296
dvs = states(sys)
296297
ps = parameters(sys)
297298
u0 = varmap_to_vars(u0map,dvs)
298299
p = varmap_to_vars(parammap,ps)
299300
f = ODEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
300301
linenumbers=linenumbers,parallel=parallel,
301-
sparse=sparse)
302+
sparse=sparse,eval_expression=eval_expression)
302303
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
303304
end
304305

src/systems/diffeqs/sdesystem.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,29 @@ respectively.
101101
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps,
102102
u0 = nothing;
103103
version = nothing, tgrad=false, sparse = false,
104-
jac = false, Wfact = false, kwargs...) where {iip}
105-
f_oop,f_iip = ModelingToolkit.eval.(generate_function(sys, dvs, ps; expression=Val{true}, kwargs...))
106-
g_oop,g_iip = ModelingToolkit.eval.(generate_diffusion_function(sys, dvs, ps; expression=Val{true}, kwargs...))
104+
jac = false, Wfact = false, eval_expression = true, kwargs...) where {iip}
105+
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
106+
f_oop,f_iip = eval_expression ? ModelingToolkit.eval.(f_gen) : f_gen
107+
g_gen = generate_diffusion_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
108+
g_oop,g_iip = eval_expression ? ModelingToolkit.eval.(g_gen) : g_gen
107109

108110
f(u,p,t) = f_oop(u,p,t)
109111
f(du,u,p,t) = f_iip(du,u,p,t)
110112
g(u,p,t) = g_oop(u,p,t)
111113
g(du,u,p,t) = g_iip(du,u,p,t)
112114

113115
if tgrad
114-
tgrad_oop,tgrad_iip = ModelingToolkit.eval.(generate_tgrad(sys, dvs, ps; expression=Val{true}, kwargs...))
116+
tgrad_gen = generate_tgrad(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
117+
tgrad_oop,tgrad_iip = eval_expression ? ModelingToolkit.eval.(tgrad_gen) : tgrad_gen
115118
_tgrad(u,p,t) = tgrad_oop(u,p,t)
116119
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
117120
else
118121
_tgrad = nothing
119122
end
120123

121124
if jac
122-
jac_oop,jac_iip = ModelingToolkit.eval.(generate_jacobian(sys, dvs, ps; expression=Val{true}, sparse=sparse, kwargs...))
125+
jac_gen = generate_jacobian(sys, dvs, ps; expression=Val{eval_expression}, sparse=sparse, kwargs...)
126+
jac_oop,jac_iip = eval_expression ? ModelingToolkit.eval.(jac_gen) : jac_gen
123127
_jac(u,p,t) = jac_oop(u,p,t)
124128
_jac(J,u,p,t) = jac_iip(J,u,p,t)
125129
else
@@ -128,8 +132,8 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
128132

129133
if Wfact
130134
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true; expression=Val{true}, kwargs...)
131-
Wfact_oop, Wfact_iip = ModelingToolkit.eval.(tmp_Wfact)
132-
Wfact_oop_t, Wfact_iip_t = ModelingToolkit.eval.(tmp_Wfact_t)
135+
Wfact_oop, Wfact_iip = eval_expression ? ModelingToolkit.eval.(tmp_Wfact) : tmp_Wfact
136+
Wfact_oop_t, Wfact_iip_t = eval_expression ? ModelingToolkit.eval.(tmp_Wfact_t) : tmp_Wfact_t
133137
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
134138
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
135139
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)
@@ -253,6 +257,7 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
253257
checkbounds = false, sparse = false,
254258
sparsenoise = sparse,
255259
linenumbers = true, parallel=SerialForm(),
260+
eval_expression = true,
256261
kwargs...) where iip
257262

258263
dvs = states(sys)
@@ -262,7 +267,7 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
262267
f = SDEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,
263268
checkbounds=checkbounds,
264269
linenumbers=linenumbers,parallel=parallel,
265-
sparse=sparse)
270+
sparse=sparse, eval_expression=eval_expression)
266271
if typeof(sys.noiseeqs) <: AbstractVector
267272
noise_rate_prototype = nothing
268273
elseif sparsenoise

test/sdesystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,11 @@ prob = SDEProblem(de,u0map,(0.0,100.0),parammap,sparsenoise=true)
7070
@test size(prob.noise_rate_prototype) == (3,3)
7171
@test prob.noise_rate_prototype isa SparseMatrixCSC
7272
sol = solve(prob,EM(),dt=0.001)
73+
74+
# Test eval_expression=false
75+
function test_SDEFunction_no_eval()
76+
# Need to test within a function scope to trigger world age issues
77+
f = SDEFunction(de, eval_expression=false)
78+
@test f([1.0,0.0,0.0], (10.0,26.0,2.33), (0.0,100.0)) [-10.0, 26.0, 0.0]
79+
end
80+
test_SDEFunction_no_eval()

0 commit comments

Comments
 (0)