@@ -9,7 +9,6 @@ export CMAEvolutionStrategyOpt
99struct CMAEvolutionStrategyOpt end
1010
1111SciMLBase. allowsbounds (:: CMAEvolutionStrategyOpt ) = true
12- SciMLBase. allowscallback (:: CMAEvolutionStrategyOpt ) = false # looks like `logger` kwarg can be used to pass it, so should be implemented
1312SciMLBase. supports_opt_cache_interface (opt:: CMAEvolutionStrategyOpt ) = true
1413
1514function __map_optimizer_args (prob:: OptimizationCache , opt:: CMAEvolutionStrategyOpt ;
@@ -23,7 +22,7 @@ function __map_optimizer_args(prob::OptimizationCache, opt::CMAEvolutionStrategy
2322 end
2423
2524 mapped_args = (; lower = prob. lb,
26- upper = prob. ub)
25+ upper = prob. ub, logger = CMAEvolutionStrategy . BasicLogger (prob . u0; verbosity = 0 , callback = callback) )
2726
2827 if ! isnothing (maxiters)
2928 mapped_args = (; mapped_args... , maxiter = maxiters)
@@ -74,12 +73,18 @@ function SciMLBase.__solve(cache::OptimizationCache{
7473
7574 cur, state = iterate (cache. data)
7675
77- function _cb (trace)
78- cb_call = cache. callback (decompose_trace (trace). metadata[" x" ], trace. value... )
76+ function _cb (opt, y, fvals, perm)
77+ curr_u = opt. logger. xbest[end ]
78+ opt_state = Optimization. OptimizationState (; iteration = length (opt. logger. fmedian),
79+ u = curr_u,
80+ objective = opt. logger. fbest[end ],
81+ solver_state = opt. logger)
82+
83+ cb_call = cache. callback (opt_state, x... )
7984 if ! (cb_call isa Bool)
8085 error (" The callback should return a boolean `halt` for whether to stop the optimization process." )
8186 end
82- cur, state = iterate (data, state)
87+ cur, state = iterate (cache . data, state)
8388 cb_call
8489 end
8590
@@ -100,11 +105,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
100105 t1 = time ()
101106
102107 opt_ret = opt_res. stop. reason
103-
108+ stats = Optimization . OptimizationStats (; iterations = length (opt_res . logger . fmedian), time = t1 - t0, fevals = length (opt_res . logger . fmedian))
104109 SciMLBase. build_solution (cache, cache. opt,
105110 opt_res. logger. xbest[end ],
106111 opt_res. logger. fbest[end ]; original = opt_res,
107- retcode = opt_ret, solve_time = t1 - t0)
112+ retcode = opt_ret,
113+ stats = stats)
108114end
109115
110116end
0 commit comments