Skip to content

jax.custom_jvp(nondiff_argnums=...) crashes inside of lax.cond. #9374

@patrick-kidger

Description

@patrick-kidger

The following program crashes with jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape () and dtype float32 to escape.

import jax
import jax.lax as lax
import functools as ft

@ft.partial(jax.custom_jvp, nondiff_argnums=(0,))
def f(x, y):
    print("f")
    return y

@f.defjvp
def f_jvp(x, y, tang_y):
    print("f_jvp")
    x + 1  # Crashes on this line
    (y,) = y
    (tang_y,) = tang_y
    return y, tang_y

def g(y, x):
    return lax.cond(x < y, f, lambda _x, _y: _y, x, y)

jax.grad(g)(1.0, 1.0)

For reference the above program prints

f
f_jvp

(and then crashes.)


I think what is going on is that the "caching" of the nondiff_argnum x between f and f_jvp, combined with the tracing performed by cond on its function-valued arguments, is triggering the unexpected tracer machinery.

Incidentally, it's interesting (and perhaps not terribly efficient?) that both f and f_jvp are called. My mental model was that only f_jvp would be called when autodifferentiating a program. I think this is probably due to some aspect of how lax.cond traces its function-valued arguments.

JAX version 0.22.7

Metadata

Metadata

Assignees

Labels

better_errorsImprove the error reportingbugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions