-
Couldn't load subscription status.
- Fork 537
-
Describe the bug
The formulas in gwggrad and solve_gromov_linesearch have typos and do not match the cited references [12] and [24]. I also calculated the gradient by hand to confirm that POT has typos.
For concreteness, I'm using the following versions of the cited papers:
[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
IN: https://proceedings.mlr.press/v48/peyre16.pdf
[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
IN: https://arxiv.org/pdf/1805.09114.pdf
To Reproduce
- Run the code below.
Code sample
Notes:
- The code below is a modification of
plot_gromov.pyfrom the examples gallery. I computed the GW distance two times, one using POT and another with my corrections implemented ingwggrad_modandsolve_gromov_linesearch_mod. - The typos did not affect the result of the Gromov-Wasserstein distance in my example, but I wonder if making sub-optimal choices in line-search will affect the speed of convergence in more complicated calculations.
import scipy as sp import numpy as np import ot # Import functions required in ot.gromov._gw from ot.utils import list_to_array from ot.optim import cg, solve_1d_linesearch_quad from ot.backend import get_backend, NumpyBackend from ot.gromov._utils import init_matrix, gwloss, gwggrad from ot.gromov._gw import solve_gromov_linesearch ############################################################################# # # Sample two Gaussian distributions (2D and 3D) # --------------------------------------------- ############################################################################# n_samples = 30 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4, 4]) cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) np.random.seed(0) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) P = sp.linalg.sqrtm(cov_t) xt = np.random.randn(n_samples, 3).dot(P) + mu_t C1 = sp.spatial.distance.cdist(xs, xs) C2 = sp.spatial.distance.cdist(xt, xt) C1 /= C1.max() C2 /= C2.max() ############################################################################# # # Parameters for dGW # --------------------------------------------- ############################################################################# p = ot.unif(n_samples) q = ot.unif(n_samples) G0 = p[:, None] * q[None, :] loss_fun='square_loss' symmetric=None log=True armijo=False max_iter=1e4 tol_rel=1e-9 tol_abs=1e-9 ############################################################################# # # gwggrad and solve_gromov_linesearch with typos corrected # --------------------------------------------- ############################################################################# def gwggrad_mod(constC, hC1, hC2, T, nx=None): if nx is None: constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) nx = get_backend(constC, hC1, hC2, T) return constC - 2 * nx.dot( nx.dot(hC1, T), hC2.T ) def solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): if nx is None: G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) if isinstance(M, int) or isinstance(M, float): nx = get_backend(G, deltaG, C1, C2) else: nx = get_backend(G, deltaG, C1, C2, M) dot_dG = nx.dot(nx.dot(C1, deltaG), C2.T) dot_G = nx.dot(nx.dot(C1, G ), C2.T) a = -2 * reg * nx.sum(dot_dG * deltaG) b = nx.sum(M * deltaG) + reg * (nx.sum(constC * deltaG) - 2 * nx.sum(dot_dG * G) - 2 * nx.sum(dot_G * deltaG)) alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) # the new cost is deduced from the line search quadratic function cost_G = cost_G + a * (alpha ** 2) + b * alpha return alpha, 1, cost_G ############################################################################# # # Compute Gromov-Wasserstein with modified functions # --------------------------------------------- ############################################################################# # cg for GW is implemented using numpy on CPU np_ = NumpyBackend() nx = get_backend(C1, C2, p, q) p0, q0, C10, C20 = p, q, C1, C2 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_) ###################################################################### # Define loss function, gradient and linesearch # --------------------------------------------- # NOTE: Using modified gwgrad and line_search def f(G): return gwloss(constC, hC1, hC2, G, np_) def df(G): return gwggrad_mod(constC, hC1, hC2, G, np_) def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M=0., reg=1., nx=np_, **kwargs) ###################################################################### res_mod, log_mod = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs) log_mod['gw_dist'] = nx.from_numpy(log_mod['loss'][-1], type_as=C1) log_mod['u'] = nx.from_numpy(log_mod['u'], type_as=C1) log_mod['v'] = nx.from_numpy(log_mod['v'], type_as=C1) gw_mod = nx.from_numpy(res_mod, type_as=C1) # Compute GW with the original function gw0, log0 = ot.gromov.gromov_wasserstein( C1, C2, p, q, 'square_loss', verbose=True, log=True) ############################################################################# # # Compare gwggrad and solve_gromov_linesearch with their modified versions # --------------------------------------------- ############################################################################# G = G0 deltaG = np.random.rand(*G.shape) cost_G = 0 grad_mod = gwggrad_mod(constC, hC1, hC2, G, np_) grad = gwggrad(constC, hC1, hC2, G, np_) linesearch_mod = solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M=0., reg=1., nx=np_) linesearch = solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_) print() print(f"dGW with func: {log0['gw_dist']}") print(f"dGW with mods: {log_mod['gw_dist']}") print("GW-distances agree:", log0['gw_dist'] == log_mod['gw_dist']) print() print('Gradients agree:', np.array_equal(grad_mod, grad)) print('Line-search results agree:', linesearch_mod == linesearch)
Expected behavior
The functions gwggrad and solve_gromov_linesearch should output the result of gwggrad_mod and solve_gromov_linesearch_mod, respectively.
Environment
Output of the following code snippet:
import platform; print(platform.platform()) import sys; print("Python", sys.version) import numpy; print("NumPy", numpy.__version__) import scipy; print("SciPy", scipy.__version__) import ot; print("POT", ot.__version__)
Linux-6.5.7-100.fc37.x86_64-x86_64-with-glibc2.36
Python 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
NumPy 1.24.3
SciPy 1.11.1
POT 0.9.1
Additional context
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 3 comments 3 replies
-
Hello @mr-gomez ,
Thank you for your detailed issue. I transformed it as a discussion so that it will remain easily visible in POT.
Actually there are small typing errors in the papers that you mentioned. These errors were corrected in POT implementation by authors /contributors. I had to do some modifications also to be able to operate on asymmetric matrices.
In order to have a clear reference with cleaned equations, I started a note on my website that you can find here : Note.
Conclusion : there are no mistake in POT implementation for the gradient computation or the exact line-search. And we will soon integrate an exact line-search for the KL inner loss.
Best,
Cédric
Beta Was this translation helpful? Give feedback.
All reactions
-
Hi Cédric.
Thanks for the clarifications and for posting such a detailed document. In fact, the comment in ot.gromov._util.gwggrad regarding the typo in Peyré et al prompted me to check the calculations on my own. I need to understand these functions to adapt them to another project, so I decided to open this report after my calculations didn't agree with either the paper or POT.
Regarding your computations, I now agree that the current implementation of the exact line-search step is correct. However, I'm still not sure about the gradient. My issue is that
-
$c_{\mathbf{C}, \overline{\mathbf{C}}} = f_1(\mathbf{C}) \mathbf{p} \mathbf{1}_m^\top + \mathbf{1}_n \mathbf{q}^\top f_2(\overline{\mathbf{C}})^\top$ is also constant w.r.t.$\mathbf{T}$ , -
$\mathcal{L}(\mathbf{C}, \overline{\mathbf{C}}) \otimes \mathbf{T} = c_{\mathbf{C}, \overline{\mathbf{C}}} - h_1(\mathbf{C}) \mathbf{T} h_2(\overline{\mathbf{C}})^\top$ , and -
$\mathcal{E}_L^{GW}(\mathbf{C}, \overline{\mathbf{C}}, \mathbf{T}) = \langle \mathcal{L}(\mathbf{C}, \overline{\mathbf{C}}) \otimes \mathbf{T}, \mathbf{T} \rangle = \langle c, \mathbf{T} \rangle - \langle h_1(\mathbf{C}) \mathbf{T} h_2(\overline{\mathbf{C}})^\top, \mathbf{T} \rangle$ ,
I find \overline consistently). We get:
-
$\nabla_{\mathbf{T}} \langle c_{\mathbf{C},\mathbf{D}}, \mathbf{T} \rangle = c_{\mathbf{C},\mathbf{D}}$ because$c_{\mathbf{C},\mathbf{D}}$ is constant w.r.t.$\mathbf{T}$ . -
$\frac{\partial}{\partial T_{pq}} \langle h_1(\mathbf{C}) \mathbf{T} h_2(\mathbf{D})^\top, \mathbf{T} \rangle = \frac{\partial}{\partial T_{pq}} \sum_{ijkl} h_1(C_{ik}) h_2(D_{jl}) T_{ij} T_{kl} = \sum_{kl} h_1(C_{pk}) h_2(D_{ql}) T_{kl} + \sum_{ij} h_1(C_{ip}) h_2(D_{jq}) T_{ij}$ . These sums are the$(p,q)$ entries of$h_1(\mathbf{C}) \mathbf{T} h_2(\mathbf{D})^\top$ and$h_1(\mathbf{C})^\top \mathbf{T} h_2(\mathbf{D})$ , respectively. Then$\nabla \langle h_1(\mathbf{C}) \mathbf{T} h_2(\mathbf{D})^\top, \mathbf{T} \rangle = h_1(\mathbf{C}) \mathbf{T} h_2(\mathbf{D})^\top + h_1(\mathbf{C})^\top \mathbf{T} h_2(\mathbf{D})$ .
Thus, I get that the gradient of
in the general case, and
in the symmetric case. In both cases, I'm missing a
Thus, am I right to treat
Best,
Mario
Beta Was this translation helpful? Give feedback.
All reactions
-
Hello Mario,
Indeed, roughly speaking it is a mistake to consider those as constants w.r.t
Briefly, the theoretical justification that you seek for is essentially the definition of the gradient of
(Struggling to get it clean with markdown) You can find a such relation in Equation 6 of Rémi's course here. You can also find other courses e.g on constrained optimization which could be useful to get the bigger picture.
Hope it helps,
Cédric.
Beta Was this translation helpful? Give feedback.
All reactions
-
Hi Cédric.
I understand that
Having said this, I want to clarify that the discrepancy in our calculations is caused by using slightly different representations of
- In the notes you linked, you are taking
$\dfrac{\partial}{\partial T_{pq}}$ of$$\sum_{ijkl} \left[ f_1(C_{ik}) + f_2(D_{jl}) - h_1(C_{ik}) h_2(D_{jl}) \right] T_{ij}T_{kl}. \tag{1}$$ - With some work, we can show that the
$(i,j)$ entry of$\langle c_{\mathbf{C}, \mathbf{D}}, \mathbf{T} \rangle - \langle h_1(\mathbf{C}) \mathbf{T} h_2(\mathbf{D})^\top, \mathbf{T} \rangle$ is
$$\sum_{ijk} f_1(C_{ik}) T_{ij}p_{k} + \sum_{ijl} f_2(D_{jl}) T_{ij}q_{l} - \sum_{ijkl} h_1(C_{ik}) h_2(D_{jl}) T_{ij}T_{kl}. \tag{2}$$
Note that the two expressions agree because we can expand
we have to use the product rule inside all three sums when taking
Hence, my question becomes: Which expression should be used to compute
Thanks,
Mario
Beta Was this translation helpful? Give feedback.
All reactions
-
Hello Mario,
Sorry if I was not explicit enough and actually my previous answer is not exhaustive. We need to be more rigorous to talk about this, because I believe that our two solvers are correct but not really for the reason you mentioned ;)
1. Solving the original GW problem. The analysis in my note aims at solving for equation 1, considering the objective function
Considering my previous message, you can express the GW cost i.e our objective function as
However,
Saying that
I hope it is clear because I don't see how to put it more clearly. Maybe the following exaggerated analogy can help: consider
2. Solving an equivalent probem to the GW problem. I believe that the algorithm you implemented actually comes down to solving an equivalent problem to the original one that I detailed in the note, and which is implemented in POT. The true reason behind this is that we are not talking about the same objective functions.
We have the following equivalent formulations for the GW problem (1):
As you mentioned, we have
This objective function does indeed match your computation. For which your gradient computation would be right but we were not talking about the same thing exactly.
Notice that we could also go further by considering the other equivalent problem (3)
Best,
Cédric
Beta Was this translation helpful? Give feedback.
All reactions
-
Hi Cédric,
I understand the difference now. Indeed, I thought that the gradient of
Thanks for your help!
Mario
Beta Was this translation helpful? Give feedback.