@@ -41,7 +41,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
4141 Regularization term >0
4242 method : str
4343 method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
44- 'sinkhorn_epsilon_scaling', see those function for specific parameters
44+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
4545 numItermax : int, optional
4646 Max number of iterations
4747 stopThr : float, optional
@@ -91,7 +91,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
9191 ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
9292
9393 """
94-
94+ 9595 if method .lower ()== 'sinkhorn' :
9696 sink = lambda : sinkhorn_knopp (a ,b , M , reg ,numItermax = numItermax ,
9797 stopThr = stopThr , verbose = verbose , log = log ,** kwargs )
@@ -100,15 +100,119 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
100100 stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
101101 elif method .lower ()== 'sinkhorn_epsilon_scaling' :
102102 sink = lambda : sinkhorn_epsilon_scaling (a ,b , M , reg ,numItermax = numItermax ,
103- stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
103+ stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
104104 else :
105105 print ('Warning : unknown method using classic Sinkhorn Knopp' )
106106 sink = lambda : sinkhorn_knopp (a ,b , M , reg , ** kwargs )
107-
107+ 108108 return sink ()
109+ 110+ def sinkhorn2 (a ,b , M , reg ,method = 'sinkhorn' , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ,** kwargs ):
111+ u"""
112+ Solve the entropic regularization optimal transport problem and return the loss
113+
114+ The function solves the following optimization problem:
115+
116+ .. math::
117+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
118+
119+ s.t. \gamma 1 = a
120+
121+ \gamma^T 1= b
122+
123+ \gamma\geq 0
124+ where :
125+
126+ - M is the (ns,nt) metric cost matrix
127+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
128+ - a and b are source and target weights (sum to 1)
129+
130+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
131+
132+
133+ Parameters
134+ ----------
135+ a : np.ndarray (ns,)
136+ samples weights in the source domain
137+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
138+ samples in the target domain, compute sinkhorn with multiple targets
139+ and fixed M if b is a matrix (return OT loss + dual variables in log)
140+ M : np.ndarray (ns,nt)
141+ loss matrix
142+ reg : float
143+ Regularization term >0
144+ method : str
145+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
146+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
147+ numItermax : int, optional
148+ Max number of iterations
149+ stopThr : float, optional
150+ Stop threshol on error (>0)
151+ verbose : bool, optional
152+ Print information along iterations
153+ log : bool, optional
154+ record log if True
155+
156+
157+ Returns
158+ -------
159+ W : (nt) ndarray or float
160+ Optimal transportation matrix for the given parameters
161+ log : dict
162+ log dictionary return only if log==True in parameters
163+
164+ Examples
165+ --------
166+
167+ >>> import ot
168+ >>> a=[.5,.5]
169+ >>> b=[.5,.5]
170+ >>> M=[[0.,1.],[1.,0.]]
171+ >>> ot.sinkhorn2(a,b,M,1)
172+ array([ 0.26894142])
109173
110174
111175
176+ References
177+ ----------
178+
179+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
180+
181+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
182+
183+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
184+
185+
186+
187+ See Also
188+ --------
189+ ot.lp.emd : Unregularized OT
190+ ot.optim.cg : General regularized OT
191+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
192+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
193+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
194+
195+ """
196+ 197+ if method .lower ()== 'sinkhorn' :
198+ sink = lambda : sinkhorn_knopp (a ,b , M , reg ,numItermax = numItermax ,
199+ stopThr = stopThr , verbose = verbose , log = log ,** kwargs )
200+ elif method .lower ()== 'sinkhorn_stabilized' :
201+ sink = lambda : sinkhorn_stabilized (a ,b , M , reg ,numItermax = numItermax ,
202+ stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
203+ elif method .lower ()== 'sinkhorn_epsilon_scaling' :
204+ sink = lambda : sinkhorn_epsilon_scaling (a ,b , M , reg ,numItermax = numItermax ,
205+ stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
206+ else :
207+ print ('Warning : unknown method using classic Sinkhorn Knopp' )
208+ sink = lambda : sinkhorn_knopp (a ,b , M , reg , ** kwargs )
209+ 210+ b = np .asarray (b ,dtype = np .float64 )
211+ if len (b .shape )< 2 :
212+ b = b .reshape ((- 1 ,1 ))
213+ 214+ return sink ()
215+ 112216
113217def sinkhorn_knopp (a ,b , M , reg , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ,** kwargs ):
114218 """
@@ -189,23 +293,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
189293 a = np .asarray (a ,dtype = np .float64 )
190294 b = np .asarray (b ,dtype = np .float64 )
191295 M = np .asarray (M ,dtype = np .float64 )
192-
296+ 193297
194298 if len (a )== 0 :
195299 a = np .ones ((M .shape [0 ],),dtype = np .float64 )/ M .shape [0 ]
196300 if len (b )== 0 :
197301 b = np .ones ((M .shape [1 ],),dtype = np .float64 )/ M .shape [1 ]
198-
302+ 199303
200304 # init data
201305 Nini = len (a )
202306 Nfin = len (b )
203-
307+ 204308 if len (b .shape )> 1 :
205309 nbb = b .shape [1 ]
206310 else :
207311 nbb = 0
208-
312+ 209313
210314 if log :
211315 log = {'err' :[]}
@@ -217,7 +321,7 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
217321 else :
218322 u = np .ones (Nini )/ Nini
219323 v = np .ones (Nfin )/ Nfin
220-
324+ 221325
222326 #print(reg)
223327
@@ -261,23 +365,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
261365 if log :
262366 log ['u' ]= u
263367 log ['v' ]= v
264-
265- if nbb : #return only loss
368+ 369+ if nbb : #return only loss
266370 res = np .zeros ((nbb ))
267371 for i in range (nbb ):
268372 res [i ]= np .sum (u [:,i ].reshape ((- 1 ,1 ))* K * v [:,i ].reshape ((1 ,- 1 ))* M )
269373 if log :
270374 return res ,log
271375 else :
272- return res
273-
376+ return res
377+ 274378 else : # return OT matrix
275-
379+ 276380 if log :
277381 return u .reshape ((- 1 ,1 ))* K * v .reshape ((1 ,- 1 )),log
278382 else :
279383 return u .reshape ((- 1 ,1 ))* K * v .reshape ((1 ,- 1 ))
280-
384+ 281385
282386def sinkhorn_stabilized (a ,b , M , reg , numItermax = 1000 ,tau = 1e3 , stopThr = 1e-9 ,warmstart = None , verbose = False ,print_period = 20 , log = False ,** kwargs ):
283387 """
@@ -393,7 +497,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
393497 alpha ,beta = np .zeros (na ),np .zeros (nb )
394498 else :
395499 alpha ,beta = warmstart
396-
500+ 397501 if nbb :
398502 u ,v = np .ones ((na ,nbb ))/ na ,np .ones ((nb ,nbb ))/ nb
399503 else :
@@ -420,7 +524,7 @@ def get_Gamma(alpha,beta,u,v):
420524
421525 uprev = u
422526 vprev = v
423-
527+ 424528 # sinkhorn update
425529 v = b / (np .dot (K .T ,u )+ 1e-16 )
426530 u = a / (np .dot (K ,v )+ 1e-16 )
@@ -471,8 +575,8 @@ def get_Gamma(alpha,beta,u,v):
471575 break
472576
473577 cpt = cpt + 1
474-
475-
578+ 579+ 476580 #print('err=',err,' cpt=',cpt)
477581 if log :
478582 log ['logu' ]= alpha / reg + np .log (u )
@@ -493,7 +597,7 @@ def get_Gamma(alpha,beta,u,v):
493597 res = np .zeros ((nbb ))
494598 for i in range (nbb ):
495599 res [i ]= np .sum (get_Gamma (alpha ,beta ,u [:,i ],v [:,i ])* M )
496- return res
600+ return res
497601 else :
498602 return get_Gamma (alpha ,beta ,u ,v )
499603
0 commit comments