|
| 1 | +""" |
| 2 | +================================================= |
| 3 | +Solving Many Optimal Transport Problems in Parallel |
| 4 | +================================================= |
| 5 | + |
| 6 | +In some situations, one may want to solve many OT problems with the same |
| 7 | +structure (same number of samples, same cost function, etc.) at the same time. |
| 8 | + |
| 9 | +In that case using a for loop to solve the problems sequentially is inefficient. |
| 10 | +This example shows how to use the batch solvers implemented in POT to solve |
| 11 | +many problems in parallel on CPU or GPU (even more efficient on GPU). |
| 12 | + |
| 13 | +""" |
| 14 | + |
| 15 | +# Author: Paul Krzakala <paul.krzakala@gmail.com> |
| 16 | +# License: MIT License |
| 17 | + |
| 18 | +# sphinx_gallery_thumbnail_number = 1 |
| 19 | + |
| 20 | + |
| 21 | +############################################################################# |
| 22 | +# |
| 23 | +# Computing the Cost Matrices |
| 24 | +# --------------------------------------------- |
| 25 | +# |
| 26 | +# We want to create a batch of optimal transport problems with |
| 27 | +# :math:`n` samples in :math:`d` dimensions. |
| 28 | +# |
| 29 | +# To do this, we first need to compute the cost matrices for each problem. |
| 30 | +# |
| 31 | +# .. note:: |
| 32 | +# A straightforward approach would be to use a Python loop and |
| 33 | +# :func:`ot.dist`. |
| 34 | +# However, this is inefficient when working with batches. |
| 35 | +# |
| 36 | +# Instead, you can directly use :func:`ot.batch.dist_batch`, which computes |
| 37 | +# all cost matrices in parallel. |
| 38 | + |
| 39 | +import ot |
| 40 | +import numpy as np |
| 41 | + |
| 42 | +n_problems = 4 # nb problems/batch size |
| 43 | +n_samples = 8 # nb samples |
| 44 | +dim = 2 # nb dimensions |
| 45 | + |
| 46 | +np.random.seed(0) |
| 47 | +samples_source = np.random.randn(n_problems, n_samples, dim) |
| 48 | +samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim) |
| 49 | + |
| 50 | +# Naive approach |
| 51 | +M_list = [] |
| 52 | +for i in range(n_problems): |
| 53 | + M_list.append( |
| 54 | + ot.dist(samples_source[i], samples_target[i]) |
| 55 | + ) # List of cost matrices n_samples x n_samples |
| 56 | +# Batched approach |
| 57 | +M_batch = ot.batch.dist_batch( |
| 58 | + samples_source, samples_target |
| 59 | +) # Array of cost matrices n_problems x n_samples x n_samples |
| 60 | + |
| 61 | +for i in range(n_problems): |
| 62 | + assert np.allclose(M_list[i], M_batch[i]) |
| 63 | + |
| 64 | +############################################################################# |
| 65 | +# |
| 66 | +# Solving the Problems |
| 67 | +# --------------------------------------------- |
| 68 | +# |
| 69 | +# Once the cost matrices are computed, we can solve the corresponding |
| 70 | +# optimal transport problems. |
| 71 | +# |
| 72 | +# .. note:: |
| 73 | +# One option is to solve them sequentially with a Python loop using |
| 74 | +# :func:`ot.solve`. |
| 75 | +# This is simple but inefficient for large batches. |
| 76 | +# |
| 77 | +# Instead, you can use :func:`ot.batch.solve_batch`, which solves all |
| 78 | +# problems in parallel. |
| 79 | + |
| 80 | +reg = 1.0 |
| 81 | +max_iter = 100 |
| 82 | +tol = 1e-3 |
| 83 | + |
| 84 | +# Naive approach |
| 85 | +results_values_list = [] |
| 86 | +for i in range(n_problems): |
| 87 | + res = ot.solve(M_list[i], reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy") |
| 88 | + results_values_list.append(res.value_linear) |
| 89 | + |
| 90 | +# Batched approach |
| 91 | +results_batch = ot.batch.solve_batch( |
| 92 | + M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy" |
| 93 | +) |
| 94 | +results_values_batch = results_batch.value_linear |
| 95 | + |
| 96 | +assert np.allclose(np.array(results_values_list), results_values_batch, atol=tol * 10) |
| 97 | + |
| 98 | +############################################################################# |
| 99 | +# |
| 100 | +# Comparing Computation Time |
| 101 | +# --------------------------------------------- |
| 102 | +# |
| 103 | +# We now compare the runtime of the two approaches on larger problems. |
| 104 | +# |
| 105 | +# .. note:: |
| 106 | +# The speedup obtained with :mod:`ot.batch` can be even more |
| 107 | +# significant when computations are performed on a GPU. |
| 108 | + |
| 109 | + |
| 110 | +from time import perf_counter |
| 111 | + |
| 112 | +n_problems = 128 |
| 113 | +n_samples = 8 |
| 114 | +dim = 2 |
| 115 | +reg = 10.0 |
| 116 | +max_iter = 1000 |
| 117 | +tol = 1e-3 |
| 118 | + |
| 119 | +samples_source = np.random.randn(n_problems, n_samples, dim) |
| 120 | +samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim) |
| 121 | + |
| 122 | + |
| 123 | +def benchmark_naive(samples_source, samples_target): |
| 124 | + start = perf_counter() |
| 125 | + for i in range(n_problems): |
| 126 | + M = ot.dist(samples_source[i], samples_target[i]) |
| 127 | + res = ot.solve(M, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy") |
| 128 | + end = perf_counter() |
| 129 | + return end - start |
| 130 | + |
| 131 | + |
| 132 | +def benchmark_batch(samples_source, samples_target): |
| 133 | + start = perf_counter() |
| 134 | + M_batch = ot.batch.dist_batch(samples_source, samples_target) |
| 135 | + res_batch = ot.batch.solve_batch( |
| 136 | + M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy" |
| 137 | + ) |
| 138 | + end = perf_counter() |
| 139 | + return end - start |
| 140 | + |
| 141 | + |
| 142 | +time_naive = benchmark_naive(samples_source, samples_target) |
| 143 | +time_batch = benchmark_batch(samples_source, samples_target) |
| 144 | + |
| 145 | +print(f"Naive approach time: {time_naive:.4f} seconds") |
| 146 | +print(f"Batched approach time: {time_batch:.4f} seconds") |
| 147 | + |
| 148 | +############################################################################# |
| 149 | +# |
| 150 | +# Gromov-Wasserstein |
| 151 | +# --------------------------------------------- |
| 152 | +# |
| 153 | +# The :mod:`ot.batch` module also provides a batched Gromov-Wasserstein solver. |
| 154 | +# |
| 155 | +# .. note:: |
| 156 | +# This solver is **not** equivalent to calling :func:`ot.solve_gromov` |
| 157 | +# repeatedly in a loop. |
| 158 | +# |
| 159 | +# Key differences: |
| 160 | +# |
| 161 | +# - :func:`ot.solve_gromov` |
| 162 | +# Uses the conditional gradient algorithm. Each inner iteration relies on |
| 163 | +# an exact EMD solver. |
| 164 | +# |
| 165 | +# - :func:`ot.batch.solve_gromov_batch` |
| 166 | +# Uses a proximal variant, where each inner iteration applies entropic |
| 167 | +# regularization. |
| 168 | +# |
| 169 | +# As a result: |
| 170 | +# |
| 171 | +# - :func:`ot.solve_gromov` is usually faster on CPU |
| 172 | +# - :func:`ot.batch.solve_gromov_batch` is slower on CPU, but provides |
| 173 | +# better objective values. |
| 174 | +# |
| 175 | +# .. tip:: |
| 176 | +# If your data is on a GPU, :func:`ot.batch.solve_gromov_batch` |
| 177 | +# is significantly faster AND provides better objective values. |
| 178 | + |
| 179 | +from ot import solve_gromov |
| 180 | +from ot.batch import solve_gromov_batch |
| 181 | + |
| 182 | + |
| 183 | +def benchmark_naive_gw(samples_source, samples_target): |
| 184 | + start = perf_counter() |
| 185 | + avg_value = 0 |
| 186 | + for i in range(n_problems): |
| 187 | + C1 = ot.dist(samples_source[i], samples_source[i]) |
| 188 | + C2 = ot.dist(samples_target[i], samples_target[i]) |
| 189 | + res = solve_gromov(C1, C2, max_iter=1000, tol=tol) |
| 190 | + avg_value += res.value |
| 191 | + avg_value /= n_problems |
| 192 | + end = perf_counter() |
| 193 | + return end - start, avg_value |
| 194 | + |
| 195 | + |
| 196 | +def benchmark_batch_gw(samples_source, samples_target): |
| 197 | + start = perf_counter() |
| 198 | + C1_batch = ot.batch.dist_batch(samples_source, samples_source) |
| 199 | + C2_batch = ot.batch.dist_batch(samples_target, samples_target) |
| 200 | + res_batch = solve_gromov_batch( |
| 201 | + C1_batch, C2_batch, reg=1, max_iter=100, max_iter_inner=50, tol=tol |
| 202 | + ) |
| 203 | + avg_value = np.mean(res_batch.value) |
| 204 | + end = perf_counter() |
| 205 | + return end - start, avg_value |
| 206 | + |
| 207 | + |
| 208 | +time_naive_gw, avg_value_naive_gw = benchmark_naive_gw(samples_source, samples_target) |
| 209 | +time_batch_gw, avg_value_batch_gw = benchmark_batch_gw(samples_source, samples_target) |
| 210 | + |
| 211 | +print(f"{'Method':<20}{'Time (s)':<15}{'Avg Value':<15}") |
| 212 | +print(f"{'Naive GW':<20}{time_naive_gw:<15.4f}{avg_value_naive_gw:<15.4f}") |
| 213 | +print(f"{'Batched GW':<20}{time_batch_gw:<15.4f}{avg_value_batch_gw:<15.4f}") |
| 214 | + |
| 215 | +############################################################################# |
| 216 | +# |
| 217 | +# In summary: no more for loops! |
| 218 | +# --------------------------------------------- |
| 219 | + |
| 220 | +import matplotlib.pyplot as plt |
| 221 | + |
| 222 | +fig, ax = plt.subplots(figsize=(4, 4)) |
| 223 | +ax.text(0.5, 0.5, "For", fontsize=160, ha="center", va="center", zorder=0) |
| 224 | +ax.axis("off") |
| 225 | +ax.plot([0, 1], [0, 1], color="red", linewidth=10, zorder=1) |
| 226 | +ax.plot([0, 1], [1, 0], color="red", linewidth=10, zorder=1) |
| 227 | +plt.show() |
0 commit comments