-
Notifications
You must be signed in to change notification settings - Fork 536
semirelaxed Gromov-Wasserstein initial transport plan #707
-
Describe the bug
I’m facing an issue where the results of the semirelaxed Gromov-Wasserstein (srGW) method differ significantly, even though the initial matrices OT (from the Gromov-Wasserstein result) and the default matrices (which I derived from the uniform distributions in the source and target space, as described in the documentation for when G0=None) appear to be quite similar. Despite this, the srGW distances are different depending on which initialization is used. Can you explain why this discrepancy occurs, even though the initializations seem to be similar?
Thanks a lot for your help !
To Reproduce
See code sample
Expected behavior
With the same initialization, the function should yield the same result.
Environment (please complete the following information):
- Linux
- Python version: 3.11.6
- How was POT installed: pip
Linux-6.5.0-44-generic-x86_64-with-glibc2.38
Python 3.11.6 (main, Apr 10 2024, 17:26:07) [GCC 13.2.0]
NumPy 2.1.3
SciPy 1.14.1
POT 0.9.5
Code sample
import numpy as np
from ot.gromov import (
semirelaxed_gromov_wasserstein,
semirelaxed_fused_gromov_wasserstein,
gromov_wasserstein,
fused_gromov_wasserstein,
)
import networkx
from networkx.generators.community import stochastic_block_model as sbm
# 1st graph
N2 = 30 # 2 communities
p2 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]]
G2 = sbm(seed=0, sizes=[N2 // 3, N2 // 3, N2 // 3], p=p2)
C2 = networkx.to_numpy_array(G2)
# 2nd graph
C3 = np.identity(4) #matrice d'adjacence
# Uniform distributions
h2 = np.ones(C2.shape[0]) / C2.shape[0]
h3 = np.ones(C3.shape[0]) / C3.shape[0]
# 0) GW(C2, h2, C3, h3) for reference
OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
gw = log["gw_dist"]
# 1) srGW(C2, h2, C3)
OT_23, log_23 = semirelaxed_gromov_wasserstein(
C2, C3, h2, symmetric=True, log=True, G0=OT
)
srgw_23 = log_23["srgw_dist"]
# 2) srGW(C2, h2, C3) with default
OT_23_default, log_23_default = semirelaxed_gromov_wasserstein(
C2, C3, h2, symmetric=True, log=True, G0=None
)
srgw_23_default = log_23_default["srgw_dist"]
# 3) srGW(C2, h2, C3) with default
initial = np.array(np.matmul(np.matrix(h2).T, np.matrix(h3)))
OT_23_initial, log_23_initial = semirelaxed_gromov_wasserstein(
C2, C3, h2, symmetric=True, log=True, G0=initial
)
srgw_23_initial = log_23_initial["srgw_dist"]
print("GW(C2, C3) = ", gw) #GW(C2, C3) = 0.41222222222222227
print("G0 = OT, srGW(C2, h2, C3) = ", srgw_23) #G0 = OT, srGW(C2, h2, C3) = 0.1377777777777782
print("G0 = default, srGW(C2, h2, C3) = ", srgw_23_default) #G0 = default, srGW(C2, h2, C3) = 0.41222222222222227
print("G0 = same as default according to doc, srGW(C2, h2, C3) = ", srgw_23_initial) #G0 = same as default according to doc, srGW(C2, h2, C3) = 0.41222222222222227
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 1 comment 1 reply
-
Hello @KetsiaGuichard ,
Thank you for your feedback. Overall, these issues are only related to the solver and not to POT's implementation which is correct.
Indeed the srGW conditional gradient solver is quite sensitive to its initialisation, as it is non-convex as GW and looks for solutions in a bigger space as the constraints over the target marginal are removed. These aspects are partly discussed in the original srGW paper, where we addressed this task of matching to an identity matrix. In this specific case, the classical product of marginals is a local optimum where the solver gets stuck (see Section 7.5 in the Supplementary). This is not necessarily the case if the target structure is more informative, which is why we kept the product of marginals as default in POT.
Unfortunately, for now good initialisations appear to be quite data dependent. For tasks that relate to graph partitioning we implemented various initialization strategies that can be used by the user in ot.gromov.semirelaxed_init_plan which can be directly called from the ot.gromov.semirelaxed_gromov_wasserstein solver and related ones, setting G0 as a string taking value in [ "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft"].
Hope it helps.
Best,
Cédric
Beta Was this translation helpful? Give feedback.
All reactions
-
Hi @cedricvincentcuaz ,
Thank you for your detailed explanation. I apologize—I must have overlooked this aspect in the original srGW paper. I’ll revisit it (especially Section 7.5) to better understand the initialization strategies you’ve highlighted.
Your clarification regarding the solver's sensitivity and the various initialization options in POT is greatly appreciated!
Best,
Ketsia
Beta Was this translation helpful? Give feedback.