Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion src/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ end
"""
ot_plan(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric)

Compute the optimal transport cost for the Monge-Kantorovich problem with univariate
Compute the optimal transport plan for the Monge-Kantorovich problem with univariate
discrete distributions `μ` and `ν` as source and target marginals and cost function `c`
of the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function.

Expand Down Expand Up @@ -308,6 +308,81 @@ function ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric; plan=n
return _ot_cost(c, μ, ν, plan)
end

"""
ot_cost(
c,
usupport::AbstractVector,
vsupport::AbstractVector,
uprobs::AbstractVector{<:Real}=fill(inv(length(usupport)), length(usupport)),
vprobs::AbstractVector{<:Real}=fill(inv(length(vsupport)), length(vsupport))
; plan=nothing
)

Compute the optimal transport cost for the Monge-Kantorovich problem with discrete
univariate distributions where `usupport` and `vsupport` are the vectors,
`uprobs` and `vprobs` are the probabilities (weights),
and cost function `c`
is of the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function.

In case `uprobs` and `vprobs` are not specified, it's attributed equal probability.

A pre-computed optimal transport `plan` may be provided.

See also: [`ot_plan`](@ref), [`emd2`](@ref)
"""
function ot_cost(
c,
usupport::AbstractVector{<:Real},
vsupport::AbstractVector{<:Real},
;
uprobs::AbstractVector{<:Real}=fill(inv(length(usupport)), length(usupport)),
vprobs::AbstractVector{<:Real}=fill(inv(length(vsupport)), length(vsupport)),
plan=nothing,
)
μ = discretemeasure(usupport, uprobs)
ν = discretemeasure(vsupport, vprobs)
if plan === nothing
return _ot_cost(c, μ, ν, plan)
else
return _ot_cost(c, μ, ν, plan[sortperm(usupport), sortperm(vsupport)])
end
end

"""
ot_plan(
c,
usupport::AbstractVector,
vsupport::AbstractVector,
;uprobs::AbstractVector{<:Real}=fill(inv(length(usupport)), length(usupport)),
vprobs::AbstractVector{<:Real}=fill(inv(length(vsupport)), length(vsupport))
)

Compute the optimal transport plan for the Monge-Kantorovich problem with discrete
univariate distributions where `u` and `v` are the vectors,
`uprobs` and `vprobs` are the probabilities (weights),
and cost function `c`
is of the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function.

In case `uprobs` and `vprobs` are not specified, it's attributed equal probability.

A pre-computed optimal transport `plan` may be provided.

See also: [`ot_plan`](@ref), [`emd2`](@ref)
"""

function ot_plan(
c,
usupport::AbstractVector{<:Real},
vsupport::AbstractVector{<:Real};
uprobs::AbstractVector{<:Real}=fill(inv(length(usupport)), length(usupport)),
vprobs::AbstractVector{<:Real}=fill(inv(length(vsupport)), length(vsupport)),
)
μ = discretemeasure(usupport, uprobs)
ν = discretemeasure(vsupport, vprobs)
γ = ot_plan(c, μ, ν)
return γ[invperm(sortperm(usupport)), invperm(sortperm(vsupport))]
end

# compute cost from scratch if no plan is provided
function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, ::Nothing)
# unpack the probabilities of the two distributions
Expand Down
52 changes: 52 additions & 0 deletions test/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,58 @@ Random.seed!(100)
c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ)))
@test c2 ≈ c
end

@testset "discrete case with vectors" begin
# random source and target marginal
m = 30
uprobs = normalize!(rand(m), 1)
usupport = randn(m)

n = 50
vprobs = normalize!(rand(n), 1)
vsupport = randn(n)

# compute OT plan
γ = @inferred(
ot_plan(euclidean, usupport, vsupport; uprobs=uprobs, vprobs=vprobs)
)
@test γ isa SparseMatrixCSC
@test size(γ) == (m, n)
@test vec(sum(γ; dims=2)) ≈ uprobs
@test vec(sum(γ; dims=1)) ≈ vprobs

# consistency checks
I, J, W = findnz(γ[sortperm(usupport), sortperm(vsupport)])
@test all(w > zero(w) for w in W)
@test sum(W) ≈ 1
@test sort(unique(I)) == 1:m
@test sort(unique(J)) == 1:n
@test sort(I .+ J) == 2:(m + n)

# compute OT cost
c = @inferred(
ot_cost(euclidean, usupport, vsupport; uprobs=uprobs, vprobs=vprobs)
)

# compare with computation with explicit cost matrix
# DiscreteNonParametric sorts the support automatically, here we have to sort
# manually
C = pairwise(Euclidean(), usupport', vsupport'; dims=2)
c2 = emd2(uprobs, vprobs, C, Tulip.Optimizer())
@test c2 ≈ c rtol = 1e-5

# do not use the probabilities of u and v check the use of the default
c = @inferred(ot_cost(euclidean, usupport, vsupport))

C = pairwise(Euclidean(), usupport', vsupport'; dims=2)
c2 = emd2(fill(1 / m, m), fill(1 / n, n), C, Tulip.Optimizer())
@test c2 ≈ c rtol = 1e-5

γ = @inferred(ot_plan(euclidean, usupport, vsupport))
# used
c2 = @inferred(ot_cost(euclidean, usupport, vsupport; plan=γ))
@test c2 ≈ c
end
end

@testset "Multivariate Gaussians" begin
Expand Down
49 changes: 48 additions & 1 deletion test/wasserstein.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using ExactOptimalTransport

using Distances
using Distributions
using LinearAlgebra

using Random
using Test
Expand Down Expand Up @@ -36,7 +37,53 @@ Random.seed!(100)
end
end

@testset "wasserstein" begin
@testset "wasserstein 1D discrete" begin
m = 10
n = 15
usupport = randn(m)
uprobs = normalize!(rand(m), 1)
vsupport = rand(n)
vprobs = normalize!(rand(n), 1)

for p in (1, 2, 3, randexp()), metric in (Euclidean(), TotalVariation())
for _p in (p, Val(p))
# without additional keyword arguments
w = wasserstein(
usupport, vsupport; p=_p, metric=metric, uprobs=uprobs, vprobs=vprobs
)
@test w ≈
ot_cost(
(x, y) -> metric(x, y)^p,
usupport,
vsupport;
uprobs=uprobs,
vprobs=vprobs,
)^(1 / p)

w = wasserstein(usupport, vsupport; p=_p, metric=metric)
@test w ≈ ot_cost((x, y) -> metric(x, y)^p, usupport, vsupport)^(1 / p)

# without passing the probabilities
T = ot_plan((x, y) -> metric(x, y)^p, usupport, vsupport)
w2 = wasserstein(usupport, vsupport; p=_p, metric=metric, plan=T)
@test w ≈ w2
end
end

# check that `Euclidean` is the default `metric`
for p in (1, 2, 3, randexp()), _p in (p, Val(p))
w = wasserstein(usupport, vsupport; p=_p)
@test w ≈ wasserstein(usupport, vsupport; p=_p, metric=Euclidean())
end

# check that `Val(1)` is the default `p`
for metric in (Euclidean(), TotalVariation())
w = wasserstein(usupport, vsupport; metric=metric)
@test w ≈ wasserstein(usupport, vsupport; p=Val(1), metric=metric)
end
end

@testset "wasserstein continuous" begin
μ = Normal(randn(), randexp())
ν = Normal(randn(), randexp())
for p in (1, 2, 3, randexp()), metric in (Euclidean(), TotalVariation())
Expand Down