Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 3a53dff

Browse files
KrzakalaPaulPaulKrzakalarflamarycedricvincentcuaz
authored
Batch OT losses (Sinkhorn + Gromov) (#755)
* linear ot implemented * improve stopping criterion and assymetric case * Add recompute_const and simplify the pipeline for the symmetric = False * add tests * update the examples and rename to follow the "ot.solve" naming conventions * update realeases.md * idem * move set_grad_enabled to backend * set_grad_enabled for quadratric solver * update doc * remove useless importation in doc * Update references * update example * Remove classes in quadratic, move examples to backend, add potentials, remove context managers for grads. To do: improve doc and tests * updat tests * Massive improvement of the documentation for ot.batch * cover (almost) all ot.batch with tests * bug in the tests * update docstring * highlight that ot.batch is solving the entropic version * removing yet another error in the docstring * Add missing parameter recompute_const * Remove png, add all backends and gradient mode to tests * add the missing pytest * change .sum() into nx.sum * add missing backend * yet another missing nx * remove useless squeeze and add test for non-log bregman * remove last_step from quadratic tests * add missing tests and improve documentation * proper unsqueeze test * add unsqueeze to tensorflow * solve double backprop issue in test_gradients_torch --------- Co-authored-by: PaulKrzakala <paul.krzakala@gmail.com> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
1 parent 803d2ab commit 3a53dff

File tree

17 files changed

+1940
-22
lines changed

17 files changed

+1940
-22
lines changed

‎.gitignore‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,5 @@ debug
123123

124124
# pytest cahche
125125
.pytest_cache
126+
127+
docs/source/

‎README.md‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,8 @@ Artificial Intelligence.
446446
[79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations.
447447

448448
[80] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019.
449+
450+
[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS).
451+
452+
453+
```

‎RELEASES.md‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)
2727
- Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743)
2828
- Removed release information from quickstart guide (PR #744)
29+
- Implement batch parallel solvers in ot.batch (PR #745)
2930
- Update REAMDE with new API and reorganize examples (PR #754)
3031

3132
#### Closed issues

‎docs/source/all.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ API and modules
1515

1616

1717
backend
18+
batch
1819
bregman
1920
coot
2021
da

‎docs/source/user_guide.rst‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,3 +1217,6 @@ References
12171217
couplings <http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf>`_. In
12181218
The 22nd International Conference on Artificial Intelligence and Statistics
12191219
(pp. 2454-2465). PMLR.
1220+
1221+
.. [41] Xu, H., Luo, D., & Carin, L. (2019). `Scalable Gromov-Wasserstein learning for graph partitioning and matching
1222+
<https://arxiv.org/abs/1906.03666>`_\ , Advances in neural information processing systems, 32.

‎examples/backends/plot_ot_batch.py‎

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
=================================================
3+
Solving Many Optimal Transport Problems in Parallel
4+
=================================================
5+
6+
In some situations, one may want to solve many OT problems with the same
7+
structure (same number of samples, same cost function, etc.) at the same time.
8+
9+
In that case using a for loop to solve the problems sequentially is inefficient.
10+
This example shows how to use the batch solvers implemented in POT to solve
11+
many problems in parallel on CPU or GPU (even more efficient on GPU).
12+
13+
"""
14+
15+
# Author: Paul Krzakala <paul.krzakala@gmail.com>
16+
# License: MIT License
17+
18+
# sphinx_gallery_thumbnail_number = 1
19+
20+
21+
#############################################################################
22+
#
23+
# Computing the Cost Matrices
24+
# ---------------------------------------------
25+
#
26+
# We want to create a batch of optimal transport problems with
27+
# :math:`n` samples in :math:`d` dimensions.
28+
#
29+
# To do this, we first need to compute the cost matrices for each problem.
30+
#
31+
# .. note::
32+
# A straightforward approach would be to use a Python loop and
33+
# :func:`ot.dist`.
34+
# However, this is inefficient when working with batches.
35+
#
36+
# Instead, you can directly use :func:`ot.batch.dist_batch`, which computes
37+
# all cost matrices in parallel.
38+
39+
import ot
40+
import numpy as np
41+
42+
n_problems = 4 # nb problems/batch size
43+
n_samples = 8 # nb samples
44+
dim = 2 # nb dimensions
45+
46+
np.random.seed(0)
47+
samples_source = np.random.randn(n_problems, n_samples, dim)
48+
samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim)
49+
50+
# Naive approach
51+
M_list = []
52+
for i in range(n_problems):
53+
M_list.append(
54+
ot.dist(samples_source[i], samples_target[i])
55+
) # List of cost matrices n_samples x n_samples
56+
# Batched approach
57+
M_batch = ot.batch.dist_batch(
58+
samples_source, samples_target
59+
) # Array of cost matrices n_problems x n_samples x n_samples
60+
61+
for i in range(n_problems):
62+
assert np.allclose(M_list[i], M_batch[i])
63+
64+
#############################################################################
65+
#
66+
# Solving the Problems
67+
# ---------------------------------------------
68+
#
69+
# Once the cost matrices are computed, we can solve the corresponding
70+
# optimal transport problems.
71+
#
72+
# .. note::
73+
# One option is to solve them sequentially with a Python loop using
74+
# :func:`ot.solve`.
75+
# This is simple but inefficient for large batches.
76+
#
77+
# Instead, you can use :func:`ot.batch.solve_batch`, which solves all
78+
# problems in parallel.
79+
80+
reg = 1.0
81+
max_iter = 100
82+
tol = 1e-3
83+
84+
# Naive approach
85+
results_values_list = []
86+
for i in range(n_problems):
87+
res = ot.solve(M_list[i], reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy")
88+
results_values_list.append(res.value_linear)
89+
90+
# Batched approach
91+
results_batch = ot.batch.solve_batch(
92+
M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy"
93+
)
94+
results_values_batch = results_batch.value_linear
95+
96+
assert np.allclose(np.array(results_values_list), results_values_batch, atol=tol * 10)
97+
98+
#############################################################################
99+
#
100+
# Comparing Computation Time
101+
# ---------------------------------------------
102+
#
103+
# We now compare the runtime of the two approaches on larger problems.
104+
#
105+
# .. note::
106+
# The speedup obtained with :mod:`ot.batch` can be even more
107+
# significant when computations are performed on a GPU.
108+
109+
110+
from time import perf_counter
111+
112+
n_problems = 128
113+
n_samples = 8
114+
dim = 2
115+
reg = 10.0
116+
max_iter = 1000
117+
tol = 1e-3
118+
119+
samples_source = np.random.randn(n_problems, n_samples, dim)
120+
samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim)
121+
122+
123+
def benchmark_naive(samples_source, samples_target):
124+
start = perf_counter()
125+
for i in range(n_problems):
126+
M = ot.dist(samples_source[i], samples_target[i])
127+
res = ot.solve(M, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy")
128+
end = perf_counter()
129+
return end - start
130+
131+
132+
def benchmark_batch(samples_source, samples_target):
133+
start = perf_counter()
134+
M_batch = ot.batch.dist_batch(samples_source, samples_target)
135+
res_batch = ot.batch.solve_batch(
136+
M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy"
137+
)
138+
end = perf_counter()
139+
return end - start
140+
141+
142+
time_naive = benchmark_naive(samples_source, samples_target)
143+
time_batch = benchmark_batch(samples_source, samples_target)
144+
145+
print(f"Naive approach time: {time_naive:.4f} seconds")
146+
print(f"Batched approach time: {time_batch:.4f} seconds")
147+
148+
#############################################################################
149+
#
150+
# Gromov-Wasserstein
151+
# ---------------------------------------------
152+
#
153+
# The :mod:`ot.batch` module also provides a batched Gromov-Wasserstein solver.
154+
#
155+
# .. note::
156+
# This solver is **not** equivalent to calling :func:`ot.solve_gromov`
157+
# repeatedly in a loop.
158+
#
159+
# Key differences:
160+
#
161+
# - :func:`ot.solve_gromov`
162+
# Uses the conditional gradient algorithm. Each inner iteration relies on
163+
# an exact EMD solver.
164+
#
165+
# - :func:`ot.batch.solve_gromov_batch`
166+
# Uses a proximal variant, where each inner iteration applies entropic
167+
# regularization.
168+
#
169+
# As a result:
170+
#
171+
# - :func:`ot.solve_gromov` is usually faster on CPU
172+
# - :func:`ot.batch.solve_gromov_batch` is slower on CPU, but provides
173+
# better objective values.
174+
#
175+
# .. tip::
176+
# If your data is on a GPU, :func:`ot.batch.solve_gromov_batch`
177+
# is significantly faster AND provides better objective values.
178+
179+
from ot import solve_gromov
180+
from ot.batch import solve_gromov_batch
181+
182+
183+
def benchmark_naive_gw(samples_source, samples_target):
184+
start = perf_counter()
185+
avg_value = 0
186+
for i in range(n_problems):
187+
C1 = ot.dist(samples_source[i], samples_source[i])
188+
C2 = ot.dist(samples_target[i], samples_target[i])
189+
res = solve_gromov(C1, C2, max_iter=1000, tol=tol)
190+
avg_value += res.value
191+
avg_value /= n_problems
192+
end = perf_counter()
193+
return end - start, avg_value
194+
195+
196+
def benchmark_batch_gw(samples_source, samples_target):
197+
start = perf_counter()
198+
C1_batch = ot.batch.dist_batch(samples_source, samples_source)
199+
C2_batch = ot.batch.dist_batch(samples_target, samples_target)
200+
res_batch = solve_gromov_batch(
201+
C1_batch, C2_batch, reg=1, max_iter=100, max_iter_inner=50, tol=tol
202+
)
203+
avg_value = np.mean(res_batch.value)
204+
end = perf_counter()
205+
return end - start, avg_value
206+
207+
208+
time_naive_gw, avg_value_naive_gw = benchmark_naive_gw(samples_source, samples_target)
209+
time_batch_gw, avg_value_batch_gw = benchmark_batch_gw(samples_source, samples_target)
210+
211+
print(f"{'Method':<20}{'Time (s)':<15}{'Avg Value':<15}")
212+
print(f"{'Naive GW':<20}{time_naive_gw:<15.4f}{avg_value_naive_gw:<15.4f}")
213+
print(f"{'Batched GW':<20}{time_batch_gw:<15.4f}{avg_value_batch_gw:<15.4f}")
214+
215+
#############################################################################
216+
#
217+
# In summary: no more for loops!
218+
# ---------------------------------------------
219+
220+
import matplotlib.pyplot as plt
221+
222+
fig, ax = plt.subplots(figsize=(4, 4))
223+
ax.text(0.5, 0.5, "For", fontsize=160, ha="center", va="center", zorder=0)
224+
ax.axis("off")
225+
ax.plot([0, 1], [0, 1], color="red", linewidth=10, zorder=1)
226+
ax.plot([0, 1], [1, 0], color="red", linewidth=10, zorder=1)
227+
plt.show()

‎examples/index.rst‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ Differentiable OT with PyTorch
3333
../../examples/gaussian_gmm/plot_GMM_flow.py
3434
../../examples/gromov/plot_gnn_TFGW.py
3535

36-
3736
Gromov-Wasserstein (GW) and Fused GW
3837
------------------------------------
3938

‎ot/__init__.py‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from . import lowrank
3838
from . import gmm
3939

40-
4140
# OT functions
4241
from .lp import (
4342
emd,
@@ -73,6 +72,8 @@
7372
from .solvers import solve, solve_gromov, solve_sample
7473
from .lowrank import lowrank_sinkhorn
7574

75+
from .batch import solve_batch, solve_gromov_batch
76+
7677
# utils functions
7778
from .utils import dist, unif, tic, toc, toq
7879

@@ -136,4 +137,6 @@
136137
"sliced_wasserstein_sphere_unif",
137138
"lowrank_sinkhorn",
138139
"lowrank_gromov_wasserstein_samples",
140+
"solve_batch",
141+
"solve_gromov_batch",
139142
]

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /