Skip to content

xmap doesn't preserve static argnums #10741

@nestordemeure

Description

@nestordemeure

xmap appears to turn static arguments into traced shapearrays. This can cause problems when one needs an argument to be static for a test.

The following code demonstrate the issue, first using vmap (no problem) then replacing it with xmap which fails (Abstract tracer value encountered where concrete value is expected) when testing the value of cond at the very beginning of the func function:

import jax
import jax.numpy as jnp
from jax.experimental.maps import xmap

# dummy function
def func(cond, data):
   if cond: return data
   else: return data

# vectorisation
func_vmap = jax.vmap(func, in_axes=[None, 0], out_axes=0)
func_xmap = xmap(func, in_axes=[[...], ['axis']], out_axes=['axis'])

# jit compiling
func_vmap_jit = jax.jit(func_vmap, static_argnames=['cond'])
func_xmap_jit = jax.jit(func_xmap, static_argnames=['cond'])

# running
cond = True
data = jnp.ones(100)
out_vmap = func_vmap_jit(cond, data)
out_xmap = func_xmap_jit(cond, data)

This happens with the very last version of Jax (0.3.13).

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions