Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Using custom cost function for free_support_barycenter #735

Unanswered
mthonack asked this question in Q&A
Discussion options

Hi there, I am trying to implement a free support barycenter function that uses custom distance functions for its calculation. I want to use elastic distance functions in the form of $d(x,y) = \frac{1}{2} |x-y|_2^2 + \gamma |x-y|_1$ and have been simply adapting the ot.lp.free_support_barycenter function and only changed the definition of M_i = dist(X, measure_locations_i) to M_i = cdist(X, measure_locations_i, metric=lambda u, v: cost_fn(u, v, gamma)).
My problem is that the resulting barycenters using my custom cost function do not differ in any way from the standard barycenters using sqeuclidean distance. I have already tried different gamma values and made sure that the distance matrix is sufficiently different to the sqeuclidean distance matrix.
So I was wondering why it does not work. I know that it has to do with the emd() function, which has the same return values independent of standard distance or custom distance but I can not make out why that is and what I can do for a proper implementation. Any help is appreciated.

Code to reproduce:

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
import ot
def plot_barycenters(samples, X_barycenter):
 # Plot each distribution and the barycenter
 fig, ax = plt.subplots()
 # Plot each distribution
 for i, samples in enumerate(measures_locations):
 ax.scatter(samples[:, 0], samples[:, 1], label=f'Distribution {i+1}', alpha=0.5)
 # Plot the barycenter
 ax.scatter(X_barycenter[:, 0], X_barycenter[:, 1], c='red', label='Barycenter', marker='x')
 ax.set_xlabel('X')
 ax.set_ylabel('Y')
 ax.legend()
 plt.show()
def dist_l2(x, y, *args):
 return np.sum((x-y)**2)
def dist_l2_l1(x, y, gamma=0.7, *args):
 return dist_l2(x,y) + gamma * np.sum(np.abs(x-y))
# Parameters
d = 2 # number of distributions
k = 200 # number of samples per distribution
N = 2 # dimensionality of each sample (changed to 2 for 2D plotting)
# Set fixed means
fixed_means = [np.array([2, 10]), np.array([6, 2])]
# Generate measure locations for d distributions
measures_locations = []
for i in range(d):
 # mean = np.random.rand(N) * 10
 mean = fixed_means[i]
 cov = np.eye(N) # Identity covariance matrix
 samples = np.random.multivariate_normal(mean, cov, k)
 measures_locations.append(samples)
# Example of measures_weights (assuming uniform weights for simplicity)
measures_weights = [np.ones((k,)) / k for _ in range(d)]
def custom_free_support_barycenter(measures_locations, measures_weights, X_init, 
 cost_fn, gamma = 0.5, A = None,
 a = None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, numThreads=1,):
 nx = ot.backend.get_backend(*measures_locations, *measures_weights, X_init)
 iter_count = 0
 N = len(measures_locations)
 k = X_init.shape[0]
 d = X_init.shape[1]
 if a is None:
 b = nx.ones((k,), type_as=X_init) / k
 if weights is None:
 weights = nx.ones((N,), type_as=X_init) / N
 X = X_init.copy()
 displacement_square_norms = []
 displacement_square_norm = stopThr + 1.0
 # for iteration in range(numItermax):
 while displacement_square_norm > stopThr and iter_count < numItermax:
 T_sum = nx.zeros((k, d), type_as=X_init)
 for measure_locations_i, measure_weights_i, weight_i in zip(
 measures_locations, measures_weights, weights
 ):
 if type(cost_fn) == str:
 M_i = cdist(X, measure_locations_i, metric=cost_fn)
 else:
 M_i = cdist(X, measure_locations_i, metric=lambda u, v: cost_fn(u, v, gamma)) # custom cost matrix
 T_i = ot.lp.emd(a, measure_weights_i, M_i, numThreads=numThreads)
 T_sum = T_sum + weight_i * 1.0 / a[:, None] * nx.dot(T_i, measure_locations_i)
 displacement_square_norm = nx.sum((T_sum - X) ** 2)
 
 X = T_sum
 if verbose:
 print("iteration {:d}, displacement_square_norm={:.6f}".format(iter_count, displacement_square_norm))
 iter_count += 1
 return X
# Initial barycenters
X_init = np.random.normal(0.0, 1.0, (k, N))
# Compute barycenters
X_barycenter = custom_free_support_barycenter(measures_locations, measures_weights, X_init,
 cost_fn = 'sqeuclidean',
 a = np.ones((k,)) / k, verbose=False)
X_barycenter_l2_l1 = custom_free_support_barycenter(measures_locations,measures_weights, X_init,
 cost_fn= dist_l2_l1, gamma = 1000, 
 a = np.ones((k,)) / k, verbose=False)
# Plot barycenters
plot_barycenters(samples, X_barycenter)
plot_barycenters(samples, X_barycenter_l2_l1)
You must be logged in to vote

Replies: 0 comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant

AltStyle によって変換されたページ (->オリジナル) /