-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
xmap
appears to turn static arguments into traced shapearrays. This can cause problems when one needs an argument to be static for a test.
The following code demonstrate the issue, first using vmap
(no problem) then replacing it with xmap
which fails (Abstract tracer value encountered where concrete value is expected
) when testing the value of cond
at the very beginning of the func
function:
import jax
import jax.numpy as jnp
from jax.experimental.maps import xmap
# dummy function
def func(cond, data):
if cond: return data
else: return data
# vectorisation
func_vmap = jax.vmap(func, in_axes=[None, 0], out_axes=0)
func_xmap = xmap(func, in_axes=[[...], ['axis']], out_axes=['axis'])
# jit compiling
func_vmap_jit = jax.jit(func_vmap, static_argnames=['cond'])
func_xmap_jit = jax.jit(func_xmap, static_argnames=['cond'])
# running
cond = True
data = jnp.ones(100)
out_vmap = func_vmap_jit(cond, data)
out_xmap = func_xmap_jit(cond, data)
This happens with the very last version of Jax (0.3.13).
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working