-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
There appears to be an issue related to jax.lax.scan, array_ref, and jnp.array().
The following code raises a ValueError: JaxprInputEffect Read<2> is invalid.
import jax
import jax.numpy as jnp
@jax.jit
def bar(x, w):
def scan_fn(x, _):
c = jnp.array([]) # <- using jnp.zeros((0,)) works fine!
o = w[...] @ x
x = jnp.concatenate([x, c], axis=-1)
return x, None
x, _ = jax.lax.scan(scan_fn, x, None, length=10)
return x
@jax.jit
def foo(w):
return bar(
jnp.zeros((1,)),
w,
)
foo(jax.array_ref(jnp.eye(1)))
It appears to be related to jnp.array([]) as using jnp.zeros((0,)) has no issue, suggesting that perhaps #16370 is not in fact resolved?
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.3.3
python: 3.13.5 (main, Jul 11 2025, 22:45:47) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='asahi', release='6.16.5-asahi', version='#1-NixOS SMP Tue Jan 1 00:00:00 UTC 1980', machine='aarch64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working