- 
  Notifications
 You must be signed in to change notification settings 
- Fork 537
Gradient Descent Using Sinkhorn #416
-
Hi,
I want to implement the following POT example here, but instead of using ot.emd2 I want to use Sinkhorn. I tried to use ot.sinkhorn2, but that strips away the gradients from the results. Is there a different function I can use that will keep the gradients?
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 1 comment 6 replies
-
This is a bug we fixed in the current master branch and is not yet released.
I is very stable and you can install it with:
```pip install -U https://github.com/PythonOT/POT/archive/master.zip`` 
Beta Was this translation helpful? Give feedback.
All reactions
-
we have a release in the work (once some major PR are merged) but at the moment sou should use the git version be cause the current release do not have the fix yet
Beta Was this translation helpful? Give feedback.
All reactions
-
I updated POT using pip install -U https://github.com/PythonOT/POT/archive/master.zip  and installed version 0.8.3.dev0. But I am still getting errors.
Beta Was this translation helpful? Give feedback.
All reactions
-
I'm sorry but on master (the same version), when editing the loss in the example file as such:
def get_loss(w): a = torch.mv(H2, w) # distribution reweighting #return ot.emd2(a, b, M2) # squared Wasserstein 2 return ot.sinkhorn2(a, b, M2, 0.1) # squared Wasserstein 2
Beta Was this translation helpful? Give feedback.
All reactions
-
Interestingly, when I pass in ot.sinkhorn2(a, b, M2, 0.1) it works, but the following parameters strip the gradients ot.sinkhorn2(a, b, M2, 0.0005, numItermax=10000000). I assume a low entropic regularization term causes gradients to disappear?
Beta Was this translation helpful? Give feedback.
All reactions
-
for such a small value sinkhorn is not numerically stable dnas the gradient can be bad, you should at least use method='sinkhorn_log' in order to ensure that the iterations do not diverge. But for such a small epsilon emd2 will be faster on small datasets.
Beta Was this translation helpful? Give feedback.