diff --git a/jax/random.py b/jax/random.py index 53e6f1affe7a..00fd41ca6a75 100644 --- a/jax/random.py +++ b/jax/random.py @@ -226,19 +226,21 @@ def _check_shape(name, shape): raise ValueError(msg.format(name, shape)) -def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.): +def uniform(key, shape=(), dtype=onp.float64, minval=0., maxval=1.): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). minval: optional, a minimum (inclusive) value for the range (default 0). maxval: optional, a maximum (exclusive) value for the range (default 1). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _uniform(key, shape, dtype, minval, maxval) @partial(jit, static_argnums=(1, 2)) @@ -247,7 +249,6 @@ def _uniform(key, shape, dtype, minval, maxval): if not onp.issubdtype(dtype, onp.floating): raise TypeError("uniform only accepts floating point dtypes.") - dtype = xla_bridge.canonicalize_dtype(dtype) minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) finfo = onp.finfo(dtype) @@ -271,7 +272,7 @@ def _uniform(key, shape, dtype, minval, maxval): lax.reshape(floats * (maxval - minval) + minval, shape)) -def randint(key, shape, minval, maxval, dtype=onp.int32): +def randint(key, shape, minval, maxval, dtype=onp.int64): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -281,20 +282,21 @@ def randint(key, shape, minval, maxval, dtype=onp.int32): (inclusive) value for the range. maxval: int or array of ints broadcast-compatible with ``shape``, a maximum (exclusive) value for the range. - dtype: optional, an int dtype for the returned values (default int32). + dtype: optional, an int dtype for the returned values (default int64 if + jax_enable_x64 is true, otherwise int32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _randint(key, shape, minval, maxval, dtype) @partial(jit, static_argnums=(1, 4)) -def _randint(key, shape, minval, maxval, dtype=onp.int32): +def _randint(key, shape, minval, maxval, dtype): _check_shape("randint", shape) if not onp.issubdtype(dtype, onp.integer): raise TypeError("randint only accepts integer dtypes.") - dtype = xla_bridge.canonicalize_dtype(dtype) minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) nbits = onp.iinfo(dtype).bits @@ -371,17 +373,19 @@ def _shuffle(key, x, axis): return x -def normal(key, shape, dtype=onp.float32): +def normal(key, shape=(), dtype=onp.float64): """Sample standard normal random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _normal(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -398,28 +402,31 @@ def bernoulli(key, p=onp.float32(0.5), shape=()): Args: key: a PRNGKey used as the random key. - p: optional, an array-like broadcastable to `shape` for the mean of the - random variables (default 0.5). + p: optional, an array-like of floating dtype broadcastable to `shape` for + the mean of the random variables (default 0.5). shape: optional, a tuple of nonnegative integers representing the shape (default scalar). Returns: A random array with the specified shape and boolean dtype. """ + dtype = xla_bridge.canonicalize_dtype(lax.dtype(p)) + if not onp.issubdtype(dtype, onp.floating): + msg = "bernoulli probability `p` must have a floating dtype, got {}." + raise TypeError(msg.format(dtype)) + p = lax.convert_element_type(p, dtype) return _bernoulli(key, p, shape) @partial(jit, static_argnums=(2,)) def _bernoulli(key, p, shape): _check_shape("bernoulli", shape) shape = shape or onp.shape(p) - if not onp.issubdtype(onp.float32, lax.dtype(p)): - p = lax.convert_element_type(p, onp.float32) if onp.shape(p) != shape: p = np.broadcast_to(p, shape) return lax.lt(uniform(key, shape, lax.dtype(p)), p) -def beta(key, a, b, shape=(), dtype=onp.float32): +def beta(key, a, b, shape=(), dtype=onp.float64): """Sample Bernoulli random values with given shape and mean. Args: @@ -430,11 +437,13 @@ def beta(key, a, b, shape=(), dtype=onp.float32): beta of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _beta(key, a, b, shape, dtype) @partial(jit, static_argnums=(3, 4)) @@ -449,18 +458,20 @@ def _beta(key, a, b, shape, dtype): return gamma_a / (gamma_a + gamma_b) -def cauchy(key, shape=(), dtype=onp.float32): +def cauchy(key, shape=(), dtype=onp.float64): """Sample Cauchy random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _cauchy(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -471,7 +482,7 @@ def _cauchy(key, shape, dtype): return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5)))) -def dirichlet(key, alpha, shape=(), dtype=onp.float32): +def dirichlet(key, alpha, shape=(), dtype=onp.float64): """Sample Cauchy random values with given shape and float dtype. Args: @@ -480,11 +491,13 @@ def dirichlet(key, alpha, shape=(), dtype=onp.float32): used as the concentration parameter of the random variables. shape: optional, a tuple of nonnegative integers representing the batch shape (defaults to `alpha.shape[:-1]`). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _dirichlet(key, alpha, shape, dtype) @partial(jit, static_argnums=(2, 3)) @@ -496,18 +509,20 @@ def _dirichlet(key, alpha, shape, dtype): return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True) -def exponential(key, shape=(), dtype=onp.float32): +def exponential(key, shape=(), dtype=onp.float64): """Sample Exponential random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _exponential(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -568,7 +583,7 @@ def _body_fn(kXVU): return lax.select(lax.eq(z, zero), onp.finfo(z.dtype).tiny, z) -def gamma(key, a, shape=(), dtype=onp.float32): +def gamma(key, a, shape=(), dtype=onp.float64): """Sample Gamma random values with given shape and float dtype. Args: @@ -577,15 +592,17 @@ def gamma(key, a, shape=(), dtype=onp.float32): of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _gamma(key, a, shape, dtype) @partial(jit, static_argnums=(2, 3)) -def _gamma(key, a, shape=(), dtype=onp.float32): +def _gamma(key, a, shape, dtype): _check_shape("gamma", shape) a = lax.convert_element_type(a, dtype) shape = shape or onp.shape(a) @@ -597,18 +614,20 @@ def _gamma(key, a, shape=(), dtype=onp.float32): return np.reshape(samples, shape) -def gumbel(key, shape=(), dtype=onp.float32): +def gumbel(key, shape=(), dtype=onp.float64): """Sample Gumbel random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _gumbel(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -617,18 +636,20 @@ def _gumbel(key, shape, dtype): return -np.log(-np.log(uniform(key, shape, dtype))) -def laplace(key, shape=(), dtype=onp.float32): +def laplace(key, shape=(), dtype=onp.float64): """Sample Laplace random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _laplace(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -638,7 +659,7 @@ def _laplace(key, shape, dtype): return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u)))) -def pareto(key, b, shape=(), dtype=onp.float32): +def pareto(key, b, shape=(), dtype=onp.float64): """Sample Pareto random values with given shape and float dtype. Args: @@ -647,11 +668,13 @@ def pareto(key, b, shape=(), dtype=onp.float32): of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _pareto(key, b, shape, dtype) @partial(jit, static_argnums=(2, 3)) @@ -665,7 +688,7 @@ def _pareto(key, b, shape, dtype): return lax.exp(lax.div(e, b)) -def t(key, df, shape=(), dtype=onp.float32): +def t(key, df, shape=(), dtype=onp.float64): """Sample Student's t random values with given shape and float dtype. Args: @@ -674,11 +697,13 @@ def t(key, df, shape=(), dtype=onp.float32): of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _t(key, df, shape, dtype) @partial(jit, static_argnums=(2, 3)) diff --git a/tests/random_test.py b/tests/random_test.py index a9b20582d0a3..7b0a635eb610 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -358,6 +358,14 @@ def feature_map(n, d, sigma=1.0, seed=123): self.assertRaisesRegex(ValueError, re.compile(r'.*requires a concrete.*'), lambda: feature_map(5, 3)) + def testIssue756(self): + key = random.PRNGKey(0) + w = random.normal(key, ()) + if FLAGS.jax_enable_x64: + self.assertEqual(onp.result_type(w), onp.float64) + else: + self.assertEqual(onp.result_type(w), onp.float32) + if __name__ == "__main__": absltest.main()