Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Auto-diff for Optimal Transport #356

Answered by rflamary
HowardZJU asked this question in Q&A
Discussion options

Hi, thanks for providing such an amazing toolkit. Now I have two distributions $a$ and $b$ parameterized by two neural networks. Now I want to update the two neural networks to minimize the W-distance between $a$ and $b$. I have witnessed that for some ot solvers such as emd2, here are descriptors as:

Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss.
It inspires the questions as follows:

  1. For solvers with this descriptor, we just add the output distance to loss function as did in link, and the gradient dependency will be automatically configured, is it right?

  2. For solvers without this descriptor, such as ot.unbalanced.sinkhorn_knopp_unbalanced, how to conduct backward propagation? #329 suggest re-computing W-distance with torch.sum(plan*M) which will be differentiable. So in this setting, the plan matrix would be constant and the gradient would flow from M to $a$ and $b$, is it right?

  3. If both 1 and 2 are right, why do not use a universal API to calculate W-distance? For example we solve the plan matrix firstly and then recalculate the W-distance, which would be auto-differentiable.

You must be logged in to vote

Hello,

For exact OT the plan itself is not differentiable wrt the input but be have the gradient wrt the masses as a result of solving the problem (hence the need to use emd2 because the gradient is properly defined in the function). In emd the ot plan is detached from the input so if you recompute the loss it will be only wrt the distance matrix, not the masses.

The unbalanced problems are smooth and we only use differentiable operation in the algorithm so you can use autodiff to compute the gradients wrt all the inputs. the plan is not detached for those functions so you can recompute the loss and the gradient will back-propagate.

EDIT: more precisely unbalanced solvers are not differen...

Replies: 1 comment 4 replies

Comment options

Hello,

For exact OT the plan itself is not differentiable wrt the input but be have the gradient wrt the masses as a result of solving the problem (hence the need to use emd2 because the gradient is properly defined in the function). In emd the ot plan is detached from the input so if you recompute the loss it will be only wrt the distance matrix, not the masses.

The unbalanced problems are smooth and we only use differentiable operation in the algorithm so you can use autodiff to compute the gradients wrt all the inputs. the plan is not detached for those functions so you can recompute the loss and the gradient will back-propagate.

EDIT: more precisely unbalanced solvers are not differentiable yet I'm sorry but they will be very shortly because we have a PR #343 that will add it too the toolbox.

You must be logged in to vote
4 replies
Comment options

Thanks for your excellent answer! I want to further make sure the details as follows. I noticed that you said we can use autodiff to compute the gradients wrt all the inputs in UOT, and I tested as follows:
xn = torch.randn(size_batch, n_features)
xd = get_data(size_batch) # generate data samples
xg = G(xn)
M = ot.dist(xg, xd)
Gs = ot.unbalanced.sinkhorn_unbalanced(ab, ab, M, reg=0.1, reg_m=1)
but it returns RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

If I use M.detach() instead, the returned Gs would be a ndarray without gradient dependency. If the plan is not detached for those functions as mentioned in the answer, why is the returned Gs a ndarray?

Comment options

The test code for reproduction is:

import numpy as np
import matplotlib.pyplot as pl
import matplotlib.animation as animation
import torch
from torch import nn
import ot

torch.manual_seed(1)
sigma = 0.1
n_features = 2

def get_data(n_samples):
c = torch.rand(size=(n_samples, 1))
angle = c * 2 * np.pi
x = torch.cat((torch.cos(angle), torch.sin(angle)), 1)
x += torch.randn(n_samples, 2) * sigma
return x

x = get_data(500)

class Generator(torch.nn.Module):
def init(self):
super(Generator, self).init()
self.fc1 = nn.Linear(n_features, 100)
self.fc2 = nn.Linear(100, 200)
self.fc3 = nn.Linear(200, n_features)
self.relu = torch.nn.ReLU() # instead of Heaviside step fn

def forward(self, x):
 output = self.fc1(x)
 output = self.relu(output) # instead of Heaviside step fn
 output = self.fc2(output)
 output = self.relu(output)
 output = self.fc3(output)
 return output

G = Generator()
optimizer = torch.optim.RMSprop(G.parameters(), lr=1e-4, eps=1e-5)

n_iter = 3000
size_batch = 128

n_visu = 100
xnvisu = torch.randn(n_visu, n_features)
xvisu = torch.zeros(n_iter, n_visu, n_features)
ab = torch.ones(size_batch) / size_batch
losses = []

for i in range(n_iter):

xn = torch.randn(size_batch, n_features)
xd = get_data(size_batch)
xvisu[i, :, :] = G(xnvisu).detach()
xg = G(xn)
M = ot.dist(xg, xd)
Gs = ot.unbalanced.sinkhorn_unbalanced(ab, ab, M.detach(), reg=0.1, reg_m=1)
Gs = torch.tensor(Gs)
loss = torch.sum(Gs * M)
losses.append(float(loss.detach()))
if i % 10 == 0:
 print("Iter: {:3d}, loss={}".format(i, losses[-1]))
loss.backward()
optimizer.step()
Comment options

as I said in my edit the UOT solvers don't YET allow differentiating but it is coming soon sinec we have a PR and a new release in the piped. Sorry OI was not precise enough I did not remember that the code was not merged yet. In the meantime feel free to use the version of POT from the PR it should be very stable

Comment options

Thanks for your excellent answer! I choose to recalculate the W-distance via gamma * M. I have also tried the branch you mentioned, which supports the auto-diff for UOT. But I found little gain compared with recalculating gamma*M.

Thanks a lot for you help, sincerely!

Answer selected by HowardZJU
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet

AltStyle によって変換されたページ (->オリジナル) /