Skip to content

Commit 9e9b704

Browse files
Make manually passed derivatives be 3 argument functions
1 parent 7af93cf commit 9e9b704

File tree

6 files changed

+26
-26
lines changed

6 files changed

+26
-26
lines changed

src/function/finitediff.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons=0)
5454
gradcache = FiniteDiff.GradientCache(x, x, adtype.fdtype)
5555
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res, x -> _f(x, args...), θ, gradcache)
5656
else
57-
grad = f.grad
57+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
5858
end
5959

6060
if f.hess === nothing
6161
hesscache = FiniteDiff.HessianCache(x, adtype.fdhtype)
6262
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res, x -> _f(x, args...), θ, updatecache(hesscache, θ))
6363
else
64-
hess = f.hess
64+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
6565
end
6666

6767
if f.hv === nothing
@@ -89,7 +89,7 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons=0)
8989
FiniteDiff.finite_difference_jacobian!(J, cons, θ, jaccache)
9090
end
9191
else
92-
cons_j = f.cons_j
92+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
9393
end
9494

9595
if cons !== nothing && f.cons_h === nothing
@@ -100,11 +100,11 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons=0)
100100
end
101101
end
102102
else
103-
cons_h = f.cons_h
103+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
104104
end
105105

106106
return OptimizationFunction{true}(f, adtype; grad=grad, hess=hess, hv=hv,
107107
cons=cons, cons_j=cons_j, cons_h=cons_h,
108108
cons_jac_colorvec = cons_jac_colorvec,
109-
hess_prototype=nothing, cons_jac_prototype=f.cons_jac_prototype, cons_hess_prototype=nothing)
109+
hess_prototype=f.hess_prototype, cons_jac_prototype=f.cons_jac_prototype, cons_hess_prototype=f.cons_hess_prototype)
110110
end

src/function/forwarddiff.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ function instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoForw
4949
gradcfg = ForwardDiff.GradientConfig(_f, x, ForwardDiff.Chunk{chunksize}())
5050
grad = (res, θ, args...) -> ForwardDiff.gradient!(res, x -> _f(x, args...), θ, gradcfg, Val{false}())
5151
else
52-
grad = f.grad
52+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
5353
end
5454

5555
if f.hess === nothing
5656
hesscfg = ForwardDiff.HessianConfig(_f, x, ForwardDiff.Chunk{chunksize}())
5757
hess = (res, θ, args...) -> ForwardDiff.hessian!(res, x -> _f(x, args...), θ, hesscfg, Val{false}())
5858
else
59-
hess = f.hess
59+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
6060
end
6161

6262
if f.hv === nothing
@@ -82,7 +82,7 @@ function instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoForw
8282
ForwardDiff.jacobian!(J, cons_oop, θ, cjconfig)
8383
end
8484
else
85-
cons_j = f.cons_j
85+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
8686
end
8787

8888
if cons !== nothing && f.cons_h === nothing
@@ -94,10 +94,10 @@ function instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoForw
9494
end
9595
end
9696
else
97-
cons_h = f.cons_h
97+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
9898
end
99-
99+
100100
return OptimizationFunction{true}(f.f, adtype; grad=grad, hess=hess, hv=hv,
101101
cons=cons, cons_j=cons_j, cons_h=cons_h,
102-
hess_prototype=nothing, cons_jac_prototype=f.cons_jac_prototype, cons_hess_prototype=f.cons_hess_prototype)
102+
hess_prototype=f.hess_prototype, cons_jac_prototype=f.cons_jac_prototype, cons_hess_prototype=f.cons_hess_prototype)
103103
end

src/function/mtk.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons=0)
1515
grad_oop, grad_iip = ModelingToolkit.generate_gradient(sys, expression=Val{false})
1616
grad(J, u) = (grad_iip(J, u, p); J)
1717
else
18-
grad = f.grad
18+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
1919
end
2020

2121
if f.hess === nothing
2222
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression=Val{false}, sparse = adtype.obj_sparse)
2323
hess(H, u) = (hess_iip(H, u, p); H)
2424
else
25-
hess = f.hess
25+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
2626
end
2727

2828
if f.hv === nothing
@@ -61,7 +61,7 @@ function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons=0)
6161
jac_iip(J, θ, p)
6262
end
6363
else
64-
cons_j = f.cons_j
64+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
6565
end
6666

6767
if f.cons !== nothing && f.cons_h === nothing
@@ -70,7 +70,7 @@ function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons=0)
7070
cons_hess_iip(res, θ, p)
7171
end
7272
else
73-
cons_h = f.cons_h
73+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
7474
end
7575

7676
if adtype.obj_sparse

src/function/reversediff.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function instantiate_function(f, x, adtype::AutoReverseDiff, p=SciMLBase.NullPar
5151
if f.grad === nothing
5252
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, ReverseDiff.GradientConfig(θ))
5353
else
54-
grad = f.grad
54+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
5555
end
5656

5757
if f.hess === nothing
@@ -67,7 +67,7 @@ function instantiate_function(f, x, adtype::AutoReverseDiff, p=SciMLBase.NullPar
6767
end
6868
end
6969
else
70-
hess = f.hess
70+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
7171
end
7272

7373

@@ -84,5 +84,5 @@ function instantiate_function(f, x, adtype::AutoReverseDiff, p=SciMLBase.NullPar
8484

8585
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
8686
cons=nothing, cons_j=nothing, cons_h=nothing,
87-
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
87+
hess_prototype=f.hess_prototype, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
8888
end

src/function/tracker.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ function instantiate_function(f, x, adtype::AutoTracker, p, num_cons = 0)
3030
if f.grad === nothing
3131
grad = (res, θ, args...) -> res isa DiffResults.DiffResult ? DiffResults.gradient!(res, Tracker.data(Tracker.gradient(x -> _f(x, args...), θ)[1])) : res .= Tracker.data(Tracker.gradient(x -> _f(x, args...), θ)[1])
3232
else
33-
grad = f.grad
33+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
3434
end
3535

3636
if f.hess === nothing
3737
hess = (res, θ, args...) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` kwarg")
3838
else
39-
hess = f.hess
39+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
4040
end
4141

4242
if f.hv === nothing
@@ -46,7 +46,7 @@ function instantiate_function(f, x, adtype::AutoTracker, p, num_cons = 0)
4646
end
4747

4848

49-
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
49+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
5050
cons=nothing, cons_j=nothing, cons_h=nothing,
51-
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
51+
hess_prototype=f.hess_prototype, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
5252
end

src/function/zygote.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0)
3131
if f.grad === nothing
3232
grad = (res, θ, args...) -> res isa DiffResults.DiffResult ? DiffResults.gradient!(res, Zygote.gradient(x -> _f(x, args...), θ)[1]) : res .= Zygote.gradient(x -> _f(x, args...), θ)[1]
3333
else
34-
grad = f.grad
34+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
3535
end
3636

3737
if f.hess === nothing
@@ -47,7 +47,7 @@ function instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0)
4747
end
4848
end
4949
else
50-
hess = f.hess
50+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
5151
end
5252

5353
if f.hv === nothing
@@ -61,7 +61,7 @@ function instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0)
6161
hv = f.hv
6262
end
6363

64-
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
64+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
6565
cons=nothing, cons_j=nothing, cons_h=nothing,
66-
hess_prototype=nothing, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
66+
hess_prototype=f.hess_prototype, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
6767
end

0 commit comments

Comments
 (0)