-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
jax.random.poisson does not broadcast to the shape of the rate parameter. This is a problem for me--I need to JIT compile a function which samples from a Poisson distribution--but because JAX doesn't broadcast the shape of the rate parameter, I end up passing in the shape of the rate parameter to Poisson, and this gives a TracerArrayConversionError
.
By contrast, np.random.poisson broadcasts to the shape of the rate parameter.
import jax.numpy as jnp
key = jnp.random.PRNGKey(0)
rate = jnp.ones((5,5,5))
jax.random.poisson(key, lam=rate)
Output:
~/.conda/envs/jax_env/lib/python3.9/site-packages/jax/_src/random.py in poisson(key, lam, shape, dtype)
1242 shape = core.canonicalize_shape(shape)
1243 if np.shape(lam) != shape:
-> 1244 lam = jnp.broadcast_to(lam, shape)
1245 lam = lax.convert_element_type(lam, np.float32)
1246 return _poisson(key, lam, shape, dtype)
~/.conda/envs/jax_env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in broadcast_to(arr, shape)
1830 shape_tail = shape[nlead:]
1831 compatible = _all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
-> 1832 for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
1833 if nlead < 0 or not compatible:
1834 msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
~/.conda/envs/jax_env/lib/python3.9/site-packages/jax/_src/util.py in safe_zip(*args)
31 n = len(args[0])
32 for arg in args[1:]:
---> 33 assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
34 return list(zip(*args))
35
AssertionError: length mismatch: [3, 0]
tagging @shoyer since I think he wrote the Poisson sampling code, hope that's not too obnoxious :)
Link to original PR which added Poisson sampler is here: #2805
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working