Skip to content

Commit 388eb29

Browse files
Merge pull request #477 from dpad/allow-gg-ODEfunctions
Provide as_expression=false option to ODEs/SDEs, for specific use cases that need to avoid world age issues (generally not needed)
2 parents af86f8c + 69b380b commit 388eb29

File tree

4 files changed

+71
-20
lines changed

4 files changed

+71
-20
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,33 +142,36 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
142142
ps = parameters(sys), u0 = nothing;
143143
version = nothing, tgrad=false,
144144
jac = false, Wfact = false,
145-
sparse = false,
145+
sparse = false, eval_expression = true,
146146
kwargs...) where {iip}
147147

148-
f_oop,f_iip = ModelingToolkit.eval.(generate_function(sys, dvs, ps; expression=Val{true}, kwargs...))
148+
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
149+
f_oop,f_iip = eval_expression ? ModelingToolkit.eval.(f_gen) : f_gen
149150
f(u,p,t) = f_oop(u,p,t)
150151
f(du,u,p,t) = f_iip(du,u,p,t)
151152

152153
if tgrad
153-
tgrad_oop,tgrad_iip = ModelingToolkit.eval.(generate_tgrad(sys, dvs, ps; expression=Val{true}, kwargs...))
154+
tgrad_gen = generate_tgrad(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
155+
tgrad_oop,tgrad_iip = eval_expression ? ModelingToolkit.eval.(tgrad_gen) : tgrad_gen
154156
_tgrad(u,p,t) = tgrad_oop(u,p,t)
155157
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
156158
else
157159
_tgrad = nothing
158160
end
159161

160162
if jac
161-
jac_oop,jac_iip = ModelingToolkit.eval.(generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{true}, kwargs...))
163+
jac_gen = generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{eval_expression}, kwargs...)
164+
jac_oop,jac_iip = eval_expression ? ModelingToolkit.eval.(jac_gen) : jac_gen
162165
_jac(u,p,t) = jac_oop(u,p,t)
163166
_jac(J,u,p,t) = jac_iip(J,u,p,t)
164167
else
165168
_jac = nothing
166169
end
167170

168171
if Wfact
169-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{true}, kwargs...)
170-
Wfact_oop, Wfact_iip = ModelingToolkit.eval.(tmp_Wfact)
171-
Wfact_oop_t, Wfact_iip_t = ModelingToolkit.eval.(tmp_Wfact_t)
172+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
173+
Wfact_oop, Wfact_iip = eval_expression ? ModelingToolkit.eval.(tmp_Wfact) : tmp_Wfact
174+
Wfact_oop_t, Wfact_iip_t = eval_expression ? ModelingToolkit.eval.(tmp_Wfact_t) : tmp_Wfact_t
172175
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
173176
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
174177
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)
@@ -288,14 +291,15 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
288291
jac = false, Wfact = false,
289292
checkbounds = false, sparse = false,
290293
linenumbers = true, parallel=SerialForm(),
294+
eval_expression = true,
291295
kwargs...) where iip
292296
dvs = states(sys)
293297
ps = parameters(sys)
294298
u0 = varmap_to_vars(u0map,dvs)
295299
p = varmap_to_vars(parammap,ps)
296300
f = ODEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
297301
linenumbers=linenumbers,parallel=parallel,
298-
sparse=sparse)
302+
sparse=sparse,eval_expression=eval_expression)
299303
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
300304
end
301305

src/systems/diffeqs/sdesystem.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,29 @@ respectively.
101101
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps,
102102
u0 = nothing;
103103
version = nothing, tgrad=false, sparse = false,
104-
jac = false, Wfact = false, kwargs...) where {iip}
105-
f_oop,f_iip = ModelingToolkit.eval.(generate_function(sys, dvs, ps; expression=Val{true}, kwargs...))
106-
g_oop,g_iip = ModelingToolkit.eval.(generate_diffusion_function(sys, dvs, ps; expression=Val{true}, kwargs...))
104+
jac = false, Wfact = false, eval_expression = true, kwargs...) where {iip}
105+
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
106+
f_oop,f_iip = eval_expression ? ModelingToolkit.eval.(f_gen) : f_gen
107+
g_gen = generate_diffusion_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
108+
g_oop,g_iip = eval_expression ? ModelingToolkit.eval.(g_gen) : g_gen
107109

108110
f(u,p,t) = f_oop(u,p,t)
109111
f(du,u,p,t) = f_iip(du,u,p,t)
110112
g(u,p,t) = g_oop(u,p,t)
111113
g(du,u,p,t) = g_iip(du,u,p,t)
112114

113115
if tgrad
114-
tgrad_oop,tgrad_iip = ModelingToolkit.eval.(generate_tgrad(sys, dvs, ps; expression=Val{true}, kwargs...))
116+
tgrad_gen = generate_tgrad(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
117+
tgrad_oop,tgrad_iip = eval_expression ? ModelingToolkit.eval.(tgrad_gen) : tgrad_gen
115118
_tgrad(u,p,t) = tgrad_oop(u,p,t)
116119
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
117120
else
118121
_tgrad = nothing
119122
end
120123

121124
if jac
122-
jac_oop,jac_iip = ModelingToolkit.eval.(generate_jacobian(sys, dvs, ps; expression=Val{true}, sparse=sparse, kwargs...))
125+
jac_gen = generate_jacobian(sys, dvs, ps; expression=Val{eval_expression}, sparse=sparse, kwargs...)
126+
jac_oop,jac_iip = eval_expression ? ModelingToolkit.eval.(jac_gen) : jac_gen
123127
_jac(u,p,t) = jac_oop(u,p,t)
124128
_jac(J,u,p,t) = jac_iip(J,u,p,t)
125129
else
@@ -128,8 +132,8 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
128132

129133
if Wfact
130134
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true; expression=Val{true}, kwargs...)
131-
Wfact_oop, Wfact_iip = ModelingToolkit.eval.(tmp_Wfact)
132-
Wfact_oop_t, Wfact_iip_t = ModelingToolkit.eval.(tmp_Wfact_t)
135+
Wfact_oop, Wfact_iip = eval_expression ? ModelingToolkit.eval.(tmp_Wfact) : tmp_Wfact
136+
Wfact_oop_t, Wfact_iip_t = eval_expression ? ModelingToolkit.eval.(tmp_Wfact_t) : tmp_Wfact_t
133137
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
134138
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
135139
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)
@@ -253,6 +257,7 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
253257
checkbounds = false, sparse = false,
254258
sparsenoise = sparse,
255259
linenumbers = true, parallel=SerialForm(),
260+
eval_expression = true,
256261
kwargs...) where iip
257262

258263
dvs = states(sys)
@@ -262,7 +267,7 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
262267
f = SDEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,
263268
checkbounds=checkbounds,
264269
linenumbers=linenumbers,parallel=parallel,
265-
sparse=sparse)
270+
sparse=sparse, eval_expression=eval_expression)
266271
if typeof(sys.noiseeqs) <: AbstractVector
267272
noise_rate_prototype = nothing
268273
elseif sparsenoise

test/function_registration.jl

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# appropriately calls the registered functions, whether the call is
44
# qualified (with a module name) or not.
55

6+
67
# TEST: Function registration in a module.
78
# ------------------------------------------------
89
module MyModule
@@ -20,9 +21,11 @@ module MyModule
2021
sys = ODESystem([eq], t, [u], [x])
2122
fun = ODEFunction(sys)
2223

23-
@test fun([0.5], [5.0], 0.) == [30.0]
24+
u0 = 5.0
25+
@test fun([0.5], [u0], 0.) == [do_something(u0) * 2]
2426
end
2527

28+
2629
# TEST: Function registration in a nested module.
2730
# ------------------------------------------------
2831
module MyModule2
@@ -41,10 +44,12 @@ module MyModule2
4144
sys = ODESystem([eq], t, [u], [x])
4245
fun = ODEFunction(sys)
4346

44-
@test fun([0.5], [3.0], 0.) == [46.0]
47+
u0 = 3.0
48+
@test fun([0.5], [u0], 0.) == [do_something_2(u0) * 2]
4549
end
4650
end
4751

52+
4853
# TEST: Function registration outside any modules.
4954
# ------------------------------------------------
5055
using ModelingToolkit, DiffEqBase, LinearAlgebra, Test
@@ -61,9 +66,12 @@ eq = Dt(u) ~ do_something_3(x) + (@__MODULE__).do_something_3(x)
6166
sys = ODESystem([eq], t, [u], [x])
6267
fun = ODEFunction(sys)
6368

64-
@test fun([0.5], [7.0], 0.) == [74.0]
69+
u0 = 7.0
70+
@test fun([0.5], [u0], 0.) == [do_something_3(u0) * 2]
71+
6572

66-
# derivative
73+
# TEST: Function registration works with derivatives.
74+
# ---------------------------------------------------
6775
foo(x, y) = sin(x) * cos(y)
6876
@parameters t; @variables x(t) y(t) z(t); @derivatives D'~t;
6977
@register foo(x, y)
@@ -74,3 +82,29 @@ expr = foo(x, y)
7482
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{1}) = cos(x) * cos(y) # derivative w.r.t. the first argument
7583
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{2}) = -sin(x) * sin(y) # derivative w.r.t. the second argument
7684
@test isequal(expand_derivatives(D(foo(x, y))), expand_derivatives(D(sin(x) * cos(y))))
85+
86+
87+
# TEST: Function registration run from inside a function.
88+
# -------------------------------------------------------
89+
# This tests that we can get around the world age issue by falling back to
90+
# GeneralizedGenerated instead of function expressions.
91+
# Might be useful in cases where someone wants to define functions that build
92+
# up and use ODEFunctions given some parameters.
93+
function do_something_4(a)
94+
a + 30
95+
end
96+
@register do_something_4(a)
97+
function build_ode()
98+
@parameters t x
99+
@variables u(t)
100+
@derivatives Dt'~t
101+
eq = Dt(u) ~ do_something_4(x) + (@__MODULE__).do_something_4(x)
102+
sys = ODESystem([eq], t, [u], [x])
103+
fun = ODEFunction(sys, eval_expression=false)
104+
end
105+
function run_test()
106+
fun = build_ode()
107+
u0 = 10.0
108+
@test fun([0.5], [u0], 0.) == [do_something_4(u0) * 2]
109+
end
110+
run_test()

test/sdesystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,11 @@ prob = SDEProblem(de,u0map,(0.0,100.0),parammap,sparsenoise=true)
7070
@test size(prob.noise_rate_prototype) == (3,3)
7171
@test prob.noise_rate_prototype isa SparseMatrixCSC
7272
sol = solve(prob,EM(),dt=0.001)
73+
74+
# Test eval_expression=false
75+
function test_SDEFunction_no_eval()
76+
# Need to test within a function scope to trigger world age issues
77+
f = SDEFunction(de, eval_expression=false)
78+
@test f([1.0,0.0,0.0], (10.0,26.0,2.33), (0.0,100.0)) [-10.0, 26.0, 0.0]
79+
end
80+
test_SDEFunction_no_eval()

0 commit comments

Comments
 (0)