Skip to content

Bug handling array_ref with jax.lax.scan and literals #32399

@pfrommerd

Description

@pfrommerd

Description

There appears to be an issue related to jax.lax.scan, array_ref, and jnp.array().
The following code raises a ValueError: JaxprInputEffect Read<2> is invalid.

import jax
import jax.numpy as jnp


@jax.jit
def bar(x, w):
    def scan_fn(x, _):
        c = jnp.array([]) # <- using jnp.zeros((0,)) works fine!
        o = w[...] @ x
        x = jnp.concatenate([x, c], axis=-1)
        return x, None

    x, _ = jax.lax.scan(scan_fn, x, None, length=10)
    return x


@jax.jit
def foo(w):
    return bar(
        jnp.zeros((1,)),
        w,
    )


foo(jax.array_ref(jnp.eye(1)))

It appears to be related to jnp.array([]) as using jnp.zeros((0,)) has no issue, suggesting that perhaps #16370 is not in fact resolved?

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.3.3
python: 3.13.5 (main, Jul 11 2025, 22:45:47) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='asahi', release='6.16.5-asahi', version='#1-NixOS SMP Tue Jan 1 00:00:00 UTC 1980', machine='aarch64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions