Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This modularity means that different HMC variants can be easily constructed by c
- Unit metric: `UnitEuclideanMetric(dim)`
- Diagonal metric: `DiagEuclideanMetric(dim)`
- Dense metric: `DenseEuclideanMetric(dim)`
- Rank update metric: `RankUpdateEuclideanMetric(dim)`

where `dim` is the dimensionality of the sampling space.

Expand Down
16 changes: 14 additions & 2 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@ module AdvancedHMC

using Statistics: mean, var, middle
using LinearAlgebra:
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
Symmetric,
UpperTriangular,
mul!,
ldiv!,
dot,
I,
diag,
cholesky,
UniformScaling,
Diagonal,
qr,
lmul!
using StatsFuns: logaddexp, logsumexp, loghalf
using Random: Random, AbstractRNG
using ProgressMeter: ProgressMeter
Expand Down Expand Up @@ -40,7 +51,8 @@ struct GaussianKinetic <: AbstractKinetic end
export GaussianKinetic

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
export UnitEuclideanMetric,
DiagEuclideanMetric, DenseEuclideanMetric, RankUpdateEuclideanMetric

include("hamiltonian.jl")
export Hamiltonian
Expand Down
19 changes: 19 additions & 0 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::A
return M⁻¹ * r
end

function ∂H∂r(
h::Hamiltonian{<:RankUpdateEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat
)
(; M⁻¹) = h.metric
axes_M⁻¹ = __axes(M⁻¹)
axes_r = __axes(r)
(first(axes_M⁻¹) !== first(axes_r)) && throw(
ArgumentError("AxesMismatch: M⁻¹ has axes $(axes_M⁻¹) but r has axes $(axes_r)")
)
return M⁻¹ * r
end

# TODO (kai) make the order of θ and r consistent with neg_energy
# TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic
# The gradient of a position-dependent Hamiltonian system depends on both θ and r.
Expand Down Expand Up @@ -165,6 +177,13 @@ function neg_energy(
return -dot(r, h.metric._temp) / 2
end

function neg_energy(
h::Hamiltonian{<:RankUpdateEuclideanMetric,<:GaussianKinetic}, r::T, θ::T
) where {T<:AbstractVecOrMat}
M⁻¹ = h.metric.M⁻¹
return -r' * M⁻¹ * r / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

energy(args...) = -neg_energy(args...)

####
Expand Down
78 changes: 78 additions & 0 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,68 @@ function Base.show(io::IO, dem::DenseEuclideanMetric)
return print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))")
end

"""
RankUpdateEuclideanMetric{T,M} <: AbstractMetric

A Gaussian Euclidean metric whose inverse is constructed by rank-updates.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A Gaussian Euclidean metric whose inverse is constructed by rank-updates.
A Gaussian Euclidean metric whose inverse is constructed by low-rank updates to a diagonal matrix.


# Constructors

RankUpdateEuclideanMetric(n::Int)

Construct a Gaussian Euclidean metric of size `(n, n)` with inverse of `M⁻¹`.

# Example

```julia
julia> RankUpdateEuclideanMetric(3)
RankUpdateEuclideanMetric(diag=[1.0, 1.0, 1.0])
```
"""
struct RankUpdateEuclideanMetric{T,AM<:AbstractVecOrMat{T},AB,AD,F} <: AbstractMetric
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More a question than a request: Is there a reason why AM can be a Vector? Also, is it intentional that AB and AD don't have to have the same element type?

# Diagnal of the inverse of the mass matrix
M⁻¹::AM
B::AB
D::AD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
D::AD
D::AD
"Woodbury factorisation of M⁻¹ + B D transpose(B)"

factorization::F
end

function woodbury_factorize(A, B, D)
cholA = cholesky(A isa Diagonal ? A : Symmetric(A))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pathfinder's implementation allows for arbitrary PD A but for the use-cases in Pathfinder and Bales's paper, diagonal A is sufficient. Since you've already documented A as diagonal, perhaps you can drop this check.

U = cholA.U
Q, R = qr(U' \ B)
V = cholesky(Symmetric(muladd(R, D * R', I))).U
return (U=U, Q=Q, V=V)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return (U=U, Q=Q, V=V)
return (; U, Q, V)

end

function RankUpdateEuclideanMetric(n::Int)
M⁻¹ = Diagonal(ones(n))
B = zeros(n, 0)
D = zeros(0, 0)
factorization = woodbury_factorize(M⁻¹, B, D)
return RankUpdateEuclideanMetric(M⁻¹, B, D, factorization)
end
Comment on lines +146 to +152
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is probably fine for now, but there are multiple ways to form the identity matrix here, and later it might be better to initialize a different way (e.g. for tuning a covariance matrix for factor analysis, B=0 and D=0 prohibits convergence)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it breaks with the API of other metric constructors, but I think you want to be able to initialize the rank of the update as well. Otherwise you have no way to fix the rank for future tuning algorithms.

function RankUpdateEuclideanMetric(::Type{T}, n::Int) where {T}
M⁻¹ = Diagonal(ones(T, n))
B = Matrix{T}(undef, n, 0)
D = Matrix{T}(undef, 0, 0)
factorization = woodbury_factorize(M⁻¹, B, D)
return RankUpdateEuclideanMetric(M⁻¹, B, D, factorization)
end
function RankUpdateEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T}
return RankUpdateEuclideanMetric(T, first(sz))
end
RankUpdateEuclideanMetric(sz::Tuple{Int}) = RankUpdateEuclideanMetric(Float64, sz)

AdvancedHMC.renew(::RankUpdateEuclideanMetric, M⁻¹) = RankUpdateEuclideanMetric(M⁻¹)

Base.size(metric::RankUpdateEuclideanMetric, dim...) = size(metric.M⁻¹.diag, dim...)

function Base.show(io::IO, metric::RankUpdateEuclideanMetric)
print(io, "RankUpdateEuclideanMetric(diag=$(diag(metric.M⁻¹)))")
return nothing
end

# `rand` functions for `metric` types.

function rand_momentum(
Expand Down Expand Up @@ -131,3 +193,19 @@ function rand_momentum(
ldiv!(metric.cholM⁻¹, r)
return r
end

function rand_momentum(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::RankUpdateEuclideanMetric{T},
kinetic::GaussianKinetic,
::AbstractVecOrMat,
) where {T}
M⁻¹ = metric.M⁻¹
r = _randn(rng, T, size(M⁻¹.diag)...)
F = metric.factorization
k = min(size(F.U, 1), size(F.V, 1))
@views ldiv!(F.V, r isa AbstractVector ? r[1:k] : r[1:k, :])
lmul!(F.Q, r)
ldiv!(F.U, r)
return r
end
9 changes: 7 additions & 2 deletions test/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using ReTest, Random, AdvancedHMC
UnitEuclideanMetric((D, n_chains)),
DiagEuclideanMetric((D, n_chains)),
# DenseEuclideanMetric((D, n_chains)) # not supported ATM
# RankUpdateEuclideanMetric((D, n_chains)) # not supported ATM
]
r = AdvancedHMC.rand_momentum(rng, metric, GaussianKinetic(), θ)
all_same = true
Expand All @@ -25,8 +26,12 @@ using ReTest, Random, AdvancedHMC
rng = MersenneTwister(1)
θ = randn(rng, D)
ℓπ(θ) = 1
for metric in
[UnitEuclideanMetric(1), DiagEuclideanMetric(1), DenseEuclideanMetric(1)]
for metric in [
UnitEuclideanMetric(1),
DiagEuclideanMetric(1),
DenseEuclideanMetric(1),
RankUpdateEuclideanMetric(1),
]
h = Hamiltonian(metric, ℓπ, ℓπ)
h = AdvancedHMC.resize(h, θ)
@test size(h.metric) == size(θ)
Expand Down
Loading