Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Calculate wasserstein distance #407

Unanswered
AdarshMJ asked this question in Q&A
Nov 19, 2022 · 1 comments · 3 replies
Discussion options

H! @rflamary

This is a really cool library. I was wondering if you can provide me with any insights on this problem -

I am trying to find the distance between two distributions, for example -

n = 100 # nb bins
x = torch.arange(n,dtype=torch.float64)
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
l1 = torch.zeros(100)
l2 = torch.zeros(100)
l1[0] = 1.0
l2[1] = 1.0
lambd = 1e-3
Gs1 = ot.sinkhorn(l1, l2, M, lambd, verbose=True)
print (Gs1)

Screenshot 2022年11月19日 at 08 23 26

The output I get is -

tensor([[0.0001, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
 [0.0000, 0.0001, 0.0000, ..., 0.0000, 0.0000, 0.0000],
 [0.0000, 0.0000, 0.0001, ..., 0.0000, 0.0000, 0.0000],
 ...,
 [0.0000, 0.0000, 0.0000, ..., 0.0001, 0.0000, 0.0000],
 [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0001, 0.0000],
 [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0001]],
 dtype=torch.float64)

Since the source and target are close by I want the wasserstein distance to be smaller here.

But for distributions looking like this -

l1[0] = 1.0
l2[1] = 0.0
l2[99] = 1.0
lambd = 1e-3
Gs2 = ot.sinkhorn(l1, l2, M, lambd, verbose=True)
print (Gs2)
pl.figure(1, figsize=(6.4, 3))
pl.plot(x, l1, 'b', label='Source distribution')
pl.plot(x, l2, 'r', label='Target distribution')
pl.legend()

Screenshot 2022年11月19日 at 08 24 26

The output I get is -

tensor([[0.0001, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
 [0.0000, 0.0001, 0.0000, ..., 0.0000, 0.0000, 0.0000],
 [0.0000, 0.0000, 0.0001, ..., 0.0000, 0.0000, 0.0000],
 ...,
 [0.0000, 0.0000, 0.0000, ..., 0.0001, 0.0000, 0.0000],
 [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0001, 0.0000],
 [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0001]],
 dtype=torch.float64)

which also is same? How can I get the distance for this to be larger than the first one? Any insights are welcome. Thank you in advance!

You must be logged in to vote

Replies: 1 comment 3 replies

Comment options

Actually, when i run your code there is a warning that numerical error occured which means that teh returned plan is not a proper OT plan and violates the marginal contraints

For instance in your solution violates the marginal contraintsbecaiuse the retured plan has uniform weights and your weights are sparse.

Sinkhorn is particularly bad at solving sparse problems. For those i suggets to use exact OT :

n = 100 # nb bins
x = torch.arange(n,dtype=torch.float64)
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
l1 = torch.zeros(100)
l2 = torch.zeros(100)
l1[0] = 1.0
l2[1] = 1.0
lambd = 1e-3
Gs1 = ot.emd(l1, l2, M)
print (Gs1)

that returns

tensor([[0., 1., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.],
 ...,
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.]], dtype=torch.float64)

and for the second problem:

tensor([[0., 0., 0., ..., 0., 0., 1.],
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.],
 ...,
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.]], dtype=torch.float64)

both are the correct solutions.

You must be logged in to vote
3 replies
Comment options

Hi! thank you for your response, but I want the distance to be different if the distributions are far off - like for example here,

Screenshot 2022年11月19日 at 08 24 26.

Shouldn't the Wasserstein distance be large for this?

I get this,

tensor([[0., 0., 0., ..., 0., 0., 1.],
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.],
 ...,
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.],
 [0., 0., 0., ..., 0., 0., 0.]])

Is there a way to get a single number distance instead of like a matrix? Like for example as shown here - https://github.com/dfdazac/wassdistance/blob/master/sinkhorn.ipynb

Comment options

yes either you use directly ot.emd2that returns the distance (as said in the doc) or you can compute it afterard with torch.sum(M*Gs1)`

Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet

AltStyle によって変換されたページ (->オリジナル) /