Skip to content

jax.random.poisson does not broadcast to shape of rate param #7861

@bantin

Description

@bantin

jax.random.poisson does not broadcast to the shape of the rate parameter. This is a problem for me--I need to JIT compile a function which samples from a Poisson distribution--but because JAX doesn't broadcast the shape of the rate parameter, I end up passing in the shape of the rate parameter to Poisson, and this gives a TracerArrayConversionError.

By contrast, np.random.poisson broadcasts to the shape of the rate parameter.

import jax.numpy as jnp
key = jnp.random.PRNGKey(0)
rate = jnp.ones((5,5,5))
jax.random.poisson(key, lam=rate)

Output:

~/.conda/envs/jax_env/lib/python3.9/site-packages/jax/_src/random.py in poisson(key, lam, shape, dtype)
   1242   shape = core.canonicalize_shape(shape)
   1243   if np.shape(lam) != shape:
-> 1244     lam = jnp.broadcast_to(lam, shape)
   1245   lam = lax.convert_element_type(lam, np.float32)
   1246   return _poisson(key, lam, shape, dtype)

~/.conda/envs/jax_env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in broadcast_to(arr, shape)
   1830     shape_tail = shape[nlead:]
   1831     compatible = _all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
-> 1832                       for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
   1833     if nlead < 0 or not compatible:
   1834       msg = "Incompatible shapes for broadcasting: {} and requested shape {}"

~/.conda/envs/jax_env/lib/python3.9/site-packages/jax/_src/util.py in safe_zip(*args)
     31   n = len(args[0])
     32   for arg in args[1:]:
---> 33     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
     34   return list(zip(*args))
     35 

AssertionError: length mismatch: [3, 0]

tagging @shoyer since I think he wrote the Poisson sampling code, hope that's not too obnoxious :)
Link to original PR which added Poisson sampler is here: #2805

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions