Skip to content

core.call_bind aggressively raises args to top trace #2833

@trevorcai

Description

@trevorcai
def f(a_bool, y):
  if a_bool:
    return y + 1
  else:
    return y

jax.jit(jax.remat(f), static_argnums=0)(True, 1)

Results in:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(bool[], weak_type=True):JaxprTrace(level=-1/2)>

I think this arises from the full_raise occurring here when processing remat_call_p - which raise to the JaxprTrace when we told JIT we don't want to!
https://github.com/google/jax/blob/77901e9fa71f5b23066c70132a983ae57f655b39/jax/core.py#L1001

This also applies to user-defined call primitives using core.call_bind, resulting in unnecessary workarounds like this one in Haiku:
https://github.com/deepmind/dm-haiku/blob/49b21f7192dfdb3dc0a49cc097c8d3b0ccabb107/haiku/_src/named_call.py#L101-L109

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions