@@ -89,123 +89,35 @@ function build_convergence_cache(
89
89
)
90
90
end
91
91
92
- # Sinkhorn algorithm
92
+ # Sinkhorn algorithm steps (see solve!)
93
+ function init_step! (solver:: SinkhornSolver )
94
+ return A_batched_mul_B! (solver. cache. Kv, solver. cache. K, solver. cache. v)
95
+ end
93
96
94
- function solve! (solver:: SinkhornSolver )
95
- # unpack solver
97
+ function step! (solver:: SinkhornSolver , iter:: Int )
96
98
μ = solver. source
97
99
ν = solver. target
98
- atol = solver. atol
99
- rtol = solver. rtol
100
- maxiter = solver. maxiter
101
- check_convergence = solver. check_convergence
102
100
cache = solver. cache
103
- convergence_cache = solver. convergence_cache
104
-
105
- # unpack cache
106
101
u = cache. u
107
102
v = cache. v
108
- K = cache. K
109
103
Kv = cache. Kv
104
+ K = cache. K
110
105
111
- A_batched_mul_B! (Kv, K, v)
112
-
113
- isconverged = false
114
- to_check_step = check_convergence
115
- for iter in 1 : maxiter
116
- # computations before the Sinkhorn iteration (e.g., absorption step)
117
- prestep! (solver, iter)
118
-
119
- # perform Sinkhorn iteration
120
- u .= μ ./ Kv
121
- At_batched_mul_B! (v, K, u)
122
- v .= ν ./ v
123
- A_batched_mul_B! (Kv, K, v)
124
-
125
- # check source marginal
126
- # always check convergence after the final iteration
127
- to_check_step -= 1
128
- if to_check_step == 0 || iter == maxiter
129
- # reset counter
130
- to_check_step = check_convergence
131
-
132
- isconverged, abserror = OptimalTransport. check_convergence (
133
- μ, u, Kv, convergence_cache, atol, rtol
134
- )
135
- @debug string (solver. alg) *
136
- " (" *
137
- string (iter) *
138
- " /" *
139
- string (maxiter) *
140
- " : absolute error of source marginal = " *
141
- string (maximum (abserror))
142
-
143
- if isconverged
144
- @debug " $(solver. alg) ($iter /$maxiter ): converged"
145
- break
146
- end
147
- end
148
- end
149
-
150
- if ! isconverged
151
- @warn " $(solver. alg) ($maxiter /$maxiter ): not converged"
152
- end
153
-
154
- return nothing
155
- end
156
-
157
- # for single inputs
158
- function check_convergence (
159
- μ:: AbstractVector ,
160
- u:: AbstractVector ,
161
- Kv:: AbstractVector ,
162
- cache:: SinkhornConvergenceCache ,
163
- atol:: Real ,
164
- rtol:: Real ,
165
- )
166
- # unpack
167
- tmp = cache. tmp
168
- norm_μ = cache. norm_source
169
-
170
- # do not overwrite `Kv` but reuse it for computing `u` if not converged
171
- tmp .= u .* Kv
172
- norm_uKv = sum (abs, tmp)
173
- tmp .= abs .(μ .- tmp)
174
- norm_diff = sum (tmp)
175
-
176
- isconverged = norm_diff < max (atol, rtol * max (norm_μ, norm_uKv))
177
-
178
- return isconverged, norm_diff
106
+ u .= μ ./ Kv
107
+ At_batched_mul_B! (v, K, u)
108
+ v .= ν ./ v
109
+ return A_batched_mul_B! (Kv, K, v)
179
110
end
180
111
181
- # for batches
182
- function check_convergence (
183
- μ:: AbstractVecOrMat ,
184
- u:: AbstractMatrix ,
185
- Kv:: AbstractMatrix ,
186
- cache:: SinkhornBatchConvergenceCache ,
187
- atol:: Real ,
188
- rtol:: Real ,
189
- )
190
- # unpack
191
- tmp = cache. tmp
192
- tmp2 = cache. tmp2
193
- norm_μ = cache. norm_source
194
- norm_uKv = cache. norm_uKv
195
- norm_diff = cache. norm_diff
196
- isconverged = cache. isconverged
197
-
198
- # do not overwrite `Kv` but reuse it for computing `u` if not converged
199
- tmp .= u .* Kv
200
- tmp2 .= abs .(tmp)
201
- sum! (norm_uKv, tmp2)
202
- tmp .= abs .(μ .- tmp)
203
- sum! (norm_diff, tmp)
204
-
205
- # check stopping criterion
206
- @. isconverged = norm_diff < max (atol, rtol * max (norm_μ, norm_uKv))
207
-
208
- return all (isconverged), norm_diff
112
+ function check_convergence (solver:: SinkhornSolver )
113
+ return OptimalTransport. check_convergence (
114
+ solver. source,
115
+ solver. cache. u,
116
+ solver. cache. Kv,
117
+ solver. convergence_cache,
118
+ solver. atol,
119
+ solver. rtol,
120
+ )
209
121
end
210
122
211
123
# API
0 commit comments