-
Couldn't load subscription status.
- Fork 537
-
H! @rflamary
This is a really cool library. I was wondering if you can provide me with any insights on this problem -
I am trying to find the distance between two distributions, for example -
n = 100 # nb bins x = torch.arange(n,dtype=torch.float64) M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') l1 = torch.zeros(100) l2 = torch.zeros(100) l1[0] = 1.0 l2[1] = 1.0 lambd = 1e-3 Gs1 = ot.sinkhorn(l1, l2, M, lambd, verbose=True) print (Gs1)
Screenshot 2022年11月19日 at 08 23 26
The output I get is -
tensor([[0.0001, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0001, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0001, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0001, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0001, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0001]], dtype=torch.float64)
Since the source and target are close by I want the wasserstein distance to be smaller here.
But for distributions looking like this -
l1[0] = 1.0 l2[1] = 0.0 l2[99] = 1.0 lambd = 1e-3 Gs2 = ot.sinkhorn(l1, l2, M, lambd, verbose=True) print (Gs2) pl.figure(1, figsize=(6.4, 3)) pl.plot(x, l1, 'b', label='Source distribution') pl.plot(x, l2, 'r', label='Target distribution') pl.legend()
Screenshot 2022年11月19日 at 08 24 26
The output I get is -
tensor([[0.0001, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0001, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0001, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0001, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0001, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0001]], dtype=torch.float64)
which also is same? How can I get the distance for this to be larger than the first one? Any insights are welcome. Thank you in advance!
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 1 comment 3 replies
-
Actually, when i run your code there is a warning that numerical error occured which means that teh returned plan is not a proper OT plan and violates the marginal contraints
For instance in your solution violates the marginal contraintsbecaiuse the retured plan has uniform weights and your weights are sparse.
Sinkhorn is particularly bad at solving sparse problems. For those i suggets to use exact OT :
n = 100 # nb bins x = torch.arange(n,dtype=torch.float64) M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') l1 = torch.zeros(100) l2 = torch.zeros(100) l1[0] = 1.0 l2[1] = 1.0 lambd = 1e-3 Gs1 = ot.emd(l1, l2, M) print (Gs1)
that returns
tensor([[0., 1., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=torch.float64)
and for the second problem:
tensor([[0., 0., 0., ..., 0., 0., 1.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=torch.float64)
both are the correct solutions.
Beta Was this translation helpful? Give feedback.
All reactions
-
Hi! thank you for your response, but I want the distance to be different if the distributions are far off - like for example here,
Screenshot 2022年11月19日 at 08 24 26.
Shouldn't the Wasserstein distance be large for this?
I get this,
tensor([[0., 0., 0., ..., 0., 0., 1.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]])
Is there a way to get a single number distance instead of like a matrix? Like for example as shown here - https://github.com/dfdazac/wassdistance/blob/master/sinkhorn.ipynb
Beta Was this translation helpful? Give feedback.
All reactions
-
yes either you use directly ot.emd2that returns the distance (as said in the doc) or you can compute it afterard with torch.sum(M*Gs1)`
Beta Was this translation helpful? Give feedback.
All reactions
-
Beta Was this translation helpful? Give feedback.