-
Couldn't load subscription status.
- Fork 536
POT sinkhorn_stabilized and sinkhorn_epsilon_scaling methods for sinkhorn getting stuck #749
-
Describe the bug
I am running sinkhorn between two datasets of 100 images (each 28x28 -- this is just 2 subsets of fashionmnist), but for some reason, I am finding that the sinkhorn_epsilon_scaling method is getting stuck at a certain value (I have attached the log), no matter if regularization is 0.0001, 0.001, 0.01, 0.1, 1. It is the same for different datasets of the same size -- why is this?
It. |Err
0|8.288824e-25|
10|1.999996e-03|
20|1.999996e-03|
30|1.999996e-03|
40|1.999996e-03|
50|1.999996e-03|
60|1.999996e-03|
70|1.999996e-03|
80|1.999996e-03|
90|1.999996e-03|
It. |Err
100|1.999996e-03|
110|1.999996e-03|
120|1.999996e-03|
130|1.999996e-03|
140|1.999996e-03|
150|1.999996e-03|
160|1.999996e-03|
170|1.999996e-03|
180|1.999996e-03|
190|1.999996e-03|
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux):
- Python version:
- How was POT installed (source,
pip,conda): - Build command you used (if compiling from source):
- Only for GPU related bugs:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Output of the following code snippet:
import platform; print(platform.platform()) import sys; print("Python", sys.version) import numpy; print("NumPy", numpy.__version__) import scipy; print("SciPy", scipy.__version__) import ot; print("POT", ot.__version__)
Additional context
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 3 comments 8 replies
-
I am using torch float64 tensors
Beta Was this translation helpful? Give feedback.
All reactions
-
N = len(data1['ys'])
X = torch.stack(data1[feature]).reshape((N, -1)).to(torch.float64)
Y = torch.stack(data2[feature]).reshape((N, -1)).to(torch.float64)
M = ot.dist(X, Y)
Mmax = M.max()
M_normalized = M / Mmax if float(Mmax) != 0.0 else M
a = torch.full((N,), 1.0 / N, dtype=torch.float64)
b = torch.full((N,), 1.0 / N, dtype=torch.float64)
W, log = ot.sinkhorn(
a, b, M_normalized,
reg=reg,
method=method,
numItermax=numItermax,
verbose=is_verbose,
log=True,
warn = True
)
Beta Was this translation helpful? Give feedback.
All reactions
-
You should use method='sinkhorn_log' parameter to have something that do not have numerical problems but keep in mind that for some regularization it might not converge
Beta Was this translation helpful? Give feedback.
All reactions
-
Sinkhorn log uses logsumexp and is very stable you should not have warings
Beta Was this translation helpful? Give feedback.
All reactions
-
I think what I was confused about is that exact OT would be able to solve this problem almost instantaneously (just OT between 2 sets of mnist, each 100 long), but sinkhorn log would take 5 hours and still wouldn't converge, even after 1000000 iterations. Why might this be?
Beta Was this translation helpful? Give feedback.
All reactions
-
I just wanted to find the method that I could scale to larger datasets at the accuracy of exact OT, but it seems like none of the methods are providing this.
Beta Was this translation helpful? Give feedback.
All reactions
-
Indeed because if you need accuracy of exact OT sinkhorn will be too slow (worst than exact OT). Another strategy to scale OT is to do minibaches but it is also an approximation... But between us tehre are not that many application that require exact OT and regularized is enough for a lot of deep learning applications on images.
Beta Was this translation helpful? Give feedback.
All reactions
-
Are there statistics for how long properly regularized sinkhorn should run for on a dataset of say 30,000 (i.e. MNIST train split in half?)
Beta Was this translation helpful? Give feedback.