From 93e201f85ba2721ef51ac7f06d0e30c794290f37 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 22 May 2019 16:22:12 -0700 Subject: [PATCH 1/2] make jax.random default dtypes 64-bit fixes #756 --- jax/random.py | 82 ++++++++++++++++++++++++++++---------------- tests/random_test.py | 8 +++++ 2 files changed, 60 insertions(+), 30 deletions(-) diff --git a/jax/random.py b/jax/random.py index 53e6f1affe7a..8d2da4ff0d08 100644 --- a/jax/random.py +++ b/jax/random.py @@ -226,19 +226,21 @@ def _check_shape(name, shape): raise ValueError(msg.format(name, shape)) -def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.): +def uniform(key, shape, dtype=onp.float64, minval=0., maxval=1.): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). minval: optional, a minimum (inclusive) value for the range (default 0). maxval: optional, a maximum (exclusive) value for the range (default 1). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _uniform(key, shape, dtype, minval, maxval) @partial(jit, static_argnums=(1, 2)) @@ -247,7 +249,6 @@ def _uniform(key, shape, dtype, minval, maxval): if not onp.issubdtype(dtype, onp.floating): raise TypeError("uniform only accepts floating point dtypes.") - dtype = xla_bridge.canonicalize_dtype(dtype) minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) finfo = onp.finfo(dtype) @@ -271,7 +272,7 @@ def _uniform(key, shape, dtype, minval, maxval): lax.reshape(floats * (maxval - minval) + minval, shape)) -def randint(key, shape, minval, maxval, dtype=onp.int32): +def randint(key, shape, minval, maxval, dtype=onp.int64): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -281,20 +282,21 @@ def randint(key, shape, minval, maxval, dtype=onp.int32): (inclusive) value for the range. maxval: int or array of ints broadcast-compatible with ``shape``, a maximum (exclusive) value for the range. - dtype: optional, an int dtype for the returned values (default int32). + dtype: optional, an int dtype for the returned values (default int64 if + jax_enable_x64 is true, otherwise int32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _randint(key, shape, minval, maxval, dtype) @partial(jit, static_argnums=(1, 4)) -def _randint(key, shape, minval, maxval, dtype=onp.int32): +def _randint(key, shape, minval, maxval, dtype): _check_shape("randint", shape) if not onp.issubdtype(dtype, onp.integer): raise TypeError("randint only accepts integer dtypes.") - dtype = xla_bridge.canonicalize_dtype(dtype) minval = lax.convert_element_type(minval, dtype) maxval = lax.convert_element_type(maxval, dtype) nbits = onp.iinfo(dtype).bits @@ -371,17 +373,19 @@ def _shuffle(key, x, axis): return x -def normal(key, shape, dtype=onp.float32): +def normal(key, shape, dtype=onp.float64): """Sample standard normal random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _normal(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -412,14 +416,14 @@ def bernoulli(key, p=onp.float32(0.5), shape=()): def _bernoulli(key, p, shape): _check_shape("bernoulli", shape) shape = shape or onp.shape(p) - if not onp.issubdtype(onp.float32, lax.dtype(p)): - p = lax.convert_element_type(p, onp.float32) + if not onp.issubdtype(onp.float64, lax.dtype(p)): + p = lax.convert_element_type(p, onp.float64) if onp.shape(p) != shape: p = np.broadcast_to(p, shape) return lax.lt(uniform(key, shape, lax.dtype(p)), p) -def beta(key, a, b, shape=(), dtype=onp.float32): +def beta(key, a, b, shape=(), dtype=onp.float64): """Sample Bernoulli random values with given shape and mean. Args: @@ -430,11 +434,13 @@ def beta(key, a, b, shape=(), dtype=onp.float32): beta of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _beta(key, a, b, shape, dtype) @partial(jit, static_argnums=(3, 4)) @@ -449,18 +455,20 @@ def _beta(key, a, b, shape, dtype): return gamma_a / (gamma_a + gamma_b) -def cauchy(key, shape=(), dtype=onp.float32): +def cauchy(key, shape=(), dtype=onp.float64): """Sample Cauchy random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _cauchy(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -471,7 +479,7 @@ def _cauchy(key, shape, dtype): return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5)))) -def dirichlet(key, alpha, shape=(), dtype=onp.float32): +def dirichlet(key, alpha, shape=(), dtype=onp.float64): """Sample Cauchy random values with given shape and float dtype. Args: @@ -480,11 +488,13 @@ def dirichlet(key, alpha, shape=(), dtype=onp.float32): used as the concentration parameter of the random variables. shape: optional, a tuple of nonnegative integers representing the batch shape (defaults to `alpha.shape[:-1]`). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _dirichlet(key, alpha, shape, dtype) @partial(jit, static_argnums=(2, 3)) @@ -496,18 +506,20 @@ def _dirichlet(key, alpha, shape, dtype): return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True) -def exponential(key, shape=(), dtype=onp.float32): +def exponential(key, shape=(), dtype=onp.float64): """Sample Exponential random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _exponential(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -568,7 +580,7 @@ def _body_fn(kXVU): return lax.select(lax.eq(z, zero), onp.finfo(z.dtype).tiny, z) -def gamma(key, a, shape=(), dtype=onp.float32): +def gamma(key, a, shape=(), dtype=onp.float64): """Sample Gamma random values with given shape and float dtype. Args: @@ -577,15 +589,17 @@ def gamma(key, a, shape=(), dtype=onp.float32): of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _gamma(key, a, shape, dtype) @partial(jit, static_argnums=(2, 3)) -def _gamma(key, a, shape=(), dtype=onp.float32): +def _gamma(key, a, shape, dtype): _check_shape("gamma", shape) a = lax.convert_element_type(a, dtype) shape = shape or onp.shape(a) @@ -597,18 +611,20 @@ def _gamma(key, a, shape=(), dtype=onp.float32): return np.reshape(samples, shape) -def gumbel(key, shape=(), dtype=onp.float32): +def gumbel(key, shape=(), dtype=onp.float64): """Sample Gumbel random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _gumbel(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -617,18 +633,20 @@ def _gumbel(key, shape, dtype): return -np.log(-np.log(uniform(key, shape, dtype))) -def laplace(key, shape=(), dtype=onp.float32): +def laplace(key, shape=(), dtype=onp.float64): """Sample Laplace random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _laplace(key, shape, dtype) @partial(jit, static_argnums=(1, 2)) @@ -638,7 +656,7 @@ def _laplace(key, shape, dtype): return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u)))) -def pareto(key, b, shape=(), dtype=onp.float32): +def pareto(key, b, shape=(), dtype=onp.float64): """Sample Pareto random values with given shape and float dtype. Args: @@ -647,11 +665,13 @@ def pareto(key, b, shape=(), dtype=onp.float32): of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _pareto(key, b, shape, dtype) @partial(jit, static_argnums=(2, 3)) @@ -665,7 +685,7 @@ def _pareto(key, b, shape, dtype): return lax.exp(lax.div(e, b)) -def t(key, df, shape=(), dtype=onp.float32): +def t(key, df, shape=(), dtype=onp.float64): """Sample Student's t random values with given shape and float dtype. Args: @@ -674,11 +694,13 @@ def t(key, df, shape=(), dtype=onp.float32): of the random variables. shape: optional, a tuple of nonnegative integers representing the shape (default scalar). - dtype: optional, a float dtype for the returned values (default float32). + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). Returns: A random array with the specified shape and dtype. """ + dtype = xla_bridge.canonicalize_dtype(dtype) return _t(key, df, shape, dtype) @partial(jit, static_argnums=(2, 3)) diff --git a/tests/random_test.py b/tests/random_test.py index a9b20582d0a3..7b0a635eb610 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -358,6 +358,14 @@ def feature_map(n, d, sigma=1.0, seed=123): self.assertRaisesRegex(ValueError, re.compile(r'.*requires a concrete.*'), lambda: feature_map(5, 3)) + def testIssue756(self): + key = random.PRNGKey(0) + w = random.normal(key, ()) + if FLAGS.jax_enable_x64: + self.assertEqual(onp.result_type(w), onp.float64) + else: + self.assertEqual(onp.result_type(w), onp.float32) + if __name__ == "__main__": absltest.main() From c743c14c8f1458f57650be20a4e4320ba1a42392 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 22 May 2019 20:06:12 -0700 Subject: [PATCH 2/2] address reviewer comments --- jax/random.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/jax/random.py b/jax/random.py index 8d2da4ff0d08..00fd41ca6a75 100644 --- a/jax/random.py +++ b/jax/random.py @@ -226,7 +226,7 @@ def _check_shape(name, shape): raise ValueError(msg.format(name, shape)) -def uniform(key, shape, dtype=onp.float64, minval=0., maxval=1.): +def uniform(key, shape=(), dtype=onp.float64, minval=0., maxval=1.): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -373,7 +373,7 @@ def _shuffle(key, x, axis): return x -def normal(key, shape, dtype=onp.float64): +def normal(key, shape=(), dtype=onp.float64): """Sample standard normal random values with given shape and float dtype. Args: @@ -402,22 +402,25 @@ def bernoulli(key, p=onp.float32(0.5), shape=()): Args: key: a PRNGKey used as the random key. - p: optional, an array-like broadcastable to `shape` for the mean of the - random variables (default 0.5). + p: optional, an array-like of floating dtype broadcastable to `shape` for + the mean of the random variables (default 0.5). shape: optional, a tuple of nonnegative integers representing the shape (default scalar). Returns: A random array with the specified shape and boolean dtype. """ + dtype = xla_bridge.canonicalize_dtype(lax.dtype(p)) + if not onp.issubdtype(dtype, onp.floating): + msg = "bernoulli probability `p` must have a floating dtype, got {}." + raise TypeError(msg.format(dtype)) + p = lax.convert_element_type(p, dtype) return _bernoulli(key, p, shape) @partial(jit, static_argnums=(2,)) def _bernoulli(key, p, shape): _check_shape("bernoulli", shape) shape = shape or onp.shape(p) - if not onp.issubdtype(onp.float64, lax.dtype(p)): - p = lax.convert_element_type(p, onp.float64) if onp.shape(p) != shape: p = np.broadcast_to(p, shape) return lax.lt(uniform(key, shape, lax.dtype(p)), p)