diff --git a/src/exact.jl b/src/exact.jl index 82ea71d..25a2cba 100644 --- a/src/exact.jl +++ b/src/exact.jl @@ -253,7 +253,7 @@ end """ ot_plan(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric) -Compute the optimal transport cost for the Monge-Kantorovich problem with univariate +Compute the optimal transport plan for the Monge-Kantorovich problem with univariate discrete distributions `μ` and `ν` as source and target marginals and cost function `c` of the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function. @@ -308,6 +308,81 @@ function ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric; plan=n return _ot_cost(c, μ, ν, plan) end +""" + ot_cost( + c, + usupport::AbstractVector, + vsupport::AbstractVector, + uprobs::AbstractVector{<:Real}=fill(inv(length(usupport)), length(usupport)), + vprobs::AbstractVector{<:Real}=fill(inv(length(vsupport)), length(vsupport)) + ; plan=nothing + ) + +Compute the optimal transport cost for the Monge-Kantorovich problem with discrete +univariate distributions where `usupport` and `vsupport` are the vectors, +`uprobs` and `vprobs` are the probabilities (weights), +and cost function `c` +is of the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function. + +In case `uprobs` and `vprobs` are not specified, it's attributed equal probability. + +A pre-computed optimal transport `plan` may be provided. + +See also: [`ot_plan`](@ref), [`emd2`](@ref) +""" +function ot_cost( + c, + usupport::AbstractVector{<:Real}, + vsupport::AbstractVector{<:Real}, + ; + uprobs::AbstractVector{<:Real}=fill(inv(length(usupport)), length(usupport)), + vprobs::AbstractVector{<:Real}=fill(inv(length(vsupport)), length(vsupport)), + plan=nothing, +) + μ = discretemeasure(usupport, uprobs) + ν = discretemeasure(vsupport, vprobs) + if plan === nothing + return _ot_cost(c, μ, ν, plan) + else + return _ot_cost(c, μ, ν, plan[sortperm(usupport), sortperm(vsupport)]) + end +end + +""" + ot_plan( + c, + usupport::AbstractVector, + vsupport::AbstractVector, + ;uprobs::AbstractVector{<:Real}=fill(inv(length(usupport)), length(usupport)), + vprobs::AbstractVector{<:Real}=fill(inv(length(vsupport)), length(vsupport)) + ) + +Compute the optimal transport plan for the Monge-Kantorovich problem with discrete +univariate distributions where `u` and `v` are the vectors, +`uprobs` and `vprobs` are the probabilities (weights), +and cost function `c` +is of the form ``c(x, y) = h(|x - y|)`` where ``h`` is a convex function. + +In case `uprobs` and `vprobs` are not specified, it's attributed equal probability. + +A pre-computed optimal transport `plan` may be provided. + +See also: [`ot_plan`](@ref), [`emd2`](@ref) +""" + +function ot_plan( + c, + usupport::AbstractVector{<:Real}, + vsupport::AbstractVector{<:Real}; + uprobs::AbstractVector{<:Real}=fill(inv(length(usupport)), length(usupport)), + vprobs::AbstractVector{<:Real}=fill(inv(length(vsupport)), length(vsupport)), +) + μ = discretemeasure(usupport, uprobs) + ν = discretemeasure(vsupport, vprobs) + γ = ot_plan(c, μ, ν) + return γ[invperm(sortperm(usupport)), invperm(sortperm(vsupport))] +end + # compute cost from scratch if no plan is provided function _ot_cost(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric, ::Nothing) # unpack the probabilities of the two distributions diff --git a/test/exact.jl b/test/exact.jl index 26826a8..ef779ef 100644 --- a/test/exact.jl +++ b/test/exact.jl @@ -161,6 +161,58 @@ Random.seed!(100) c2 = @inferred(ot_cost(euclidean, μ2, ν2; plan=Matrix(γ))) @test c2 ≈ c end + + @testset "discrete case with vectors" begin + # random source and target marginal + m = 30 + uprobs = normalize!(rand(m), 1) + usupport = randn(m) + + n = 50 + vprobs = normalize!(rand(n), 1) + vsupport = randn(n) + + # compute OT plan + γ = @inferred( + ot_plan(euclidean, usupport, vsupport; uprobs=uprobs, vprobs=vprobs) + ) + @test γ isa SparseMatrixCSC + @test size(γ) == (m, n) + @test vec(sum(γ; dims=2)) ≈ uprobs + @test vec(sum(γ; dims=1)) ≈ vprobs + + # consistency checks + I, J, W = findnz(γ[sortperm(usupport), sortperm(vsupport)]) + @test all(w > zero(w) for w in W) + @test sum(W) ≈ 1 + @test sort(unique(I)) == 1:m + @test sort(unique(J)) == 1:n + @test sort(I .+ J) == 2:(m + n) + + # compute OT cost + c = @inferred( + ot_cost(euclidean, usupport, vsupport; uprobs=uprobs, vprobs=vprobs) + ) + + # compare with computation with explicit cost matrix + # DiscreteNonParametric sorts the support automatically, here we have to sort + # manually + C = pairwise(Euclidean(), usupport', vsupport'; dims=2) + c2 = emd2(uprobs, vprobs, C, Tulip.Optimizer()) + @test c2 ≈ c rtol = 1e-5 + + # do not use the probabilities of u and v check the use of the default + c = @inferred(ot_cost(euclidean, usupport, vsupport)) + + C = pairwise(Euclidean(), usupport', vsupport'; dims=2) + c2 = emd2(fill(1 / m, m), fill(1 / n, n), C, Tulip.Optimizer()) + @test c2 ≈ c rtol = 1e-5 + + γ = @inferred(ot_plan(euclidean, usupport, vsupport)) + # used + c2 = @inferred(ot_cost(euclidean, usupport, vsupport; plan=γ)) + @test c2 ≈ c + end end @testset "Multivariate Gaussians" begin diff --git a/test/wasserstein.jl b/test/wasserstein.jl index f39e6af..0cfcde9 100644 --- a/test/wasserstein.jl +++ b/test/wasserstein.jl @@ -2,6 +2,7 @@ using ExactOptimalTransport using Distances using Distributions +using LinearAlgebra using Random using Test @@ -36,7 +37,53 @@ Random.seed!(100) end end - @testset "wasserstein" begin + @testset "wasserstein 1D discrete" begin + m = 10 + n = 15 + usupport = randn(m) + uprobs = normalize!(rand(m), 1) + vsupport = rand(n) + vprobs = normalize!(rand(n), 1) + + for p in (1, 2, 3, randexp()), metric in (Euclidean(), TotalVariation()) + for _p in (p, Val(p)) + # without additional keyword arguments + w = wasserstein( + usupport, vsupport; p=_p, metric=metric, uprobs=uprobs, vprobs=vprobs + ) + @test w ≈ + ot_cost( + (x, y) -> metric(x, y)^p, + usupport, + vsupport; + uprobs=uprobs, + vprobs=vprobs, + )^(1 / p) + + w = wasserstein(usupport, vsupport; p=_p, metric=metric) + @test w ≈ ot_cost((x, y) -> metric(x, y)^p, usupport, vsupport)^(1 / p) + + # without passing the probabilities + T = ot_plan((x, y) -> metric(x, y)^p, usupport, vsupport) + w2 = wasserstein(usupport, vsupport; p=_p, metric=metric, plan=T) + @test w ≈ w2 + end + end + + # check that `Euclidean` is the default `metric` + for p in (1, 2, 3, randexp()), _p in (p, Val(p)) + w = wasserstein(usupport, vsupport; p=_p) + @test w ≈ wasserstein(usupport, vsupport; p=_p, metric=Euclidean()) + end + + # check that `Val(1)` is the default `p` + for metric in (Euclidean(), TotalVariation()) + w = wasserstein(usupport, vsupport; metric=metric) + @test w ≈ wasserstein(usupport, vsupport; p=Val(1), metric=metric) + end + end + + @testset "wasserstein continuous" begin μ = Normal(randn(), randexp()) ν = Normal(randn(), randexp()) for p in (1, 2, 3, randexp()), metric in (Euclidean(), TotalVariation())