Skip to content

Commit c7347d2

Browse files
Update OptimizationZygoteExt.jl
1 parent f2001e9 commit c7347d2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ext/OptimizationZygoteExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoZygote, p,
1313
grad = function (res, θ, args...)
1414
val = Zygote.gradient(x -> _f(x, args...), θ)[1]
1515
if val === nothing
16-
res .= 0
16+
res .= zero(typeof(θ))
1717
else
1818
res .= val
1919
end
@@ -90,7 +90,7 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
9090
grad = function (res, θ, args...)
9191
val = Zygote.gradient(x -> _f(x, args...), θ)[1]
9292
if val === nothing
93-
res .= 0
93+
res .= zero(typeof(θ))
9494
else
9595
res .= val
9696
end

0 commit comments

Comments
 (0)