Skip to content

Commit f7ac226

Browse files
committed
Inject registered functions into expressions used to build EvalFunc. Fixes the regression of #445 caused by #451.
1 parent 5ca00a7 commit f7ac226

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

src/build_function.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
125125
end
126126

127127
if expression == Val{true}
128-
return oop_ex
128+
return ModelingToolkit.inject_registered_module_functions(oop_ex)
129129
else
130130
_build_and_inject_function(@__MODULE__, oop_ex)
131131
end
132132
end
133133

134134
function _build_and_inject_function(mod::Module, ex)
135135
# Generate the function, which will process the expression
136-
runtimefn = GeneralizedGenerated.mk_function(@__MODULE__, ex)
136+
runtimefn = GeneralizedGenerated.mk_function(mod, ex)
137137

138138
# Extract the processed expression of the function body
139139
params = typeof(runtimefn).parameters
@@ -339,7 +339,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
339339
end
340340

341341
if expression == Val{true}
342-
return oop_ex, iip_ex
342+
return ModelingToolkit.inject_registered_module_functions(oop_ex), ModelingToolkit.inject_registered_module_functions(iip_ex)
343343
else
344344
return _build_and_inject_function(@__MODULE__, oop_ex), _build_and_inject_function(@__MODULE__, iip_ex)
345345
end

src/function_registration.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,21 @@ Base.one(::Operation) = 1
9999
const registered_external_functions = Dict{Symbol,Module}()
100100
function inject_registered_module_functions(expr)
101101
MacroTools.postwalk(expr) do x
102-
# We need to find all function calls in the expression...
103-
MacroTools.@capture(x, f_(xs__))
102+
# Find all function calls in the expression and extract the function
103+
# name and calling module.
104+
MacroTools.@capture(x, f_module_.f_name_(xs__))
105+
if isnothing(f_module)
106+
MacroTools.@capture(x, f_name_(xs__))
107+
end
104108

105-
if !isnothing(f) && f isa Expr && f.head == :. && f.args[2] isa QuoteNode
106-
# If the function call matches any of the functions we've
107-
# registered, set the calling module (which is probably
108-
# "ModelingToolkit") to the module it is registered to.
109-
f_name = f.args[2].value # function name
110-
f.args[1] = get(registered_external_functions, f_name, f.args[1])
109+
if !isnothing(f_name)
110+
# Set the calling module to the module that registered it.
111+
mod = get(registered_external_functions, f_name, f_module)
112+
if !isnothing(mod)
113+
x.args[1] = :(getproperty($mod, $(Meta.quot(f_name))))
114+
end
111115
end
112116

113-
# Make sure we rebuild the expression as is.
114117
return x
115118
end
116119
end

test/function_registration.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Each test here builds an ODEFunction including some user-registered
2+
# Operations. The test simply checks that calling the ODEFunction
3+
# appropriately calls the registered functions, whether the call is
4+
# qualified (with a module name) or not.
5+
16
# TEST: Function registration in a module.
27
# ------------------------------------------------
38
module MyModule
@@ -11,11 +16,11 @@ module MyModule
1116
end
1217
@register do_something(a)
1318

14-
eq = Dt(u) ~ do_something(x)
19+
eq = Dt(u) ~ do_something(x) + MyModule.do_something(x)
1520
sys = ODESystem([eq], t, [u], [x])
1621
fun = ODEFunction(sys)
1722

18-
@test_broken fun([0.5], [5.0], 0.) == [15.0]
23+
@test fun([0.5], [5.0], 0.) == [30.0]
1924
end
2025

2126
# TEST: Function registration in a nested module.
@@ -32,11 +37,11 @@ module MyModule2
3237
end
3338
@register do_something_2(a)
3439

35-
eq = Dt(u) ~ do_something_2(x)
40+
eq = Dt(u) ~ do_something_2(x) + MyNestedModule.do_something_2(x)
3641
sys = ODESystem([eq], t, [u], [x])
3742
fun = ODEFunction(sys)
3843

39-
@test_broken fun([0.5], [3.0], 0.) == [23.0]
44+
@test fun([0.5], [3.0], 0.) == [46.0]
4045
end
4146
end
4247

@@ -52,8 +57,8 @@ function do_something_3(a)
5257
end
5358
@register do_something_3(a)
5459

55-
eq = Dt(u) ~ do_something_3(x)
60+
eq = Dt(u) ~ do_something_3(x) + (@__MODULE__).do_something_3(x)
5661
sys = ODESystem([eq], t, [u], [x])
5762
fun = ODEFunction(sys)
5863

59-
@test_broken fun([0.5], [7.0], 0.) == [37.0]
64+
@test fun([0.5], [7.0], 0.) == [74.0]

0 commit comments

Comments
 (0)