Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions src/Cones/hypogeomean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mutable struct HypoGeoMean{T <: Real} <: Cone{T}
feas_updated::Bool
grad_updated::Bool
hess_updated::Bool
hess_aux_updated::Bool
inv_hess_updated::Bool
hess_fact_updated::Bool
is_feas::Bool
Expand Down Expand Up @@ -49,6 +50,8 @@ end

use_heuristic_neighborhood(cone::HypoGeoMean) = false

reset_data(cone::HypoGeoMean) = (cone.feas_updated = cone.grad_updated = cone.hess_updated = cone.inv_hess_updated = cone.hess_fact_updated = cone.hess_aux_updated = false)

function setup_extra_data(cone::HypoGeoMean{T}) where {T <: Real}
dim = cone.dim
cone.hess = Symmetric(zeros(T, dim, dim), :U)
Expand Down Expand Up @@ -111,8 +114,17 @@ function update_grad(cone::HypoGeoMean)
return cone.grad
end

function update_hess_aux(cone::HypoGeoMean)
H = cone.hess.data
H[1, 1] = abs2(cone.grad[1])
@. @views H[1, 2:end] = -cone.wgeo / d / z
cone.hess_aux_updated = true
return
end

function update_hess(cone::HypoGeoMean)
@assert cone.grad_updated
update_hess_aux(cone)
u = cone.point[1]
@views w = cone.point[2:end]
z = cone.z
Expand All @@ -122,12 +134,12 @@ function update_hess(cone::HypoGeoMean)
constww = wgeoz * (1 + wgeozm1) + 1
H = cone.hess.data

H[1, 1] = abs2(cone.grad[1])
# H[1, 1] = abs2(cone.grad[1])
@inbounds for j in eachindex(w)
j1 = j + 1
wj = w[j]
wgeozwj = wgeoz / wj
H[1, j1] = -wgeozwj / z
# H[1, j1] = -wgeozwj / z
wgeozwj2 = wgeozwj * wgeozm1
@inbounds for i in 1:(j - 1)
H[i + 1, j1] = wgeozwj2 / w[i]
Expand Down Expand Up @@ -163,6 +175,42 @@ function hess_prod!(prod::AbstractVecOrMat{T}, arr::AbstractVecOrMat{T}, cone::H
return prod
end

# TODO figure out where to store temp_np_np, when to allocate
function hess_outer_prod!(prod::AbstractVecOrMat{T}, arr::AbstractVecOrMat{T}, cone::HypoGeoMean{T}) where T
@assert cone.grad_updated
update_hess_aux(cone)
u = cone.point[1]
@views w = cone.point[2:end]
z = cone.z
d = cone.wdim
wgeo = cone.wgeo
temp_np = zeros(T, size(arr 1)) # TODO
temp_np_np = zeros(T, size(arr)...) # TODO

@views A_u = arr[:, 1]
@views A_w = arr[:, 2:end]
A_wt = A_w'

@. temp_np = A_u / z
mul!(res, temp_np, temp_np')

@views mul!(temp_np, A_w, cone.hess[1, 2:end])
mul!(temp_np_np, temp_np, A_u')
@. prod += temp_np_np
@. prod += temp_np_np'

ldiv!(temp_np, Diagonal(w), A_w)
@. temp_np /= d * z
mul!(prod, temp_np, temp_np', wgeo * u, true)

temp_sd_np = zeros(T, d, size(arr, 1)) # TODO
@. @views temp_sd_np = A_wt / w

mul!(prod, temp_sd_np', temp_sd_np, wgeo / d / z + 1, true)

return prod
end

function update_inv_hess(cone::HypoGeoMean{T}) where T
@assert !cone.inv_hess_updated
u = cone.point[1]
Expand Down