Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Suggestion for efficient discrepancy estimation (and optimisation) for large weighted datasets #727

Answered by rflamary
tvercaut asked this question in Q&A
Discussion options

I am interested in maximizing the discrepancy between two large datasets where each sample is a weighted feature vector. More specifically, I would like to split a large set of feature vectors in 2 subsets and maximize the discrepancy across the two subsets (think of unsupervised 2 class clustering). The weights would act as soft assignments between the 2 groups.

The number of samples (N) could be around 1e6 and the dimension of each feature vector (D) could be around 1e2. I am thus looking for efficient approaches.

I tried the sliced Wasserstein distance:
https://pythonot.github.io/gen_modules/ot.sliced.html#ot.sliced.sliced_wasserstein_distance
It handles 1e4 points well but 1e5 is getting slow already.

I also tried the MMD implementation in geomloss (geomloss.SamplesLoss(loss="gaussian")):
https://www.kernel-operations.io/geomloss/api/pytorch-api.html#geomloss.SamplesLoss
It is significantly slower than the sliced EMD.

I could work with an approximate algorithm and am not stuck on a specific notion of discrepancy at this stage (hence the test of sliced EMD and Gaussian MMD). Is there a recommended sample loss for such large-scale problems?

I haven't looked at the gradient computation yet but will eventually need the gradient of the discrepancy measure with respect to the weights.

I also considered the Wasserstein Discriminant Analysis but it doesn't seem to support soft labels so I guess it can't be used to get a gradient with respect to the labels:
https://pythonot.github.io/gen_modules/ot.dr.html#id15

You must be logged in to vote

If you have a very large number of points you should definitely consider minibach OT. It consists in optimizing the expectation ofover minibatch with SGD. You can do that manually easily enough with sliced wasserstein or exact solver (ot.emd2/ot.solve_sample) or sinkhorn divergence from POT that are very efficient on small batches.

Replies: 1 comment 3 replies

Comment options

If you have a very large number of points you should definitely consider minibach OT. It consists in optimizing the expectation ofover minibatch with SGD. You can do that manually easily enough with sliced wasserstein or exact solver (ot.emd2/ot.solve_sample) or sinkhorn divergence from POT that are very efficient on small batches.

You must be logged in to vote
3 replies
Comment options

That's interesting. I was thinking of using random subsampling during training but hadn't thought of batching the distance computation. I did a quick try using torch.chunk and a for loop but this didn't lead to significant gains. Running the batch computation in parallel with torch.vmap helps a bit though. I get a factor 2 on the forward pass with this:

def batched_swd(X_s, X_t, w_s, w_t, num_chunks=10):
 # Note that this is quick and dirty and only works if the number of samples is divisible by num_chunks
 Xs2 = X_s.reshape(num_chunks, -1, X_s.shape[1])
 ws2 = w_s.reshape(num_chunks, -1)
 Xt2 = X_t.reshape(num_chunks, -1, X_t.shape[1])
 wt2 = w_t.reshape(num_chunks, -1)
 bswdfunc = torch.func.vmap(ot.sliced.sliced_wasserstein_distance,randomness="same")
 swds = bswdfunc(Xs2, Xt2, ws2, wt2)
 return swds.sum()

Again, just a quick test for now.

Comment options

With the sliced EMD, the main computational bottleneck seems to the the sorting operation:

POT/ot/lp/solver_1d.py

Lines 113 to 121 in 8aed2d7

if require_sort:
u_sorter = nx.argsort(u_values, 0)
u_values = nx.take_along_axis(u_values, u_sorter, 0)
v_sorter = nx.argsort(v_values, 0)
v_values = nx.take_along_axis(v_values, v_sorter, 0)
u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
v_weights = nx.take_along_axis(v_weights, v_sorter, 0)

However, the sorted projected vector is apparently "only" used through some quantiles:

POT/ot/lp/solver_1d.py

Lines 127 to 135 in 8aed2d7

u_quantiles = quantile_function(qs, u_cumweights, u_values)
v_quantiles = quantile_function(qs, v_cumweights, v_values)
qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)])
delta = qs[1:, ...] - qs[:-1, ...]
diff_quantiles = nx.abs(u_quantiles - v_quantiles)
if p == 1:
return nx.sum(delta * diff_quantiles, axis=0)
return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)

Has someone tried replacing these quantile computations with approximate ones based on a histograms? I haven't checked the validity of this approximation but this allows a significant speedup. Here is some quick and dirty code:

def approx_swd(
 X_s,
 X_t,
 a=None,
 b=None,
 n_projections=50,
 p=2,
 projections=None,
 seed=None,
 log=False,
 num_bins=100,
):
 X_s, X_t = ot.utils.list_to_array(X_s, X_t)
 if a is not None and b is not None and projections is None:
 nx = ot.utils.get_backend(X_s, X_t, a, b)
 elif a is not None and b is not None and projections is not None:
 nx = ot.utils.get_backend(X_s, X_t, a, b, projections)
 elif a is None and b is None and projections is not None:
 nx = ot.utils.get_backend(X_s, X_t, projections)
 else:
 nx = ot.utils.get_backend(X_s, X_t)
 n = X_s.shape[0]
 m = X_t.shape[0]
 if X_s.shape[1] != X_t.shape[1]:
 raise ValueError(
 "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(
 X_s.shape[1], X_t.shape[1]
 )
 )
 if a is None:
 a = nx.full(n, 1 / n, type_as=X_s)
 if b is None:
 b = nx.full(m, 1 / m, type_as=X_s)
 d = X_s.shape[1]
 if projections is None:
 projections = ot.sliced.get_random_projections(
 d, n_projections, seed, backend=nx, type_as=X_s
 )
 else:
 n_projections = projections.shape[1]
 X_s_projections = nx.dot(X_s, projections)
 X_t_projections = nx.dot(X_t, projections)
 # projected_emd = ot.lp.wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p)
 # Compute histograms
 pminv, pmaxv = torch.aminmax(torch.cat((X_s_projections,X_t_projections)))
 boundaries = torch.linspace(start=pminv, end=pmaxv+torch.finfo(torch.float32).eps, steps=num_bins+1, device=X_s.device)
 bin_idx_s = torch.bucketize(X_s_projections, boundaries[1:])
 bin_idx_t = torch.bucketize(X_t_projections, boundaries[1:])
 centres = 0.5*( boundaries[:-1] + boundaries[1:] )[:,None].expand(-1,n_projections)
 a_rep = a[:,None].expand(-1,n_projections)
 X_s_p_aggr = torch.zeros(num_bins,n_projections, device=X_s.device).scatter_reduce(0, bin_idx_s, a_rep*X_s_projections, reduce="sum", include_self=False)
 a_aggr = torch.zeros(num_bins,n_projections, device=X_s.device).scatter_reduce(0, bin_idx_s, a_rep, reduce="sum", include_self=False)
 X_s_p_aggr /= a_aggr
 zidx = (a_aggr<=torch.finfo(torch.float32).eps)
 X_s_p_aggr[zidx] = centres[zidx]
 b_rep = b[:,None].expand(-1,n_projections)
 X_t_p_aggr = torch.zeros(num_bins,n_projections, device=X_s.device).scatter_reduce(0, bin_idx_t, b_rep*X_t_projections, reduce="sum", include_self=False)
 b_aggr = torch.zeros(num_bins,n_projections, device=X_s.device).scatter_reduce(0, bin_idx_t, b_rep, reduce="sum", include_self=False)
 X_t_p_aggr /= b_aggr
 zidx = (b_aggr<=torch.finfo(torch.float32).eps)
 X_t_p_aggr[zidx] = centres[zidx]
 # Compute EMD on histograms
 projected_emd = ot.lp.wasserstein_1d(X_s_p_aggr, X_t_p_aggr, a_aggr, b_aggr, p=p, require_sort=False)
 res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p)
 if log:
 return res, {"projections": projections, "projected_emds": projected_emd}
 return res
Comment options

That's a good idea and must be much faster!
But IMHO if you accept computing an approximation of sliced you probably should be able to do it with batches (that can be computed possibly in parallel) instead of doing an approximation that remains O(n) instead of O(n log(n)). In practice people do SGD because each iteration does no requite to pass through the whole dataset. But this is obviously application dependent and studying the statistical quality of this approximation might be interesting.

Answer selected by tvercaut
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet

AltStyle によって変換されたページ (->オリジナル) /