Skip to content

Commit ab9bc1b

Browse files
committed
Add gradient and hessian to state if available
1 parent 1723965 commit ab9bc1b

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,13 @@ function SciMLBase.__solve(cache::OptimizationCache{
133133
error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
134134

135135
function _cb(trace)
136-
θ = cache.opt isa Optim.NelderMead ? decompose_trace(trace).metadata["centroid"] :
137-
decompose_trace(trace).metadata["x"]
136+
metadata = decompose_trace(trace).metadata
137+
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
138138
opt_state = Optimization.OptimizationState(iter = trace.iteration,
139139
u = θ,
140140
objective = x[1],
141+
grad = get(metadata, "g(x)", nothing),
142+
hess = get(metadata, "h(x)", nothing),
141143
original = trace)
142144
cb_call = cache.callback(opt_state, x...)
143145
if !(cb_call isa Bool)
@@ -252,12 +254,15 @@ function SciMLBase.__solve(cache::OptimizationCache{
252254
cur, state = iterate(cache.data)
253255

254256
function _cb(trace)
257+
metadata = decompose_trace(trace).metadata
255258
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?
256-
decompose_trace(trace).metadata["centroid"] :
257-
decompose_trace(trace).metadata["x"]
259+
metadata["centroid"] :
260+
metadata["x"]
258261
opt_state = Optimization.OptimizationState(iter = trace.iteration,
259262
u = θ,
260263
objective = x[1],
264+
grad = get(metadata, "g(x)", nothing),
265+
hess = get(metadata, "h(x)", nothing),
261266
original = trace)
262267
cb_call = cache.callback(opt_state, x...)
263268
if !(cb_call isa Bool)
@@ -341,8 +346,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
341346
cur, state = iterate(cache.data)
342347

343348
function _cb(trace)
349+
metadata = decompose_trace(trace).metadata
344350
opt_state = Optimization.OptimizationState(iter = trace.iteration,
345-
u = decompose_trace(trace).metadata["x"],
351+
u = metadata["x"],
352+
grad = get(metadata, "g(x)", nothing),
353+
hess = get(metadata, "h(x)", nothing),
346354
objective = x[1],
347355
original = trace)
348356
cb_call = cache.callback(opt_state, x...)

0 commit comments

Comments
 (0)