Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Gradient Descent Using Sinkhorn #416

Unanswered
akul-goyal asked this question in Q&A
Discussion options

Hi,

I want to implement the following POT example here, but instead of using ot.emd2 I want to use Sinkhorn. I tried to use ot.sinkhorn2, but that strips away the gradients from the results. Is there a different function I can use that will keep the gradients?

You must be logged in to vote

Replies: 1 comment 6 replies

Comment options

This is a bug we fixed in the current master branch and is not yet released.

I is very stable and you can install it with:
```pip install -U https://github.com/PythonOT/POT/archive/master.zip``

You must be logged in to vote
6 replies
Comment options

we have a release in the work (once some major PR are merged) but at the moment sou should use the git version be cause the current release do not have the fix yet

Comment options

I updated POT using pip install -U https://github.com/PythonOT/POT/archive/master.zip and installed version 0.8.3.dev0. But I am still getting errors.

Comment options

I'm sorry but on master (the same version), when editing the loss in the example file as such:

def get_loss(w):
 a = torch.mv(H2, w) # distribution reweighting
 #return ot.emd2(a, b, M2) # squared Wasserstein 2
 return ot.sinkhorn2(a, b, M2, 0.1) # squared Wasserstein 2

It just works...
image
image

Comment options

Interestingly, when I pass in ot.sinkhorn2(a, b, M2, 0.1) it works, but the following parameters strip the gradients ot.sinkhorn2(a, b, M2, 0.0005, numItermax=10000000). I assume a low entropic regularization term causes gradients to disappear?

Comment options

for such a small value sinkhorn is not numerically stable dnas the gradient can be bad, you should at least use method='sinkhorn_log' in order to ensure that the iterations do not diverge. But for such a small epsilon emd2 will be faster on small datasets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet

AltStyle によって変換されたページ (->オリジナル) /