Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Jax backend: jax.errors.TracerArrayConversionError #625

Open
@act65

Description

Describe the bug

As far as I understand it, I should be able to use this library paired with jax via your backend switching (depending on the input types)? However, I am getting a jax.errors.TracerArrayConversionError which seems to be arising as POT is converting to numpy (not jax.numpy) in the backend (despite me giving only jax.numpy inputs).

To Reproduce

import jax.numpy as jnp 
from jax import random, grad
import ot as pot
key = random.PRNGKey(0)
B = 10
key, subkey = random.split(key)
x = random.normal(subkey, (B, 1))
key, subkey = random.split(key)
y = random.normal(subkey, (B, 1))
def loss_fn(x, y):
 costs = jnp.linalg.norm(x[:, None] - y[None, :], axis=-1)**2
 pi = pot.emd(
 jnp.ones(B) / B, 
 jnp.ones(B) / B, 
 costs)
 return jnp.sum(pi * costs)
g = grad(loss_fn)(x, y)
print(g)

(note the problem isn't specific to grad. it also applies to; vmap, jit, ...)

Traceback (most recent call last):
 File "/home/telfaralex/Documents/phdv2/code/sinterp/sinterp/tests/test_couplings.py", line 58, in test_grad
 g = grad(loss_fn)(x, y)
 ^^^^^^^^^^^^^^^^^^^
 File "/home/telfaralex/Documents/phdv2/code/sinterp/sinterp/tests/test_couplings.py", line 51, in loss_fn
 pi = ot_fn(
 ^^^^^^
 File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/lp/__init__.py", line 318, in emd
 M, a, b = nx.to_numpy(M, a, b)
 ^^^^^^^^^^^^^^^^^^^^
 File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 260, in to_numpy
 return [self._to_numpy(array) for array in arrays]
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 260, in <listcomp>
 return [self._to_numpy(array) for array in arrays]
 ^^^^^^^^^^^^^^^^^^^^^
 File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 1439, in _to_numpy
 return np.array(a)
 ^^^^^^^^^^^
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[10,10]
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Environment:

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: 3.11.4
  • How was POT installed (source, pip, conda): pip. v0.9.3

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

      Relationships

      None yet

      Development

      No branches or pull requests

      Issue actions

        AltStyle によって変換されたページ (->オリジナル) /