Skip to content

Commit 59ff422

Browse files
zstevedevmotion
andauthored
Unify solve! for SinkhornSolver and SinkhornBarycenterSolver (#123)
* unify solve! for SinkhornSolver and SinkhornBarycenterSolver * format * reorganize and format * removed examples/benchmark * Update test/entropic/sinkhorn_barycenter.jl Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 9b9b803 commit 59ff422

File tree

6 files changed

+159
-170
lines changed

6 files changed

+159
-170
lines changed

src/OptimalTransport.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ include("entropic/sinkhorn_epsscaling.jl")
4040
include("entropic/sinkhorn_unbalanced.jl")
4141
include("entropic/sinkhorn_barycenter.jl")
4242
include("entropic/sinkhorn_barycenter_gibbs.jl")
43+
include("entropic/sinkhorn_solve.jl")
4344

4445
include("quadratic.jl")
4546
include("quadratic_newton.jl")

src/entropic/sinkhorn.jl

Lines changed: 19 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -89,123 +89,35 @@ function build_convergence_cache(
8989
)
9090
end
9191

92-
# Sinkhorn algorithm
92+
# Sinkhorn algorithm steps (see solve!)
93+
function init_step!(solver::SinkhornSolver)
94+
return A_batched_mul_B!(solver.cache.Kv, solver.cache.K, solver.cache.v)
95+
end
9396

94-
function solve!(solver::SinkhornSolver)
95-
# unpack solver
97+
function step!(solver::SinkhornSolver, iter::Int)
9698
μ = solver.source
9799
ν = solver.target
98-
atol = solver.atol
99-
rtol = solver.rtol
100-
maxiter = solver.maxiter
101-
check_convergence = solver.check_convergence
102100
cache = solver.cache
103-
convergence_cache = solver.convergence_cache
104-
105-
# unpack cache
106101
u = cache.u
107102
v = cache.v
108-
K = cache.K
109103
Kv = cache.Kv
104+
K = cache.K
110105

111-
A_batched_mul_B!(Kv, K, v)
112-
113-
isconverged = false
114-
to_check_step = check_convergence
115-
for iter in 1:maxiter
116-
# computations before the Sinkhorn iteration (e.g., absorption step)
117-
prestep!(solver, iter)
118-
119-
# perform Sinkhorn iteration
120-
u .= μ ./ Kv
121-
At_batched_mul_B!(v, K, u)
122-
v .= ν ./ v
123-
A_batched_mul_B!(Kv, K, v)
124-
125-
# check source marginal
126-
# always check convergence after the final iteration
127-
to_check_step -= 1
128-
if to_check_step == 0 || iter == maxiter
129-
# reset counter
130-
to_check_step = check_convergence
131-
132-
isconverged, abserror = OptimalTransport.check_convergence(
133-
μ, u, Kv, convergence_cache, atol, rtol
134-
)
135-
@debug string(solver.alg) *
136-
" (" *
137-
string(iter) *
138-
"/" *
139-
string(maxiter) *
140-
": absolute error of source marginal = " *
141-
string(maximum(abserror))
142-
143-
if isconverged
144-
@debug "$(solver.alg) ($iter/$maxiter): converged"
145-
break
146-
end
147-
end
148-
end
149-
150-
if !isconverged
151-
@warn "$(solver.alg) ($maxiter/$maxiter): not converged"
152-
end
153-
154-
return nothing
155-
end
156-
157-
# for single inputs
158-
function check_convergence(
159-
μ::AbstractVector,
160-
u::AbstractVector,
161-
Kv::AbstractVector,
162-
cache::SinkhornConvergenceCache,
163-
atol::Real,
164-
rtol::Real,
165-
)
166-
# unpack
167-
tmp = cache.tmp
168-
norm_μ = cache.norm_source
169-
170-
# do not overwrite `Kv` but reuse it for computing `u` if not converged
171-
tmp .= u .* Kv
172-
norm_uKv = sum(abs, tmp)
173-
tmp .= abs.(μ .- tmp)
174-
norm_diff = sum(tmp)
175-
176-
isconverged = norm_diff < max(atol, rtol * max(norm_μ, norm_uKv))
177-
178-
return isconverged, norm_diff
106+
u .= μ ./ Kv
107+
At_batched_mul_B!(v, K, u)
108+
v .= ν ./ v
109+
return A_batched_mul_B!(Kv, K, v)
179110
end
180111

181-
# for batches
182-
function check_convergence(
183-
μ::AbstractVecOrMat,
184-
u::AbstractMatrix,
185-
Kv::AbstractMatrix,
186-
cache::SinkhornBatchConvergenceCache,
187-
atol::Real,
188-
rtol::Real,
189-
)
190-
# unpack
191-
tmp = cache.tmp
192-
tmp2 = cache.tmp2
193-
norm_μ = cache.norm_source
194-
norm_uKv = cache.norm_uKv
195-
norm_diff = cache.norm_diff
196-
isconverged = cache.isconverged
197-
198-
# do not overwrite `Kv` but reuse it for computing `u` if not converged
199-
tmp .= u .* Kv
200-
tmp2 .= abs.(tmp)
201-
sum!(norm_uKv, tmp2)
202-
tmp .= abs.(μ .- tmp)
203-
sum!(norm_diff, tmp)
204-
205-
# check stopping criterion
206-
@. isconverged = norm_diff < max(atol, rtol * max(norm_μ, norm_uKv))
207-
208-
return all(isconverged), norm_diff
112+
function check_convergence(solver::SinkhornSolver)
113+
return OptimalTransport.check_convergence(
114+
solver.source,
115+
solver.cache.u,
116+
solver.cache.Kv,
117+
solver.convergence_cache,
118+
solver.atol,
119+
solver.rtol,
120+
)
209121
end
210122

211123
# API

src/entropic/sinkhorn_barycenter.jl

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -48,69 +48,6 @@ function build_solver(
4848
return solver
4949
end
5050

51-
function solve!(solver::SinkhornBarycenterSolver)
52-
# unpack solver
53-
μ = solver.source
54-
w = solver.w
55-
atol = solver.atol
56-
rtol = solver.rtol
57-
58-
maxiter = solver.maxiter
59-
check_convergence = solver.check_convergence
60-
cache = solver.cache
61-
convergence_cache = solver.convergence_cache
62-
63-
# unpack cache
64-
u = cache.u
65-
v = cache.v
66-
K = cache.K
67-
Kv = cache.Kv
68-
a = cache.a
69-
70-
isconverged = false
71-
to_check_step = check_convergence
72-
A_batched_mul_B!(Kv, K, v)
73-
for iter in 1:maxiter
74-
# prestep if needed (not used for SinkhornBarycenterSolver{SinkhornGibbs})
75-
prestep!(solver, iter)
76-
77-
# Sinkhorn iteration
78-
a .= prod(Kv' .^ w; dims=1)' # TODO: optimise
79-
u .= a ./ Kv
80-
At_batched_mul_B!(v, K, u)
81-
v .= μ ./ v
82-
A_batched_mul_B!(Kv, K, v)
83-
84-
# decrement check marginal step
85-
to_check_step -= 1
86-
# check convergence
87-
if to_check_step == 0 || iter == maxiter
88-
# reset counter
89-
to_check_step = check_convergence
90-
91-
isconverged, abserror = OptimalTransport.check_convergence(
92-
a, u, Kv, convergence_cache, atol, rtol
93-
)
94-
@debug string(solver.alg) *
95-
" (" *
96-
string(iter) *
97-
"/" *
98-
string(maxiter) *
99-
": absolute error of source marginal = " *
100-
string(maximum(abserror))
101-
102-
if isconverged
103-
@debug "$(solver.alg) ($iter/$maxiter): converged"
104-
break
105-
end
106-
end
107-
end
108-
if !isconverged
109-
@warn "$(solver.alg) ($maxiter/$maxiter): not converged"
110-
end
111-
return nothing
112-
end
113-
11451
"""
11552
sinkhorn_barycenter(μ, C, ε, w, alg = SinkhornGibbs(); kwargs...)
11653

src/entropic/sinkhorn_barycenter_gibbs.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,41 @@ function build_cache(
3030
return SinkhornBarycenterGibbsCache(u, v, K, Kv, a)
3131
end
3232

33+
# Sinkhorn algorithm steps (see solve!)
3334
prestep!(::SinkhornBarycenterSolver{SinkhornGibbs}, ::Int) = nothing
3435

36+
function init_step!(solver::SinkhornBarycenterSolver{SinkhornGibbs})
37+
return A_batched_mul_B!(solver.cache.Kv, solver.cache.K, solver.cache.v)
38+
end
39+
40+
function step!(solver::SinkhornBarycenterSolver{SinkhornGibbs}, iter::Int)
41+
μ = solver.source
42+
w = solver.w
43+
cache = solver.cache
44+
u = cache.u
45+
v = cache.v
46+
Kv = cache.Kv
47+
K = cache.K
48+
a = cache.a
49+
50+
a .= prod(Kv' .^ w; dims=1)' # TODO: optimise
51+
u .= a ./ Kv
52+
At_batched_mul_B!(v, K, u)
53+
v .= μ ./ v
54+
return A_batched_mul_B!(Kv, K, v)
55+
end
56+
57+
function check_convergence(solver::SinkhornBarycenterSolver{SinkhornGibbs})
58+
return OptimalTransport.check_convergence(
59+
solver.cache.a,
60+
solver.cache.u,
61+
solver.cache.Kv,
62+
solver.convergence_cache,
63+
solver.atol,
64+
solver.rtol,
65+
)
66+
end
67+
3568
function solution(solver::SinkhornBarycenterSolver{SinkhornGibbs})
3669
cache = solver.cache
3770
return cache.u[:, 1] .* cache.Kv[:, 1]

src/entropic/sinkhorn_gibbs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ function sinkhorn(
8181
kwargs...,
8282
)
8383
end
84+
8485
function sinkhorn2(
8586
μ,
8687
ν,

src/entropic/sinkhorn_solve.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Convergence checks
2+
#
3+
# for single inputs
4+
function check_convergence(
5+
μ::AbstractVector,
6+
u::AbstractVector,
7+
Kv::AbstractVector,
8+
cache::SinkhornConvergenceCache,
9+
atol::Real,
10+
rtol::Real,
11+
)
12+
# unpack
13+
tmp = cache.tmp
14+
norm_μ = cache.norm_source
15+
16+
# do not overwrite `Kv` but reuse it for computing `u` if not converged
17+
tmp .= u .* Kv
18+
norm_uKv = sum(abs, tmp)
19+
tmp .= abs.(μ .- tmp)
20+
norm_diff = sum(tmp)
21+
22+
isconverged = norm_diff < max(atol, rtol * max(norm_μ, norm_uKv))
23+
24+
return isconverged, norm_diff
25+
end
26+
27+
# for batches
28+
function check_convergence(
29+
μ::AbstractVecOrMat,
30+
u::AbstractMatrix,
31+
Kv::AbstractMatrix,
32+
cache::SinkhornBatchConvergenceCache,
33+
atol::Real,
34+
rtol::Real,
35+
)
36+
# unpack
37+
tmp = cache.tmp
38+
tmp2 = cache.tmp2
39+
norm_μ = cache.norm_source
40+
norm_uKv = cache.norm_uKv
41+
norm_diff = cache.norm_diff
42+
isconverged = cache.isconverged
43+
44+
# do not overwrite `Kv` but reuse it for computing `u` if not converged
45+
tmp .= u .* Kv
46+
tmp2 .= abs.(tmp)
47+
sum!(norm_uKv, tmp2)
48+
tmp .= abs.(μ .- tmp)
49+
sum!(norm_diff, tmp)
50+
51+
# check stopping criterion
52+
@. isconverged = norm_diff < max(atol, rtol * max(norm_μ, norm_uKv))
53+
54+
return all(isconverged), norm_diff
55+
end
56+
57+
# Common solve! operation
58+
function solve!(solver::Union{SinkhornSolver,SinkhornBarycenterSolver})
59+
# unpack solver
60+
atol = solver.atol
61+
rtol = solver.rtol
62+
maxiter = solver.maxiter
63+
check_convergence = solver.check_convergence
64+
cache = solver.cache
65+
convergence_cache = solver.convergence_cache
66+
67+
isconverged = false
68+
to_check_step = check_convergence
69+
# initial step if needed
70+
init_step!(solver)
71+
for iter in 1:maxiter
72+
# computations before the Sinkhorn iteration (e.g., absorption step)
73+
prestep!(solver, iter)
74+
# perform Sinkhorn iteration
75+
step!(solver, iter)
76+
77+
# check source marginal
78+
# always check convergence after the final iteration
79+
to_check_step -= 1
80+
if to_check_step == 0 || iter == maxiter
81+
# reset counter
82+
to_check_step = check_convergence
83+
84+
isconverged, abserror = OptimalTransport.check_convergence(solver)
85+
@debug string(solver.alg) *
86+
" (" *
87+
string(iter) *
88+
"/" *
89+
string(maxiter) *
90+
": absolute error of source marginal = " *
91+
string(maximum(abserror))
92+
93+
if isconverged
94+
@debug "$(solver.alg) ($iter/$maxiter): converged"
95+
break
96+
end
97+
end
98+
end
99+
100+
if !isconverged
101+
@warn "$(solver.alg) ($maxiter/$maxiter): not converged"
102+
end
103+
104+
return nothing
105+
end

0 commit comments

Comments
 (0)