Skip to content

Commit 0a417ea

Browse files
committed
Test state contains gradient/hessian using callback
1 parent ab9bc1b commit 0a417ea

File tree

1 file changed

+44
-8
lines changed

1 file changed

+44
-8
lines changed

lib/OptimizationOptimJL/test/runtests.jl

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@ using OptimizationOptimJL,
33
Random, ModelingToolkit
44
using Test
55

6+
struct CallbackTester
7+
dim::Int
8+
has_grad::Bool
9+
has_hess::Bool
10+
end
11+
function CallbackTester(dim::Int; has_grad = false, has_hess = false)
12+
CallbackTester(dim, has_grad, has_hess)
13+
end
14+
15+
function (cb::CallbackTester)(state, loss_val)
16+
@test length(state.u) == cb.dim
17+
if cb.has_grad
18+
@test state.grad isa AbstractVector
19+
@test length(state.grad) == cb.dim
20+
else
21+
@test state.grad === nothing
22+
end
23+
if cb.has_hess
24+
@test state.hess isa AbstractMatrix
25+
@test size(state.hess) == (cb.dim, cb.dim)
26+
else
27+
@test state.hess === nothing
28+
end
29+
return false
30+
end
31+
632
@testset "OptimizationOptimJL.jl" begin
733
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
834
x0 = zeros(2)
@@ -13,34 +39,43 @@ using Test
1339
sol = solve(prob,
1440
Optim.NelderMead(;
1541
initial_simplex = Optim.AffineSimplexer(; a = 0.025,
16-
b = 0.5)))
42+
b = 0.5)); callback = CallbackTester(length(x0)))
1743
@test 10 * sol.objective < l1
1844

1945
f = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
2046

2147
Random.seed!(1234)
2248
prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
23-
sol = solve(prob, SAMIN())
49+
sol = solve(prob, SAMIN(); callback = CallbackTester(length(x0)))
2450
@test 10 * sol.objective < l1
2551

26-
sol = solve(prob, Optim.IPNewton())
52+
sol = solve(
53+
prob, Optim.IPNewton();
54+
callback = CallbackTester(length(x0); has_grad = true, has_hess = true)
55+
)
2756
@test 10 * sol.objective < l1
2857

2958
prob = OptimizationProblem(f, x0, _p)
3059
Random.seed!(1234)
31-
sol = solve(prob, SimulatedAnnealing())
60+
sol = solve(prob, SimulatedAnnealing(); callback = CallbackTester(length(x0)))
3261
@test 10 * sol.objective < l1
3362

34-
sol = solve(prob, Optim.BFGS())
63+
sol = solve(prob, Optim.BFGS(); callback = CallbackTester(length(x0); has_grad = true))
3564
@test 10 * sol.objective < l1
3665

37-
sol = solve(prob, Optim.Newton())
66+
sol = solve(
67+
prob, Optim.Newton();
68+
callback = CallbackTester(length(x0); has_grad = true, has_hess = true)
69+
)
3870
@test 10 * sol.objective < l1
3971

4072
sol = solve(prob, Optim.KrylovTrustRegion())
4173
@test 10 * sol.objective < l1
4274

43-
sol = solve(prob, Optim.BFGS(), maxiters = 1)
75+
sol = solve(
76+
prob, Optim.BFGS();
77+
maxiters = 1, callback = CallbackTester(length(x0); has_grad = true)
78+
)
4479
@test sol.original.iterations == 1
4580

4681
sol = solve(prob, Optim.BFGS(), maxiters = 1, local_maxiters = 2)
@@ -92,7 +127,8 @@ using Test
92127
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
93128

94129
prob = OptimizationProblem(optprob, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
95-
sol = solve(prob, Optim.Fminbox())
130+
sol = solve(
131+
prob, Optim.Fminbox(); callback = CallbackTester(length(x0); has_grad = true))
96132
@test 10 * sol.objective < l1
97133

98134
Random.seed!(1234)

0 commit comments

Comments
 (0)