-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
The following program crashes with jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape () and dtype float32 to escape.
import jax
import jax.lax as lax
import functools as ft
@ft.partial(jax.custom_jvp, nondiff_argnums=(0,))
def f(x, y):
print("f")
return y
@f.defjvp
def f_jvp(x, y, tang_y):
print("f_jvp")
x + 1 # Crashes on this line
(y,) = y
(tang_y,) = tang_y
return y, tang_y
def g(y, x):
return lax.cond(x < y, f, lambda _x, _y: _y, x, y)
jax.grad(g)(1.0, 1.0)
For reference the above program prints
f
f_jvp
(and then crashes.)
I think what is going on is that the "caching" of the nondiff_argnum x
between f
and f_jvp
, combined with the tracing performed by cond
on its function-valued arguments, is triggering the unexpected tracer machinery.
Incidentally, it's interesting (and perhaps not terribly efficient?) that both f
and f_jvp
are called. My mental model was that only f_jvp
would be called when autodifferentiating a program. I think this is probably due to some aspect of how lax.cond
traces its function-valued arguments.
JAX version 0.22.7