Skip to content

Commit 65b853f

Browse files
Make manually passed derivatives be 3 argument functions
1 parent 22b7ded commit 65b853f

File tree

6 files changed

+35
-44
lines changed

6 files changed

+35
-44
lines changed

src/function/finitediff.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
5959
args...),
6060
θ, gradcache)
6161
else
62-
grad = f.grad
62+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
6363
end
6464

6565
if f.hess === nothing
@@ -71,7 +71,7 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
7171
updatecache(hesscache,
7272
θ))
7373
else
74-
hess = f.hess
74+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
7575
end
7676

7777
if f.hv === nothing
@@ -102,7 +102,7 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
102102
FiniteDiff.finite_difference_jacobian!(J, cons, θ, jaccache)
103103
end
104104
else
105-
cons_j = f.cons_j
105+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
106106
end
107107

108108
if cons !== nothing && f.cons_h === nothing
@@ -120,13 +120,11 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
120120
end
121121
end
122122
else
123-
cons_h = f.cons_h
123+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
124124
end
125125

126-
return OptimizationFunction{true}(f, adtype; grad = grad, hess = hess, hv = hv,
127-
cons = cons, cons_j = cons_j, cons_h = cons_h,
128-
cons_jac_colorvec = cons_jac_colorvec,
129-
hess_prototype = nothing,
130-
cons_jac_prototype = f.cons_jac_prototype,
131-
cons_hess_prototype = nothing)
126+
return OptimizationFunction{true}(f, adtype; grad=grad, hess=hess, hv=hv,
127+
cons=cons, cons_j=cons_j, cons_h=cons_h,
128+
cons_jac_colorvec = cons_jac_colorvec,
129+
hess_prototype=f.hess_prototype, cons_jac_prototype=f.cons_jac_prototype, cons_hess_prototype=f.cons_hess_prototype)
132130
end

src/function/forwarddiff.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ function instantiate_function(f::OptimizationFunction{true}, x,
5151
grad = (res, θ, args...) -> ForwardDiff.gradient!(res, x -> _f(x, args...), θ,
5252
gradcfg, Val{false}())
5353
else
54-
grad = f.grad
54+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
5555
end
5656

5757
if f.hess === nothing
5858
hesscfg = ForwardDiff.HessianConfig(_f, x, ForwardDiff.Chunk{chunksize}())
5959
hess = (res, θ, args...) -> ForwardDiff.hessian!(res, x -> _f(x, args...), θ,
6060
hesscfg, Val{false}())
6161
else
62-
hess = f.hess
62+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
6363
end
6464

6565
if f.hv === nothing
@@ -85,7 +85,7 @@ function instantiate_function(f::OptimizationFunction{true}, x,
8585
ForwardDiff.jacobian!(J, cons_oop, θ, cjconfig)
8686
end
8787
else
88-
cons_j = f.cons_j
88+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
8989
end
9090

9191
if cons !== nothing && f.cons_h === nothing
@@ -99,12 +99,10 @@ function instantiate_function(f::OptimizationFunction{true}, x,
9999
end
100100
end
101101
else
102-
cons_h = f.cons_h
102+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
103103
end
104104

105-
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
106-
cons = cons, cons_j = cons_j, cons_h = cons_h,
107-
hess_prototype = nothing,
108-
cons_jac_prototype = f.cons_jac_prototype,
109-
cons_hess_prototype = f.cons_hess_prototype)
105+
return OptimizationFunction{true}(f.f, adtype; grad=grad, hess=hess, hv=hv,
106+
cons=cons, cons_j=cons_j, cons_h=cons_h,
107+
hess_prototype=f.hess_prototype, cons_jac_prototype=f.cons_jac_prototype, cons_hess_prototype=f.cons_hess_prototype)
110108
end

src/function/mtk.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons = 0
1515
grad_oop, grad_iip = ModelingToolkit.generate_gradient(sys, expression = Val{false})
1616
grad(J, u) = (grad_iip(J, u, p); J)
1717
else
18-
grad = f.grad
18+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
1919
end
2020

2121
if f.hess === nothing
2222
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression = Val{false},
2323
sparse = adtype.obj_sparse)
2424
hess(H, u) = (hess_iip(H, u, p); H)
2525
else
26-
hess = f.hess
26+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
2727
end
2828

2929
if f.hv === nothing
@@ -69,7 +69,7 @@ function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons = 0
6969
jac_iip(J, θ, p)
7070
end
7171
else
72-
cons_j = f.cons_j
72+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
7373
end
7474

7575
if f.cons !== nothing && f.cons_h === nothing
@@ -82,7 +82,7 @@ function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons = 0
8282
cons_hess_iip(res, θ, p)
8383
end
8484
else
85-
cons_h = f.cons_h
85+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
8686
end
8787

8888
if adtype.obj_sparse

src/function/reversediff.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function instantiate_function(f, x, adtype::AutoReverseDiff, p = SciMLBase.NullP
5353
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ,
5454
ReverseDiff.GradientConfig(θ))
5555
else
56-
grad = f.grad
56+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
5757
end
5858

5959
if f.hess === nothing
@@ -70,7 +70,7 @@ function instantiate_function(f, x, adtype::AutoReverseDiff, p = SciMLBase.NullP
7070
end
7171
end
7272
else
73-
hess = f.hess
73+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
7474
end
7575

7676
if f.hv === nothing
@@ -84,9 +84,7 @@ function instantiate_function(f, x, adtype::AutoReverseDiff, p = SciMLBase.NullP
8484
hv = f.hv
8585
end
8686

87-
return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv,
88-
cons = nothing, cons_j = nothing, cons_h = nothing,
89-
hess_prototype = nothing,
90-
cons_jac_prototype = nothing,
91-
cons_hess_prototype = nothing)
87+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
88+
cons=nothing, cons_j=nothing, cons_h=nothing,
89+
hess_prototype=f.hess_prototype, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
9290
end

src/function/tracker.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ function instantiate_function(f, x, adtype::AutoTracker, p, num_cons = 0)
3636
res .= Tracker.data(Tracker.gradient(x -> _f(x, args...),
3737
θ)[1])
3838
else
39-
grad = f.grad
39+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
4040
end
4141

4242
if f.hess === nothing
4343
hess = (res, θ, args...) -> error("Hessian based methods not supported with Tracker backend, pass in the `hess` kwarg")
4444
else
45-
hess = f.hess
45+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
4646
end
4747

4848
if f.hv === nothing
@@ -51,9 +51,8 @@ function instantiate_function(f, x, adtype::AutoTracker, p, num_cons = 0)
5151
hv = f.hv
5252
end
5353

54-
return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv,
55-
cons = nothing, cons_j = nothing, cons_h = nothing,
56-
hess_prototype = nothing,
57-
cons_jac_prototype = nothing,
58-
cons_hess_prototype = nothing)
54+
55+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
56+
cons=nothing, cons_j=nothing, cons_h=nothing,
57+
hess_prototype=f.hess_prototype, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
5958
end

src/function/zygote.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0)
3636
θ)[1]) :
3737
res .= Zygote.gradient(x -> _f(x, args...), θ)[1]
3838
else
39-
grad = f.grad
39+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
4040
end
4141

4242
if f.hess === nothing
@@ -52,7 +52,7 @@ function instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0)
5252
end
5353
end
5454
else
55-
hess = f.hess
55+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
5656
end
5757

5858
if f.hv === nothing
@@ -66,9 +66,7 @@ function instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0)
6666
hv = f.hv
6767
end
6868

69-
return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv,
70-
cons = nothing, cons_j = nothing, cons_h = nothing,
71-
hess_prototype = nothing,
72-
cons_jac_prototype = nothing,
73-
cons_hess_prototype = nothing)
69+
return OptimizationFunction{false}(f, adtype; grad=grad, hess=hess, hv=hv,
70+
cons=nothing, cons_j=nothing, cons_h=nothing,
71+
hess_prototype=f.hess_prototype, cons_jac_prototype=nothing, cons_hess_prototype=nothing)
7472
end

0 commit comments

Comments
 (0)