-
Notifications
You must be signed in to change notification settings - Fork 536
-
hi,
I would like to identify a transport plan given two known input grids , and output grids and apply it to a new grid however it does not work:
import ot
def compute_transport_plan(tensor1, tensor2):
"""
Compute the optimal transport plan from tensor1 to tensor2.
"""
tensor1 = tensor1.flatten()
tensor2 = tensor2.flatten()
# Assuming uniform weights for the discrete distributions
a = np.ones(len(tensor1)) / len(tensor1)
b = np.ones(len(tensor2)) / len(tensor2)
M = ot.dist(tensor1[:, None], tensor2[:, None], metric='euclidean')
# Compute the optimal transport plan
transport_plan = ot.emd(a, b, M)
return transport_plan
def apply_transport_plan(tensor, transport_plan, target_shape):
"""
Apply the transport plan to transform the tensor.
"""
tensor = tensor.flatten()
# Multiplication might work without reshape depending on tensor shapes
transformed_tensor = np.dot(transport_plan, tensor)
# Reshape to target shape only if necessary
if len(transformed_tensor.shape) != len(target_shape):
transformed_tensor = transformed_tensor.reshape(target_shape)
return transformed_tensor
# Example usage
tensor1 = np.array([[1, 2, 3], [4, 5, 6]])
tensor2 = np.array([[7, 8, 9], [10, 11, 12]])
new_tensor = np.array([[13, 14, 15], [16, 17, 18]])
# Compute Optimal Transport Plan
transport_plan = compute_transport_plan(tensor1, tensor2)
print(f"Optimal Transport Plan: \n{transport_plan}")
# Apply the Transport Plan to a New Tensor
transformed_tensor = apply_transport_plan(new_tensor, transport_plan, tensor2.shape)
print(f"Transformed Tensor: \n{transformed_tensor}")
I have this result :
Optimal Transport Plan:
[[0. 0. 0. 0. 0.16666667 0. ]
[0. 0.16666667 0. 0. 0. 0. ]
[0. 0. 0.16666667 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.16666667]
[0.16666667 0. 0. 0. 0. 0. ]
[0. 0. 0. 0.16666667 0. 0. ]]
Transformed Tensor:
[[2.83333333 2.33333333 2.5 ]
[3. 2.16666667 2.66666667]]
I expected [19,20,21],[22,23,24]
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 1 comment
-
Hello,
First note that the size of the matrix M given to ot.emdis n_s x n_t where n_s and n_t are the number of source samples and target respectively. From your tensors it seems that you have two samples in R^3. So I'm not sure what you are doing with the flatten that create 6 samples our of your 2D matrices.
See the following example for samples in 2D:
https://pythonot.github.io/auto_examples/plot_OT_2D_samples.html#sphx-glr-auto-examples-plot-ot-2d-samples-py
Second the OT plan can be indeed used for transporting points using what is called barycentric mapping. But it can be applied only to the samples used to estimate the OT plan. You cannot apply the OT plan to new samples that is probably why you are not happy with the result.
Beta Was this translation helpful? Give feedback.