-
Couldn't load subscription status.
- Fork 537
-
Hi there, I am trying to implement a free support barycenter function that uses custom distance functions for its calculation. I want to use elastic distance functions in the form of ot.lp.free_support_barycenter function and only changed the definition of M_i = dist(X, measure_locations_i) to M_i = cdist(X, measure_locations_i, metric=lambda u, v: cost_fn(u, v, gamma)).
My problem is that the resulting barycenters using my custom cost function do not differ in any way from the standard barycenters using sqeuclidean distance. I have already tried different gamma values and made sure that the distance matrix is sufficiently different to the sqeuclidean distance matrix.
So I was wondering why it does not work. I know that it has to do with the emd() function, which has the same return values independent of standard distance or custom distance but I can not make out why that is and what I can do for a proper implementation. Any help is appreciated.
Code to reproduce:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
import ot
def plot_barycenters(samples, X_barycenter):
# Plot each distribution and the barycenter
fig, ax = plt.subplots()
# Plot each distribution
for i, samples in enumerate(measures_locations):
ax.scatter(samples[:, 0], samples[:, 1], label=f'Distribution {i+1}', alpha=0.5)
# Plot the barycenter
ax.scatter(X_barycenter[:, 0], X_barycenter[:, 1], c='red', label='Barycenter', marker='x')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.legend()
plt.show()
def dist_l2(x, y, *args):
return np.sum((x-y)**2)
def dist_l2_l1(x, y, gamma=0.7, *args):
return dist_l2(x,y) + gamma * np.sum(np.abs(x-y))
# Parameters
d = 2 # number of distributions
k = 200 # number of samples per distribution
N = 2 # dimensionality of each sample (changed to 2 for 2D plotting)
# Set fixed means
fixed_means = [np.array([2, 10]), np.array([6, 2])]
# Generate measure locations for d distributions
measures_locations = []
for i in range(d):
# mean = np.random.rand(N) * 10
mean = fixed_means[i]
cov = np.eye(N) # Identity covariance matrix
samples = np.random.multivariate_normal(mean, cov, k)
measures_locations.append(samples)
# Example of measures_weights (assuming uniform weights for simplicity)
measures_weights = [np.ones((k,)) / k for _ in range(d)]
def custom_free_support_barycenter(measures_locations, measures_weights, X_init,
cost_fn, gamma = 0.5, A = None,
a = None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, numThreads=1,):
nx = ot.backend.get_backend(*measures_locations, *measures_weights, X_init)
iter_count = 0
N = len(measures_locations)
k = X_init.shape[0]
d = X_init.shape[1]
if a is None:
b = nx.ones((k,), type_as=X_init) / k
if weights is None:
weights = nx.ones((N,), type_as=X_init) / N
X = X_init.copy()
displacement_square_norms = []
displacement_square_norm = stopThr + 1.0
# for iteration in range(numItermax):
while displacement_square_norm > stopThr and iter_count < numItermax:
T_sum = nx.zeros((k, d), type_as=X_init)
for measure_locations_i, measure_weights_i, weight_i in zip(
measures_locations, measures_weights, weights
):
if type(cost_fn) == str:
M_i = cdist(X, measure_locations_i, metric=cost_fn)
else:
M_i = cdist(X, measure_locations_i, metric=lambda u, v: cost_fn(u, v, gamma)) # custom cost matrix
T_i = ot.lp.emd(a, measure_weights_i, M_i, numThreads=numThreads)
T_sum = T_sum + weight_i * 1.0 / a[:, None] * nx.dot(T_i, measure_locations_i)
displacement_square_norm = nx.sum((T_sum - X) ** 2)
X = T_sum
if verbose:
print("iteration {:d}, displacement_square_norm={:.6f}".format(iter_count, displacement_square_norm))
iter_count += 1
return X
# Initial barycenters
X_init = np.random.normal(0.0, 1.0, (k, N))
# Compute barycenters
X_barycenter = custom_free_support_barycenter(measures_locations, measures_weights, X_init,
cost_fn = 'sqeuclidean',
a = np.ones((k,)) / k, verbose=False)
X_barycenter_l2_l1 = custom_free_support_barycenter(measures_locations,measures_weights, X_init,
cost_fn= dist_l2_l1, gamma = 1000,
a = np.ones((k,)) / k, verbose=False)
# Plot barycenters
plot_barycenters(samples, X_barycenter)
plot_barycenters(samples, X_barycenter_l2_l1)
Beta Was this translation helpful? Give feedback.