Skip to content

Commit 9b0ab2c

Browse files
Merge pull request #368 from SciML/complex-u0-for-OptimizationOptimisers
error with complex number u0 in OptimizationOptimisers
2 parents becceb8 + d52ae70 commit 9b0ab2c

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

lib/OptimizationOptimisers/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
99
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1111
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
12+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1213

1314
[compat]
1415
julia = "1"

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function SciMLBase.__solve(prob::OptimizationProblem, opt::OptimisersOptimizers,
2323
G = copy(θ)
2424

2525
local x, min_err, min_θ
26-
min_err = typemax(eltype(prob.u0)) #dummy variables
26+
min_err = typemax(eltype(real(prob.u0))) #dummy variables
2727
min_opt = 1
2828
min_θ = prob.u0
2929

lib/OptimizationOptimisers/test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using OptimizationOptimisers, Optimization, ForwardDiff
22
using Test
3+
using Zygote
34

45
@testset "OptimizationOptimisers.jl" begin
56
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
@@ -17,4 +18,16 @@ using Test
1718
prob = OptimizationProblem(optprob, x0, _p)
1819
sol = solve(prob, Optimisers.ADAM(), maxiters = 1000, progress = false)
1920
@test 10 * sol.minimum < l1
21+
22+
x0 = 2 * ones(ComplexF64, 2)
23+
_p = ones(2)
24+
sumfunc(x0, _p) = sum(abs2, (x0 - _p))
25+
l1 = sumfunc(x0, _p)
26+
27+
optprob = OptimizationFunction(sumfunc, Optimization.AutoZygote())
28+
29+
prob = OptimizationProblem(optprob, x0, _p)
30+
31+
sol = solve(prob, Optimisers.ADAM(), maxiters = 1000)
32+
@test 10 * sol.minimum < l1
2033
end

0 commit comments

Comments
 (0)