2

I'm trying to vmap a function. My understanding of vmap is essentially anywhere I would write a ~for loop/list comprehension I should instead consider vmapping. I have a few points of confusion:

  1. Does vmap need fixed sizes for everything through the function(s) being vmapped?
  2. Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).
  3. If vmap is jit-ing something, how would one use something like a static-arguments with vmap?
  4. What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)

The reason I'm asking is that it seems like vmap, like JIT, runs into all sorts of ConcretizationTypeError and seems (not 100% clear yet) to need constant sized items for everything. I associate this behavior with any function I'm trying to Jit, but not necessarily any function I write in Jax.

Nimantha
6,5396 gold badges32 silver badges78 bronze badges
asked Nov 5, 2023 at 22:33
3
  • 2
    I'd suggest splitting the post into multiple independent questions. Commented Nov 6, 2023 at 10:06
  • For question 1, I recommend reading the Dynamic Shapes section of "JAX - The Sharp Bits" in the JAX documentation. Commented Nov 6, 2023 at 10:07
  • 1
    For question 3, you can set the in_axes corresponding to the static argument to None. More information can be found by the discussion here. Commented Nov 6, 2023 at 10:12

2 Answers 2

3

Does vmap need fixed sizes for everything through the function(s) being vmapped?

yes – vmap, like all JAX transformations, requires any arrays defined in the function to have static shapes.

Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).

No, vmap does not jit-compile a function by default, although you can always compose both if you wish (e.g. jit(vmap(f)))

If vmap is jit-ing something, how would one use something like a static-arguments with vmap?

As mentioned, vmap is unrelated to jit, but an analogy of jit static_argnums is passing None to in_axes, which will keep the argument unmapped and therefore static within the transformation.

What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)

blackgreen
47k29 gold badges172 silver badges164 bronze badges
answered Nov 6, 2023 at 14:55
Sign up to request clarification or add additional context in comments.

3 Comments

Does vmap jit its argument? That is, is vmap(f) equivalent to vmap(jit(f)) (as opposed to the jit(vmap(f)) in this answer)?
No, vmap does not result in JIT compilation. If you want your vmapped function to be JIT-compiled, you must explicitly wrap it in jit.
Thanks! One more question: Does jit(vmap(f)) ~ jit(vmap(jit(f))? That is, does it make any difference (speed or otherwise) to jit inside the vmap if one jits outside the vmap?
0

A section of my code now looks like:

vmaped_f = jax.vmap(my_func, parallel_axes, 0)
n_batches = int(num_items / batch_size)
if num_items % batch_size != 0:
 n_batches += 1 #Round up
 
all_vals = []
for i in range(n_batches):
 top = min([num_items, (i+1)*batch_size])
 batch_inds = jnp.arange(i*batch_size, top)
 batch_inds_1, batch_inds_2 = jnp.array(inds_1)[batch_inds], \
 jnp.array(inds_2)[batch_inds]
 f_vals = vmaped_f(batch_inds_1, batch_inds2, other_relevant_inputs)
 all_vals.extend(f_vals.tolist())

The vmap'd function basically takes in all of my data, and the indices of that data to use (which will be constant sized except for potentially the last batch, so only need to jit compile 2x if you'd want to jit it).

Nimantha
6,5396 gold badges32 silver badges78 bronze badges
answered Nov 24, 2023 at 5:55

1 Comment

I will note that in my particular case this runs slowly, but I don't think that's a result of the batching vmap, but just what I'm vmapping in my super-specific case.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.