Skip to content

Commit 7bec3f3

Browse files
authored
precise central point for power mean (#860)
* precise central point for power mean * formatting * typo * remove unused branch
1 parent 01488be commit 7bec3f3

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

src/Cones/hypopowermean.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -273,27 +273,27 @@ function dder3(cone::HypoPowerMean{T}, dir::AbstractVector{T}) where {T <: Real}
273273
return dder3
274274
end
275275

276-
# see analysis in
277-
# https://github.com/lkapelevich/HypatiaSupplements.jl/tree/master/centralpoints
278-
function get_central_ray_hypopowermean::Vector{T}) where {T <: Real}
279-
d = length(α)
280-
# predict w given α and d
281-
w = zeros(T, d)
282-
if d == 1
283-
w .= 1.306563
284-
elseif d == 2
285-
@. w = 1.0049885 + 0.2986276 * α
286-
elseif d <= 5
287-
@. w = 1.0040142949 - 0.0004885108 * d + 0.3016645951 * α
288-
elseif d <= 20
289-
@. w = 1.001168 - 4.547017e-05 * d + 3.032880e-01 * α
290-
elseif d <= 100
291-
@. w = 1.000069 - 5.469926e-07 * d + 3.074084e-01 * α
292-
else
293-
@. w = 1 + 3.086535e-01 * α
276+
function get_central_ray_hypopowermean::Vector{T}) where {T <: AbstractFloat}
277+
s = (T(5) - 1) / 2
278+
tol = sqrt(eps(T))
279+
maxiter = 2ceil(log2(-log2(tol)))
280+
counter = 0
281+
while counter < maxiter
282+
counter += 1
283+
step = _newton_ratio_powermean(s, α)
284+
s -= step
285+
if abs(step) < tol
286+
break
287+
end
294288
end
295-
# get u in closed form from w
296-
p = exp(sum(α_i * log(w_i) for (α_i, w_i) in zip(α, w)))
297-
u = p - p / d * sum(α_i / (abs2(w_i) - 1) for (α_i, w_i) in zip(α, w))
289+
counter == maxiter && error("Failed to compute initial point.")
290+
u = -√(1 - s)
291+
w = sqrt.(s .* α .+ 1)
298292
return (u, w)
299293
end
294+
295+
function _newton_ratio_powermean(s, α)
296+
logf = 2log(s) - log(1 - s) - sum(αi * log(αi * s + 1) for αi in α)
297+
dlogf = 2 / s + 1 / (1 - s) - sum(αi^2 / (αi * s + 1) for αi in α)
298+
return logf / dlogf
299+
end

0 commit comments

Comments
 (0)