Skip to content

Commit d858058

Browse files
committed
Provide eval_expression=false option in ODEFunction, for very specific use cases that need to avoid world age issues arising from evaluating expressions.
1 parent 90bace8 commit d858058

File tree

2 files changed

+48
-11
lines changed

2 files changed

+48
-11
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,33 +142,36 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
142142
ps = parameters(sys), u0 = nothing;
143143
version = nothing, tgrad=false,
144144
jac = false, Wfact = false,
145-
sparse = false,
145+
sparse = false, eval_expression=true,
146146
kwargs...) where {iip}
147147

148-
f_oop,f_iip = ModelingToolkit.eval.(generate_function(sys, dvs, ps; expression=Val{true}, kwargs...))
148+
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
149+
f_oop,f_iip = eval_expression ? ModelingToolkit.eval.(f_gen) : f_gen
149150
f(u,p,t) = f_oop(u,p,t)
150151
f(du,u,p,t) = f_iip(du,u,p,t)
151152

152153
if tgrad
153-
tgrad_oop,tgrad_iip = ModelingToolkit.eval.(generate_tgrad(sys, dvs, ps; expression=Val{true}, kwargs...))
154+
tgrad_gen = generate_tgrad(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
155+
tgrad_oop,tgrad_iip = eval_expression ? ModelingToolkit.eval.(tgrad_gen) : tgrad_gen
154156
_tgrad(u,p,t) = tgrad_oop(u,p,t)
155157
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
156158
else
157159
_tgrad = nothing
158160
end
159161

160162
if jac
161-
jac_oop,jac_iip = ModelingToolkit.eval.(generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{true}, kwargs...))
163+
jac_gen = generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{eval_expression}, kwargs...)
164+
jac_oop,jac_iip = eval_expression ? ModelingToolkit.eval.(jac_gen) : jac_gen
162165
_jac(u,p,t) = jac_oop(u,p,t)
163166
_jac(J,u,p,t) = jac_iip(J,u,p,t)
164167
else
165168
_jac = nothing
166169
end
167170

168171
if Wfact
169-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{true}, kwargs...)
170-
Wfact_oop, Wfact_iip = ModelingToolkit.eval.(tmp_Wfact)
171-
Wfact_oop_t, Wfact_iip_t = ModelingToolkit.eval.(tmp_Wfact_t)
172+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
173+
Wfact_oop, Wfact_iip = eval_expression ? ModelingToolkit.eval.(tmp_Wfact) : tmp_Wfact
174+
Wfact_oop_t, Wfact_iip_t = eval_expression ? ModelingToolkit.eval.(tmp_Wfact_t) : tmp_Wfact_t
172175
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
173176
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
174177
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)

test/function_registration.jl

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# appropriately calls the registered functions, whether the call is
44
# qualified (with a module name) or not.
55

6+
67
# TEST: Function registration in a module.
78
# ------------------------------------------------
89
module MyModule
@@ -20,9 +21,11 @@ module MyModule
2021
sys = ODESystem([eq], t, [u], [x])
2122
fun = ODEFunction(sys)
2223

23-
@test fun([0.5], [5.0], 0.) == [30.0]
24+
u0 = 5.0
25+
@test fun([0.5], [u0], 0.) == [do_something(u0) * 2]
2426
end
2527

28+
2629
# TEST: Function registration in a nested module.
2730
# ------------------------------------------------
2831
module MyModule2
@@ -41,10 +44,12 @@ module MyModule2
4144
sys = ODESystem([eq], t, [u], [x])
4245
fun = ODEFunction(sys)
4346

44-
@test fun([0.5], [3.0], 0.) == [46.0]
47+
u0 = 3.0
48+
@test fun([0.5], [u0], 0.) == [do_something_2(u0) * 2]
4549
end
4650
end
4751

52+
4853
# TEST: Function registration outside any modules.
4954
# ------------------------------------------------
5055
using ModelingToolkit, DiffEqBase, LinearAlgebra, Test
@@ -61,9 +66,12 @@ eq = Dt(u) ~ do_something_3(x) + (@__MODULE__).do_something_3(x)
6166
sys = ODESystem([eq], t, [u], [x])
6267
fun = ODEFunction(sys)
6368

64-
@test fun([0.5], [7.0], 0.) == [74.0]
69+
u0 = 7.0
70+
@test fun([0.5], [u0], 0.) == [do_something_3(u0) * 2]
71+
6572

66-
# derivative
73+
# TEST: Function registration works with derivatives.
74+
# ---------------------------------------------------
6775
foo(x, y) = sin(x) * cos(y)
6876
@parameters t; @variables x(t) y(t) z(t); @derivatives D'~t;
6977
@register foo(x, y)
@@ -74,3 +82,29 @@ expr = foo(x, y)
7482
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{1}) = cos(x) * cos(y) # derivative w.r.t. the first argument
7583
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{2}) = -sin(x) * sin(y) # derivative w.r.t. the second argument
7684
@test isequal(expand_derivatives(D(foo(x, y))), expand_derivatives(D(sin(x) * cos(y))))
85+
86+
87+
# TEST: Function registration run from inside a function.
88+
# -------------------------------------------------------
89+
# This tests that we can get around the world age issue by falling back to
90+
# GeneralizedGenerated instead of function expressions.
91+
# Might be useful in cases where someone wants to define functions that build
92+
# up and use ODEFunctions given some parameters.
93+
function do_something_4(a)
94+
a + 30
95+
end
96+
@register do_something_4(a)
97+
function build_ode()
98+
@parameters t x
99+
@variables u(t)
100+
@derivatives Dt'~t
101+
eq = Dt(u) ~ do_something_4(x) + (@__MODULE__).do_something_4(x)
102+
sys = ODESystem([eq], t, [u], [x])
103+
fun = ODEFunction(sys, eval_expression=false)
104+
end
105+
function run_test()
106+
fun = build_ode()
107+
u0 = 10.0
108+
@test fun([0.5], [u0], 0.) == [do_something_4(u0) * 2]
109+
end
110+
run_test()

0 commit comments

Comments
 (0)