Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Pytorch backward gradient estimation of OT functions #501

Unanswered
bornabat asked this question in Q&A
Discussion options

I am currently utilizing the ot.gromov.gwloss and ot.gromov.init_matrix functions from the OT library to compute the gw distance between two datasets and, eventually, its gradient with regard to one of the datasets. While integrating the said function, I stumbled upon a critical query.

Given that the ot.gromov.gwloss function seems to accommodate an abstract backend approach (with nx.backend), I have passed PyTorch tensors as inputs, hoping to maintain the computational graph and utilize PyTorch's autograd mechanism. While the function does return a torch tensor as output, I am uncertain about whether the operations inside ot.gromov.gwloss are PyTorch native. While I have observed methods like nx.dot, which appear to adapt based on the said backend, I wanted to ensure that there aren't any underlying numpy operations that would break the computational graph.

My questions are:

  • Can you confirm that when PyTorch tensors are input to ot.gromov.gwloss, all internal operations (and any auxiliary functions it calls) remain fully PyTorch-native and differentiable? In a way that guarantees the correctness of gradients obtained via .backward() when using this function with PyTorch tensors?

  • If no, does the ot library accommodate taking the gradient of the GW distance with regard to one of the datasets in any manner? (I have already looked at the ot.gromov.gwggrad function, but apparently it does not compute the gradient that I want; i.e. , with regard to one of the datasets only.)

You must be logged in to vote

Replies: 2 comments 1 reply

Comment options

I believe the OT library doesn't have a built-in feature that directly computes the gradient of the GW distance concerning one of the datasets. Particularly, within DL frameworks like PyTorch, there isn't a straightforward way to achieve this using the library's existing functions.

The ot.gromov.gwggrad function available in the library does calculates the gradient of the GW loss with respect to the coupling matrix, but it doesn't handle the computation of gradients with respect to the input datasets themselves.

You must be logged in to vote
1 reply
Comment options

actually yes POT provides most useful gradients when using the pytorch backend, that is when giving pytorch tensors to the solvers.

Comment options

gwloss in used internally in the ot.gromov_wassersytein(2) solver but yes it uses only pytorch operation when being computed on pytorch tensors so it should provide proper gradients (you should check numerically) wrt the data since it uses only pytorch operation in this case.

Still keep in mind that gwloss returns the loss for a given OT plan not the optimal one for that you need to ue ot.gromov_wassretsein2 function that defined all gadients properly (also wrt th marginal weights and data.)

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet

AltStyle によって変換されたページ (->オリジナル) /