-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
def f(a_bool, y):
if a_bool:
return y + 1
else:
return y
jax.jit(jax.remat(f), static_argnums=0)(True, 1)
Results in:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(bool[], weak_type=True):JaxprTrace(level=-1/2)>
I think this arises from the full_raise occurring here when processing remat_call_p
- which raise to the JaxprTrace when we told JIT we don't want to!
https://github.com/google/jax/blob/77901e9fa71f5b23066c70132a983ae57f655b39/jax/core.py#L1001
This also applies to user-defined call primitives using core.call_bind
, resulting in unnecessary workarounds like this one in Haiku:
https://github.com/deepmind/dm-haiku/blob/49b21f7192dfdb3dc0a49cc097c8d3b0ccabb107/haiku/_src/named_call.py#L101-L109
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working