Skip to content

[Nonlinear] improve test coverage of operators.jl #2650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 24 additions & 28 deletions src/Nonlinear/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,31 @@
push!(out.args, _create_binary_switch(ids[2:end], exprs[2:end]))
end
return out
else
mid = length(exprs) >>> 1
return Expr(
:if,
Expr(:call, :(<=), :id, ids[mid]),
_create_binary_switch(ids[1:mid], exprs[1:mid]),
_create_binary_switch(ids[mid+1:end], exprs[mid+1:end]),
)
end
mid = length(exprs) >>> 1
return Expr(
:if,
Expr(:call, :(<=), :id, ids[mid]),
_create_binary_switch(ids[1:mid], exprs[1:mid]),
_create_binary_switch(ids[mid+1:end], exprs[mid+1:end]),
)
end

# We use a let block here for `expr` to create a local variable that does not
# persist in the scope of the module. All we care about is the _eval_univariate
# function that is eval'd as a result.
let exprs = map(SYMBOLIC_UNIVARIATE_EXPRESSIONS) do arg
let
exprs = map(SYMBOLIC_UNIVARIATE_EXPRESSIONS) do arg
return :(return $(arg[1])(x), $(arg[2]))
end
@eval @inline function _eval_univariate(id, x::T) where {T}
$(_create_binary_switch(1:length(exprs), exprs))
return error("Invalid operator_id")
end
end

# We use a let block here for `expr` to create a local variable that does not
# persist in the scope of the module. All we care about is the function that is
# eval'd as a result.
let exprs = map(SYMBOLIC_UNIVARIATE_EXPRESSIONS) do arg
if arg === :(nothing) # f''(x) isn't defined
:(error("Invalid operator_id"))
else
:(return $(arg[3]))
end
return error("Invalid id for univariate operator: $id")
end
∇²f_exprs = map(arg -> :(return $(arg[3])), SYMBOLIC_UNIVARIATE_EXPRESSIONS)

Check warning on line 35 in src/Nonlinear/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/Nonlinear/operators.jl#L35

Added line #L35 was not covered by tests
@eval @inline function _eval_univariate_2nd_deriv(id, x::T) where {T}
$(_create_binary_switch(1:length(exprs), exprs))
return error("Invalid operator_id")
$(_create_binary_switch(1:length(∇²f_exprs), ∇²f_exprs))
return error("Invalid id for univariate operator: $id")
end
end

Expand Down Expand Up @@ -339,7 +328,7 @@
y = f(zeros(dimension)...)
end
catch
# We hit some other error, perhaps we called a function like log(0).
# We hit some other error, perhaps we called a function like log(-1).
# Ignore for now, and hope that a useful error is shown to the user
# during the solve.
end
Expand All @@ -363,7 +352,7 @@
_FORWARD_DIFF_METHOD_ERROR_HELPER,
)
end
# We hit some other error, perhaps we called a function like log(0).
# We hit some other error, perhaps we called a function like log(-1).
# Ignore for now, and hope that a useful error is shown to the user
# during the solve.
end
Expand Down Expand Up @@ -747,7 +736,12 @@
x::T,
) where {T}
if id <= registry.univariate_user_operator_start
return _eval_univariate_2nd_deriv(id, x)::T
ret = _eval_univariate_2nd_deriv(id, x)
if ret === nothing
op = registry.univariate_operators[id]
error("Hessian is not defined for operator $op")
end
return ret::T
end
offset = id - registry.univariate_user_operator_start
operator = registry.registered_univariate_operators[offset]
Expand Down Expand Up @@ -910,13 +904,15 @@
op::Symbol,
H::AbstractMatrix,
x::AbstractVector{T},
) where {T}
)::Bool where {T}

Evaluate the Hessian of operator `∇²op(x)`, where `op` is a multivariate
function in `registry`.

The Hessian is stored in the lower-triangular part of the matrix `H`.

Returns a `Bool` indicating whether non-zeros were stored in the matrix.

!!! note
Implementations of the Hessian operators will not fill structural zeros.
Therefore, before calling this function you should pre-populate the matrix
Expand Down
79 changes: 78 additions & 1 deletion test/Nonlinear/Nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,28 @@ function test_eval_univariate_function()
return
end

function test_eval_univariate_missing_hessian()
r = Nonlinear.OperatorRegistry()
x = 2.0
@test Nonlinear.eval_univariate_function(r, :asec, x) ≈ asec(x)
@test Nonlinear.eval_univariate_gradient(r, :asec, x) ≈
1 / (abs(x) * sqrt(x^2 - 1))
@test_throws(
ErrorException("Hessian is not defined for operator asec"),
Nonlinear.eval_univariate_hessian(r, :asec, x),
)
return
end

function test_eval_univariate_hessian_bad_id()
r = Nonlinear.OperatorRegistry()
err = ErrorException("Invalid id for univariate operator: -1")
@test_throws err Nonlinear.eval_univariate_function(r, -1, 1.0)
@test_throws err Nonlinear.eval_univariate_gradient(r, -1, 1.0)
@test_throws err Nonlinear.eval_univariate_hessian(r, -1, 1.0)
return
end

function test_eval_univariate_gradient()
r = Nonlinear.OperatorRegistry()
for (op, x, y) in [
Expand Down Expand Up @@ -594,7 +616,29 @@ function test_eval_multivariate_gradient_mult()
x = [1.1, 0.0, 2.2]
g = zeros(3)
Nonlinear.eval_multivariate_gradient(r, :*, g, x)
@test g == [0.0, 1.1 * 2.2, 0.0]
@test g ≈ [0.0, 1.1 * 2.2, 0.0]
x = [1.1, 3.3, 2.2]
Nonlinear.eval_multivariate_gradient(r, :*, g, x)
@test g ≈ [3.3 * 2.2, 1.1 * 2.2, 1.1 * 3.3]
return
end

function test_eval_multivariate_gradient_univariate_mult()
r = Nonlinear.OperatorRegistry()
x = [1.1]
g = zeros(1)
Nonlinear.eval_multivariate_gradient(r, :*, g, x)
@test g == [1.0]
return
end

function test_eval_multivariate_hessian_shortcut()
r = Nonlinear.OperatorRegistry()
x = [1.1]
H = LinearAlgebra.LowerTriangular(zeros(1, 1))
for op in (:+, :-, :ifelse)
@test !MOI.Nonlinear.eval_multivariate_hessian(r, op, H, x)
end
return
end

Expand Down Expand Up @@ -670,6 +714,18 @@ function test_eval_multivariate_function_registered()
return
end

function test_eval_multivariate_function_registered_log()
r = Nonlinear.OperatorRegistry()
f(x...) = log(x[1] - 1)
Nonlinear.register_operator(r, :f, 2, f)
x = [1.1, 2.2]
@test Nonlinear.eval_multivariate_function(r, :f, x) ≈ f(x...)
x = [0.0, 0.0]
g = zeros(2)
@test_throws DomainError Nonlinear.eval_multivariate_gradient(r, :f, g, x)
return
end

function test_eval_multivariate_function_method_error()
r = Nonlinear.OperatorRegistry()
function f(x...)
Expand Down Expand Up @@ -1327,6 +1383,27 @@ function test_convert_to_expr()
return
end

function test_create_binary_switch()
target = Expr(
:if,
Expr(:call, :(<=), :id, 2),
Expr(
:if,
Expr(:call, :(==), :id, 1),
:a,
Expr(:if, Expr(:call, :(==), :id, 2), :b),
),
Expr(
:if,
Expr(:call, :(==), :id, 3),
:c,
Expr(:if, Expr(:call, :(==), :id, 4), :d),
),
)
@test MOI.Nonlinear._create_binary_switch(1:4, [:a, :b, :c, :d]) == target
return
end

end # TestNonlinear

TestNonlinear.runtests()
Expand Down
Loading