Skip to content

Commit acb56d4

Browse files
authored
[Nonlinear] improve code coverage of operators.jl (#2671)
1 parent 178060e commit acb56d4

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

src/Nonlinear/operators.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,26 @@ function _create_binary_switch(ids, exprs)
2121
)
2222
end
2323

24-
# We use a let block here for `expr` to create a local variable that does not
25-
# persist in the scope of the module. All we care about is the _eval_univariate
26-
# function that is eval'd as a result.
27-
let
24+
function _generate_eval_univariate()
2825
exprs = map(SYMBOLIC_UNIVARIATE_EXPRESSIONS) do arg
2926
return :(return $(arg[1])(x), $(arg[2]))
3027
end
31-
@eval @inline function _eval_univariate(id, x::T) where {T}
32-
$(_create_binary_switch(1:length(exprs), exprs))
33-
return error("Invalid id for univariate operator: $id")
34-
end
35-
∇²f_exprs = map(arg -> :(return $(arg[3])), SYMBOLIC_UNIVARIATE_EXPRESSIONS)
36-
@eval @inline function _eval_univariate_2nd_deriv(id, x::T) where {T}
37-
$(_create_binary_switch(1:length(∇²f_exprs), ∇²f_exprs))
38-
return error("Invalid id for univariate operator: $id")
39-
end
28+
return _create_binary_switch(1:length(exprs), exprs)
29+
end
30+
31+
@eval @inline function _eval_univariate(id, x::T) where {T}
32+
$(_generate_eval_univariate())
33+
return error("Invalid id for univariate operator: $id")
34+
end
35+
36+
function _generate_eval_univariate_2nd_deriv()
37+
exprs = map(arg -> :(return $(arg[3])), SYMBOLIC_UNIVARIATE_EXPRESSIONS)
38+
return _create_binary_switch(1:length(exprs), exprs)
39+
end
40+
41+
@eval @inline function _eval_univariate_2nd_deriv(id, x::T) where {T}
42+
$(_generate_eval_univariate_2nd_deriv())
43+
return error("Invalid id for univariate operator: $id")
4044
end
4145

4246
struct _UnivariateOperator{F,F′,F′′}

test/Nonlinear/Nonlinear.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,22 @@ function test_create_binary_switch()
14011401
),
14021402
)
14031403
@test MOI.Nonlinear._create_binary_switch(1:4, [:a, :b, :c, :d]) == target
1404+
# Just test that these functions don't error. We'll test their contents by
1405+
# evaluating the actual fuctions that are `@eval`ed.
1406+
MOI.Nonlinear._generate_eval_univariate()
1407+
MOI.Nonlinear._generate_eval_univariate_2nd_deriv()
1408+
return
1409+
end
1410+
1411+
function test_intercept_ForwardDiff_MethodError()
1412+
r = Nonlinear.OperatorRegistry()
1413+
f(x::Float64) = sin(x)^2
1414+
g(x) = x > 1 ? f(x) : zero(x)
1415+
Nonlinear.register_operator(r, :g, 1, g)
1416+
@test Nonlinear.eval_univariate_function(r, :g, 0.0) == 0.0
1417+
@test Nonlinear.eval_univariate_function(r, :g, 2.0) sin(2.0)^2
1418+
@test Nonlinear.eval_univariate_gradient(r, :g, 0.0) == 0.0
1419+
@test_throws ErrorException Nonlinear.eval_univariate_gradient(r, :g, 2.0)
14041420
return
14051421
end
14061422

0 commit comments

Comments
 (0)