-
Notifications
You must be signed in to change notification settings - Fork 538
Closed
Labels
@samuelbx
Description
Describe the bug
TorchBackend.sqrtm relies on torch.linalg.eigh which has undefined gradients when eigvals are repeated (PyTorch's doc explains the issue).
To Reproduce
import torch from ot.backend import TorchBackend torch.set_default_dtype(torch.float64) torch.autograd.set_detect_anomaly(True) nx = TorchBackend() A = torch.eye(3, dtype=torch.float64, requires_grad=True) nx.sqrtm(A)[0, 1].backward() print('OK')
Output: RuntimeError: Function 'LinalgEighBackward0' returned nan values in its 0th output.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): MacOS
- Python version: 9
- How was POT installed (source,
pip,conda):pip