-
Notifications
You must be signed in to change notification settings - Fork 536
-
I am interested in maximizing the discrepancy between two large datasets where each sample is a weighted feature vector. More specifically, I would like to split a large set of feature vectors in 2 subsets and maximize the discrepancy across the two subsets (think of unsupervised 2 class clustering). The weights would act as soft assignments between the 2 groups.
The number of samples (N) could be around 1e6 and the dimension of each feature vector (D) could be around 1e2. I am thus looking for efficient approaches.
I tried the sliced Wasserstein distance:
https://pythonot.github.io/gen_modules/ot.sliced.html#ot.sliced.sliced_wasserstein_distance
It handles 1e4 points well but 1e5 is getting slow already.
I also tried the MMD implementation in geomloss (geomloss.SamplesLoss(loss="gaussian")):
https://www.kernel-operations.io/geomloss/api/pytorch-api.html#geomloss.SamplesLoss
It is significantly slower than the sliced EMD.
I could work with an approximate algorithm and am not stuck on a specific notion of discrepancy at this stage (hence the test of sliced EMD and Gaussian MMD). Is there a recommended sample loss for such large-scale problems?
I haven't looked at the gradient computation yet but will eventually need the gradient of the discrepancy measure with respect to the weights.
I also considered the Wasserstein Discriminant Analysis but it doesn't seem to support soft labels so I guess it can't be used to get a gradient with respect to the labels:
https://pythonot.github.io/gen_modules/ot.dr.html#id15
Beta Was this translation helpful? Give feedback.
All reactions
If you have a very large number of points you should definitely consider minibach OT. It consists in optimizing the expectation ofover minibatch with SGD. You can do that manually easily enough with sliced wasserstein or exact solver (ot.emd2/ot.solve_sample) or sinkhorn divergence from POT that are very efficient on small batches.
Replies: 1 comment 3 replies
-
If you have a very large number of points you should definitely consider minibach OT. It consists in optimizing the expectation ofover minibatch with SGD. You can do that manually easily enough with sliced wasserstein or exact solver (ot.emd2/ot.solve_sample) or sinkhorn divergence from POT that are very efficient on small batches.
Beta Was this translation helpful? Give feedback.
All reactions
-
That's interesting. I was thinking of using random subsampling during training but hadn't thought of batching the distance computation. I did a quick try using torch.chunk and a for loop but this didn't lead to significant gains. Running the batch computation in parallel with torch.vmap helps a bit though. I get a factor 2 on the forward pass with this:
def batched_swd(X_s, X_t, w_s, w_t, num_chunks=10): # Note that this is quick and dirty and only works if the number of samples is divisible by num_chunks Xs2 = X_s.reshape(num_chunks, -1, X_s.shape[1]) ws2 = w_s.reshape(num_chunks, -1) Xt2 = X_t.reshape(num_chunks, -1, X_t.shape[1]) wt2 = w_t.reshape(num_chunks, -1) bswdfunc = torch.func.vmap(ot.sliced.sliced_wasserstein_distance,randomness="same") swds = bswdfunc(Xs2, Xt2, ws2, wt2) return swds.sum()
Again, just a quick test for now.
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1
-
With the sliced EMD, the main computational bottleneck seems to the the sorting operation:
Lines 113 to 121 in 8aed2d7
However, the sorted projected vector is apparently "only" used through some quantiles:
Lines 127 to 135 in 8aed2d7
Has someone tried replacing these quantile computations with approximate ones based on a histograms? I haven't checked the validity of this approximation but this allows a significant speedup. Here is some quick and dirty code:
def approx_swd( X_s, X_t, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, log=False, num_bins=100, ): X_s, X_t = ot.utils.list_to_array(X_s, X_t) if a is not None and b is not None and projections is None: nx = ot.utils.get_backend(X_s, X_t, a, b) elif a is not None and b is not None and projections is not None: nx = ot.utils.get_backend(X_s, X_t, a, b, projections) elif a is None and b is None and projections is not None: nx = ot.utils.get_backend(X_s, X_t, projections) else: nx = ot.utils.get_backend(X_s, X_t) n = X_s.shape[0] m = X_t.shape[0] if X_s.shape[1] != X_t.shape[1]: raise ValueError( "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( X_s.shape[1], X_t.shape[1] ) ) if a is None: a = nx.full(n, 1 / n, type_as=X_s) if b is None: b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] if projections is None: projections = ot.sliced.get_random_projections( d, n_projections, seed, backend=nx, type_as=X_s ) else: n_projections = projections.shape[1] X_s_projections = nx.dot(X_s, projections) X_t_projections = nx.dot(X_t, projections) # projected_emd = ot.lp.wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) # Compute histograms pminv, pmaxv = torch.aminmax(torch.cat((X_s_projections,X_t_projections))) boundaries = torch.linspace(start=pminv, end=pmaxv+torch.finfo(torch.float32).eps, steps=num_bins+1, device=X_s.device) bin_idx_s = torch.bucketize(X_s_projections, boundaries[1:]) bin_idx_t = torch.bucketize(X_t_projections, boundaries[1:]) centres = 0.5*( boundaries[:-1] + boundaries[1:] )[:,None].expand(-1,n_projections) a_rep = a[:,None].expand(-1,n_projections) X_s_p_aggr = torch.zeros(num_bins,n_projections, device=X_s.device).scatter_reduce(0, bin_idx_s, a_rep*X_s_projections, reduce="sum", include_self=False) a_aggr = torch.zeros(num_bins,n_projections, device=X_s.device).scatter_reduce(0, bin_idx_s, a_rep, reduce="sum", include_self=False) X_s_p_aggr /= a_aggr zidx = (a_aggr<=torch.finfo(torch.float32).eps) X_s_p_aggr[zidx] = centres[zidx] b_rep = b[:,None].expand(-1,n_projections) X_t_p_aggr = torch.zeros(num_bins,n_projections, device=X_s.device).scatter_reduce(0, bin_idx_t, b_rep*X_t_projections, reduce="sum", include_self=False) b_aggr = torch.zeros(num_bins,n_projections, device=X_s.device).scatter_reduce(0, bin_idx_t, b_rep, reduce="sum", include_self=False) X_t_p_aggr /= b_aggr zidx = (b_aggr<=torch.finfo(torch.float32).eps) X_t_p_aggr[zidx] = centres[zidx] # Compute EMD on histograms projected_emd = ot.lp.wasserstein_1d(X_s_p_aggr, X_t_p_aggr, a_aggr, b_aggr, p=p, require_sort=False) res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p) if log: return res, {"projections": projections, "projected_emds": projected_emd} return res
Beta Was this translation helpful? Give feedback.
All reactions
-
That's a good idea and must be much faster!
But IMHO if you accept computing an approximation of sliced you probably should be able to do it with batches (that can be computed possibly in parallel) instead of doing an approximation that remains O(n) instead of O(n log(n)). In practice people do SGD because each iteration does no requite to pass through the whole dataset. But this is obviously application dependent and studying the statistical quality of this approximation might be interesting.
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1