-
Notifications
You must be signed in to change notification settings - Fork 537
-
Hi, thanks for providing such an amazing toolkit. Now I have two distributions
Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss.
It inspires the questions as follows:
-
For solvers with this descriptor, we just add the output distance to loss function as did in link, and the gradient dependency will be automatically configured, is it right?
-
For solvers without this descriptor, such as ot.unbalanced.sinkhorn_knopp_unbalanced, how to conduct backward propagation? #329 suggest re-computing W-distance with torch.sum(plan*M) which will be differentiable. So in this setting, the plan matrix would be constant and the gradient would flow from M to
$a$ and$b$ , is it right? -
If both 1 and 2 are right, why do not use a universal API to calculate W-distance? For example we solve the plan matrix firstly and then recalculate the W-distance, which would be auto-differentiable.
Beta Was this translation helpful? Give feedback.
All reactions
-
🎉 1
Hello,
For exact OT the plan itself is not differentiable wrt the input but be have the gradient wrt the masses as a result of solving the problem (hence the need to use emd2 because the gradient is properly defined in the function). In emd the ot plan is detached from the input so if you recompute the loss it will be only wrt the distance matrix, not the masses.
The unbalanced problems are smooth and we only use differentiable operation in the algorithm so you can use autodiff to compute the gradients wrt all the inputs. the plan is not detached for those functions so you can recompute the loss and the gradient will back-propagate.
EDIT: more precisely unbalanced solvers are not differen...
Replies: 1 comment 4 replies
-
Hello,
For exact OT the plan itself is not differentiable wrt the input but be have the gradient wrt the masses as a result of solving the problem (hence the need to use emd2 because the gradient is properly defined in the function). In emd the ot plan is detached from the input so if you recompute the loss it will be only wrt the distance matrix, not the masses.
The unbalanced problems are smooth and we only use differentiable operation in the algorithm so you can use autodiff to compute the gradients wrt all the inputs. the plan is not detached for those functions so you can recompute the loss and the gradient will back-propagate.
EDIT: more precisely unbalanced solvers are not differentiable yet I'm sorry but they will be very shortly because we have a PR #343 that will add it too the toolbox.
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1
-
Thanks for your excellent answer! I want to further make sure the details as follows. I noticed that you said we can use autodiff to compute the gradients wrt all the inputs in UOT, and I tested as follows:
xn = torch.randn(size_batch, n_features)
xd = get_data(size_batch) # generate data samples
xg = G(xn)
M = ot.dist(xg, xd)
Gs = ot.unbalanced.sinkhorn_unbalanced(ab, ab, M, reg=0.1, reg_m=1)
but it returns RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
If I use M.detach() instead, the returned Gs would be a ndarray without gradient dependency. If the plan is not detached for those functions as mentioned in the answer, why is the returned Gs a ndarray?
Beta Was this translation helpful? Give feedback.
All reactions
-
The test code for reproduction is:
import numpy as np
import matplotlib.pyplot as pl
import matplotlib.animation as animation
import torch
from torch import nn
import ot
torch.manual_seed(1)
sigma = 0.1
n_features = 2
def get_data(n_samples):
c = torch.rand(size=(n_samples, 1))
angle = c * 2 * np.pi
x = torch.cat((torch.cos(angle), torch.sin(angle)), 1)
x += torch.randn(n_samples, 2) * sigma
return x
x = get_data(500)
class Generator(torch.nn.Module):
def init(self):
super(Generator, self).init()
self.fc1 = nn.Linear(n_features, 100)
self.fc2 = nn.Linear(100, 200)
self.fc3 = nn.Linear(200, n_features)
self.relu = torch.nn.ReLU() # instead of Heaviside step fn
def forward(self, x):
output = self.fc1(x)
output = self.relu(output) # instead of Heaviside step fn
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
G = Generator()
optimizer = torch.optim.RMSprop(G.parameters(), lr=1e-4, eps=1e-5)
n_iter = 3000
size_batch = 128
n_visu = 100
xnvisu = torch.randn(n_visu, n_features)
xvisu = torch.zeros(n_iter, n_visu, n_features)
ab = torch.ones(size_batch) / size_batch
losses = []
for i in range(n_iter):
xn = torch.randn(size_batch, n_features)
xd = get_data(size_batch)
xvisu[i, :, :] = G(xnvisu).detach()
xg = G(xn)
M = ot.dist(xg, xd)
Gs = ot.unbalanced.sinkhorn_unbalanced(ab, ab, M.detach(), reg=0.1, reg_m=1)
Gs = torch.tensor(Gs)
loss = torch.sum(Gs * M)
losses.append(float(loss.detach()))
if i % 10 == 0:
print("Iter: {:3d}, loss={}".format(i, losses[-1]))
loss.backward()
optimizer.step()
Beta Was this translation helpful? Give feedback.
All reactions
-
as I said in my edit the UOT solvers don't YET allow differentiating but it is coming soon sinec we have a PR and a new release in the piped. Sorry OI was not precise enough I did not remember that the code was not merged yet. In the meantime feel free to use the version of POT from the PR it should be very stable
Beta Was this translation helpful? Give feedback.
All reactions
-
Thanks for your excellent answer! I choose to recalculate the W-distance via gamma * M. I have also tried the branch you mentioned, which supports the auto-diff for UOT. But I found little gain compared with recalculating gamma*M.
Thanks a lot for you help, sincerely!
Beta Was this translation helpful? Give feedback.