-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
better_errorsImprove the error reportingImprove the error reportingbugSomething isn't workingSomething isn't workingdocumentation
Description
!pip install -U -q jax jaxlib
!pip freeze
import jax
import jax.numpy as np
alpha = np.float32(2.)
def sample(seed):
@jax.custom_jvp
def f(alpha):
return jax.random.gamma(seed, alpha, shape=[])
@f.defjvp
def f_jvp(primal, tangent):
alpha = primal
dalpha = tangent
sample = f(alpha)
partial_alpha = jax.lax.random_gamma_grad(alpha, sample)
return sample, partial_alpha * dalpha
return f(alpha)
jax.vmap(sample)(jax.random.split(jax.random.PRNGKey(1), 3))
=>
...
jax==0.1.72
jaxlib==0.1.51
...
/usr/local/lib/python3.6/dist-packages/jax/core.py in <listcomp>(.0)
1078 todo = []
1079 while True:
-> 1080 tracers = [x for x in outs if isinstance(x, Tracer) and x._trace.level > level]
1081 if tracers:
1082 ans = max(tracers, key=lambda x: x._trace.level)
TypeError: '>' not supported between instances of 'int' and 'NoneType'
Metadata
Metadata
Assignees
Labels
better_errorsImprove the error reportingImprove the error reportingbugSomething isn't workingSomething isn't workingdocumentation