Skip to content

Commit a020b50

Browse files
authored
[Nonlinear] fix _UnsafeVectorView with [email protected] (#2708)
1 parent 981ddb5 commit a020b50

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

src/Nonlinear/ReverseAD/utils.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,25 @@ struct _UnsafeVectorView{T} <: DenseVector{T}
4242
ptr::Ptr{T}
4343
end
4444

45-
Base.getindex(x::_UnsafeVectorView, i) = unsafe_load(x.ptr, i + x.offset)
45+
function Base.getindex(x::_UnsafeVectorView, i::Integer)
46+
return unsafe_load(x.ptr, i + x.offset)
47+
end
48+
49+
Base.getindex(x::_UnsafeVectorView, i::CartesianIndex{1}) = getindex(x, i[1])
4650

47-
function Base.setindex!(x::_UnsafeVectorView, value, i)
51+
function Base.setindex!(x::_UnsafeVectorView, value, i::Integer)
52+
# We don't need to worry about `value` being the right type here because
53+
# x.ptr is a `::Ptr{T}`, so even though it is called `unsafe_store!`, there
54+
# is still a type convertion that happens so that we're not just chucking
55+
# the bits of value into `x.ptr`.
4856
unsafe_store!(x.ptr, value, i + x.offset)
4957
return value
5058
end
5159

60+
function Base.setindex!(x::_UnsafeVectorView, value, i::CartesianIndex{1})
61+
return setindex!(x, value, i[1])
62+
end
63+
5264
Base.length(v::_UnsafeVectorView) = v.len
5365

5466
Base.size(v::_UnsafeVectorView) = (v.len,)

test/Nonlinear/ReverseAD.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,65 @@ function test_toposort_subexpressions()
13041304
return
13051305
end
13061306

1307+
function test_eval_user_defined_operator_ForwardDiff_gradient!()
1308+
model = MOI.Nonlinear.Model()
1309+
x = MOI.VariableIndex.(1:4)
1310+
p = MOI.Nonlinear.add_parameter(model, 2.0)
1311+
ex = MOI.Nonlinear.add_expression(model, :($p * $(x[1])))
1312+
ψ(x) = sin(x)
1313+
t(x, y) = x + 3y
1314+
MOI.Nonlinear.register_operator(model, , 1, ψ)
1315+
MOI.Nonlinear.register_operator(model, :t, 2, t)
1316+
MOI.Nonlinear.add_constraint(
1317+
model,
1318+
:($ex^3 + sin($(x[2])) / ψ($(x[2])) + t($(x[3]), $(x[4]))),
1319+
MOI.LessThan(0.0),
1320+
)
1321+
d = MOI.Nonlinear.Evaluator(model, MOI.Nonlinear.SparseReverseMode(), x)
1322+
MOI.initialize(d, [:Jac])
1323+
X = [1.1, 1.2, 1.3, 1.4]
1324+
g = [NaN]
1325+
MOI.eval_constraint(d, g, X)
1326+
@test only(g) 17.148
1327+
@test MOI.jacobian_structure(d) == [(1, 1), (1, 2), (1, 3), (1, 4)]
1328+
J = [NaN, NaN, NaN, NaN]
1329+
MOI.eval_constraint_jacobian(d, J, X)
1330+
@test J [2.0^3 * 3.0 * 1.1^2, 0.0, 1.0, 3.0]
1331+
return
1332+
end
1333+
1334+
function test_eval_user_defined_operator_type_mismatch()
1335+
model = MOI.Nonlinear.Model()
1336+
x = MOI.VariableIndex.(1:4)
1337+
p = MOI.Nonlinear.add_parameter(model, 2.0)
1338+
ex = MOI.Nonlinear.add_expression(model, :($p * $(x[1])))
1339+
ψ(x) = sin(x)
1340+
t(x, y) = x + 3y
1341+
function ∇t(ret, x, y)
1342+
ret[1] = 1 # These are intentionally the wrong type
1343+
ret[2] = 3 // 1 # These are intentionally the wrong type
1344+
return
1345+
end
1346+
MOI.Nonlinear.register_operator(model, , 1, ψ, cos)
1347+
MOI.Nonlinear.register_operator(model, :t, 2, t, ∇t)
1348+
MOI.Nonlinear.add_constraint(
1349+
model,
1350+
:($ex^3 + sin($(x[2])) / ψ($(x[2])) + t($(x[3]), $(x[4]))),
1351+
MOI.LessThan(0.0),
1352+
)
1353+
d = MOI.Nonlinear.Evaluator(model, MOI.Nonlinear.SparseReverseMode(), x)
1354+
MOI.initialize(d, [:Jac])
1355+
X = [1.1, 1.2, 1.3, 1.4]
1356+
g = [NaN]
1357+
MOI.eval_constraint(d, g, X)
1358+
@test only(g) 17.148
1359+
@test MOI.jacobian_structure(d) == [(1, 1), (1, 2), (1, 3), (1, 4)]
1360+
J = [NaN, NaN, NaN, NaN]
1361+
MOI.eval_constraint_jacobian(d, J, X)
1362+
@test J [2.0^3 * 3.0 * 1.1^2, 0.0, 1.0, 3.0]
1363+
return
1364+
end
1365+
13071366
end # module
13081367

13091368
TestReverseAD.runtests()

0 commit comments

Comments
 (0)