Skip to content

Some tracer level is None when vmap'ing a custom_jvp #3822

@brianwa84

Description

@brianwa84
!pip install -U -q jax jaxlib
!pip freeze

import jax
import jax.numpy as np

alpha = np.float32(2.)

def sample(seed):
  @jax.custom_jvp
  def f(alpha):
    return jax.random.gamma(seed, alpha, shape=[])

  @f.defjvp
  def f_jvp(primal, tangent):
    alpha = primal
    dalpha = tangent
    sample = f(alpha)
    partial_alpha = jax.lax.random_gamma_grad(alpha, sample)
    return sample, partial_alpha * dalpha
  return f(alpha)

jax.vmap(sample)(jax.random.split(jax.random.PRNGKey(1), 3))

=>

...
jax==0.1.72
jaxlib==0.1.51
...
/usr/local/lib/python3.6/dist-packages/jax/core.py in <listcomp>(.0)
   1078   todo = []
   1079   while True:
-> 1080     tracers = [x for x in outs if isinstance(x, Tracer) and x._trace.level > level]
   1081     if tracers:
   1082       ans = max(tracers, key=lambda x: x._trace.level)

TypeError: '>' not supported between instances of 'int' and 'NoneType'

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions