Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 47477c5

Browse files
committed
add sinkhorbn2 +v3
1 parent 0fc1124 commit 47477c5

File tree

4 files changed

+133
-28
lines changed

4 files changed

+133
-28
lines changed

‎README.md‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,15 @@ import ot
8383
# a,b are 1D histograms (sum to 1 and positive)
8484
# M is the ground cost matrix
8585
Wd=ot.emd2(a,b,M) # exact linear program
86+
Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
8687
# if b is a matrix compute all distances to a and return a vector
8788
```
8889
* Compute OT matrix
8990
```python
9091
# a,b are 1D histograms (sum to 1 and positive)
9192
# M is the ground cost matrix
92-
Totp=ot.emd(a,b,M) # exact linear program
93-
Totp_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
93+
T=ot.emd(a,b,M) # exact linear program
94+
T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
9495
```
9596
* Compute Wasserstein barycenter
9697
```python

‎examples/plot_compute_emd.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161

6262
#%%
6363
reg=1e-2
64-
d_sinkhorn=ot.sinkhorn(a,B,M,reg)
65-
d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
64+
d_sinkhorn=ot.sinkhorn2(a,B,M,reg)
65+
d_sinkhorn2=ot.sinkhorn2(a,B,M2,reg)
6666

6767
pl.figure(2)
6868
pl.clf()

‎ot/__init__.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
# OT functions
1818
from .lp import emd, emd2
19-
from .bregman import sinkhorn, barycenter
19+
from .bregman import sinkhorn, sinkhorn2, barycenter
2020
from .da import sinkhorn_lpl1_mm
2121

2222
# utils functions
2323
from .utils import dist, unif, tic, toc, toq
2424

25-
__version__ = "0.2"
25+
__version__ = "0.3"
2626

27-
__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp',
28-
'plot', 'tic', 'toc', 'toq',
27+
__all__ = ["emd", "emd2", "sinkhorn","sinkhorn2", "utils", 'datasets',
28+
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
2929
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

‎ot/bregman.py‎

Lines changed: 124 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
4141
Regularization term >0
4242
method : str
4343
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
44-
'sinkhorn_epsilon_scaling', see those function for specific parameters
44+
'sinkhorn_epsilon_scaling', see those function for specific parameters
4545
numItermax : int, optional
4646
Max number of iterations
4747
stopThr : float, optional
@@ -91,7 +91,7 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
9191
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
9292
9393
"""
94-
94+
9595
if method.lower()=='sinkhorn':
9696
sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
9797
stopThr=stopThr, verbose=verbose, log=log,**kwargs)
@@ -100,15 +100,119 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
100100
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
101101
elif method.lower()=='sinkhorn_epsilon_scaling':
102102
sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
103-
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
103+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
104104
else:
105105
print('Warning : unknown method using classic Sinkhorn Knopp')
106106
sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
107-
107+
108108
return sink()
109+
110+
def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
111+
u"""
112+
Solve the entropic regularization optimal transport problem and return the loss
113+
114+
The function solves the following optimization problem:
115+
116+
.. math::
117+
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
118+
119+
s.t. \gamma 1 = a
120+
121+
\gamma^T 1= b
122+
123+
\gamma\geq 0
124+
where :
125+
126+
- M is the (ns,nt) metric cost matrix
127+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
128+
- a and b are source and target weights (sum to 1)
129+
130+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
131+
132+
133+
Parameters
134+
----------
135+
a : np.ndarray (ns,)
136+
samples weights in the source domain
137+
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
138+
samples in the target domain, compute sinkhorn with multiple targets
139+
and fixed M if b is a matrix (return OT loss + dual variables in log)
140+
M : np.ndarray (ns,nt)
141+
loss matrix
142+
reg : float
143+
Regularization term >0
144+
method : str
145+
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
146+
'sinkhorn_epsilon_scaling', see those function for specific parameters
147+
numItermax : int, optional
148+
Max number of iterations
149+
stopThr : float, optional
150+
Stop threshol on error (>0)
151+
verbose : bool, optional
152+
Print information along iterations
153+
log : bool, optional
154+
record log if True
155+
156+
157+
Returns
158+
-------
159+
W : (nt) ndarray or float
160+
Optimal transportation matrix for the given parameters
161+
log : dict
162+
log dictionary return only if log==True in parameters
163+
164+
Examples
165+
--------
166+
167+
>>> import ot
168+
>>> a=[.5,.5]
169+
>>> b=[.5,.5]
170+
>>> M=[[0.,1.],[1.,0.]]
171+
>>> ot.sinkhorn2(a,b,M,1)
172+
array([ 0.26894142])
109173
110174
111175
176+
References
177+
----------
178+
179+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
180+
181+
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
182+
183+
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
184+
185+
186+
187+
See Also
188+
--------
189+
ot.lp.emd : Unregularized OT
190+
ot.optim.cg : General regularized OT
191+
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
192+
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
193+
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
194+
195+
"""
196+
197+
if method.lower()=='sinkhorn':
198+
sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
199+
stopThr=stopThr, verbose=verbose, log=log,**kwargs)
200+
elif method.lower()=='sinkhorn_stabilized':
201+
sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax,
202+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
203+
elif method.lower()=='sinkhorn_epsilon_scaling':
204+
sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
205+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
206+
else:
207+
print('Warning : unknown method using classic Sinkhorn Knopp')
208+
sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
209+
210+
b=np.asarray(b,dtype=np.float64)
211+
if len(b.shape)<2:
212+
b=b.reshape((-1,1))
213+
214+
return sink()
215+
112216

113217
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
114218
"""
@@ -189,23 +293,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
189293
a=np.asarray(a,dtype=np.float64)
190294
b=np.asarray(b,dtype=np.float64)
191295
M=np.asarray(M,dtype=np.float64)
192-
296+
193297

194298
if len(a)==0:
195299
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
196300
if len(b)==0:
197301
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
198-
302+
199303

200304
# init data
201305
Nini = len(a)
202306
Nfin = len(b)
203-
307+
204308
if len(b.shape)>1:
205309
nbb=b.shape[1]
206310
else:
207311
nbb=0
208-
312+
209313

210314
if log:
211315
log={'err':[]}
@@ -217,7 +321,7 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
217321
else:
218322
u = np.ones(Nini)/Nini
219323
v = np.ones(Nfin)/Nfin
220-
324+
221325

222326
#print(reg)
223327

@@ -261,23 +365,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
261365
if log:
262366
log['u']=u
263367
log['v']=v
264-
265-
if nbb: #return only loss
368+
369+
if nbb: #return only loss
266370
res=np.zeros((nbb))
267371
for i in range(nbb):
268372
res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
269373
if log:
270374
return res,log
271375
else:
272-
return res
273-
376+
return res
377+
274378
else: # return OT matrix
275-
379+
276380
if log:
277381
return u.reshape((-1,1))*K*v.reshape((1,-1)),log
278382
else:
279383
return u.reshape((-1,1))*K*v.reshape((1,-1))
280-
384+
281385

282386
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
283387
"""
@@ -393,7 +497,7 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
393497
alpha,beta=np.zeros(na),np.zeros(nb)
394498
else:
395499
alpha,beta=warmstart
396-
500+
397501
if nbb:
398502
u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb
399503
else:
@@ -420,7 +524,7 @@ def get_Gamma(alpha,beta,u,v):
420524

421525
uprev = u
422526
vprev = v
423-
527+
424528
# sinkhorn update
425529
v = b/(np.dot(K.T,u)+1e-16)
426530
u = a/(np.dot(K,v)+1e-16)
@@ -471,8 +575,8 @@ def get_Gamma(alpha,beta,u,v):
471575
break
472576

473577
cpt = cpt +1
474-
475-
578+
579+
476580
#print('err=',err,' cpt=',cpt)
477581
if log:
478582
log['logu']=alpha/reg+np.log(u)
@@ -493,7 +597,7 @@ def get_Gamma(alpha,beta,u,v):
493597
res=np.zeros((nbb))
494598
for i in range(nbb):
495599
res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M)
496-
return res
600+
return res
497601
else:
498602
return get_Gamma(alpha,beta,u,v)
499603

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /