Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
make jax.random default dtypes 64-bit
fixes #756
  • Loading branch information
mattjj committed May 22, 2019
commit 93e201f85ba2721ef51ac7f06d0e30c794290f37
82 changes: 52 additions & 30 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,21 @@ def _check_shape(name, shape):
raise ValueError(msg.format(name, shape))


def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
def uniform(key, shape, dtype=onp.float64, minval=0., maxval=1.):
"""Sample uniform random values in [minval, maxval) with given shape/dtype.

Args:
key: a PRNGKey used as the random key.
shape: a tuple of nonnegative integers representing the shape.
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
minval: optional, a minimum (inclusive) value for the range (default 0).
maxval: optional, a maximum (exclusive) value for the range (default 1).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _uniform(key, shape, dtype, minval, maxval)

@partial(jit, static_argnums=(1, 2))
Expand All @@ -247,7 +249,6 @@ def _uniform(key, shape, dtype, minval, maxval):
if not onp.issubdtype(dtype, onp.floating):
raise TypeError("uniform only accepts floating point dtypes.")

dtype = xla_bridge.canonicalize_dtype(dtype)
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
finfo = onp.finfo(dtype)
Expand All @@ -271,7 +272,7 @@ def _uniform(key, shape, dtype, minval, maxval):
lax.reshape(floats * (maxval - minval) + minval, shape))


def randint(key, shape, minval, maxval, dtype=onp.int32):
def randint(key, shape, minval, maxval, dtype=onp.int64):
"""Sample uniform random values in [minval, maxval) with given shape/dtype.

Args:
Expand All @@ -281,20 +282,21 @@ def randint(key, shape, minval, maxval, dtype=onp.int32):
(inclusive) value for the range.
maxval: int or array of ints broadcast-compatible with ``shape``, a maximum
(exclusive) value for the range.
dtype: optional, an int dtype for the returned values (default int32).
dtype: optional, an int dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _randint(key, shape, minval, maxval, dtype)

@partial(jit, static_argnums=(1, 4))
def _randint(key, shape, minval, maxval, dtype=onp.int32):
def _randint(key, shape, minval, maxval, dtype):
_check_shape("randint", shape)
if not onp.issubdtype(dtype, onp.integer):
raise TypeError("randint only accepts integer dtypes.")

dtype = xla_bridge.canonicalize_dtype(dtype)
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
nbits = onp.iinfo(dtype).bits
Expand Down Expand Up @@ -371,17 +373,19 @@ def _shuffle(key, x, axis):
return x


def normal(key, shape, dtype=onp.float32):
def normal(key, shape, dtype=onp.float64):
"""Sample standard normal random values with given shape and float dtype.

Args:
key: a PRNGKey used as the random key.
shape: a tuple of nonnegative integers representing the shape.
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _normal(key, shape, dtype)

@partial(jit, static_argnums=(1, 2))
Expand Down Expand Up @@ -412,14 +416,14 @@ def bernoulli(key, p=onp.float32(0.5), shape=()):
def _bernoulli(key, p, shape):
_check_shape("bernoulli", shape)
shape = shape or onp.shape(p)
if not onp.issubdtype(onp.float32, lax.dtype(p)):
p = lax.convert_element_type(p, onp.float32)
if not onp.issubdtype(onp.float64, lax.dtype(p)):
p = lax.convert_element_type(p, onp.float64)
if onp.shape(p) != shape:
p = np.broadcast_to(p, shape)
return lax.lt(uniform(key, shape, lax.dtype(p)), p)


def beta(key, a, b, shape=(), dtype=onp.float32):
def beta(key, a, b, shape=(), dtype=onp.float64):
"""Sample Bernoulli random values with given shape and mean.

Args:
Expand All @@ -430,11 +434,13 @@ def beta(key, a, b, shape=(), dtype=onp.float32):
beta of the random variables.
shape: optional, a tuple of nonnegative integers representing the shape
(default scalar).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _beta(key, a, b, shape, dtype)

@partial(jit, static_argnums=(3, 4))
Expand All @@ -449,18 +455,20 @@ def _beta(key, a, b, shape, dtype):
return gamma_a / (gamma_a + gamma_b)


def cauchy(key, shape=(), dtype=onp.float32):
def cauchy(key, shape=(), dtype=onp.float64):
"""Sample Cauchy random values with given shape and float dtype.

Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the shape
(default scalar).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _cauchy(key, shape, dtype)

@partial(jit, static_argnums=(1, 2))
Expand All @@ -471,7 +479,7 @@ def _cauchy(key, shape, dtype):
return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5))))


def dirichlet(key, alpha, shape=(), dtype=onp.float32):
def dirichlet(key, alpha, shape=(), dtype=onp.float64):
"""Sample Cauchy random values with given shape and float dtype.

Args:
Expand All @@ -480,11 +488,13 @@ def dirichlet(key, alpha, shape=(), dtype=onp.float32):
used as the concentration parameter of the random variables.
shape: optional, a tuple of nonnegative integers representing the batch
shape (defaults to `alpha.shape[:-1]`).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _dirichlet(key, alpha, shape, dtype)

@partial(jit, static_argnums=(2, 3))
Expand All @@ -496,18 +506,20 @@ def _dirichlet(key, alpha, shape, dtype):
return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True)


def exponential(key, shape=(), dtype=onp.float32):
def exponential(key, shape=(), dtype=onp.float64):
"""Sample Exponential random values with given shape and float dtype.

Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the shape
(default scalar).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _exponential(key, shape, dtype)

@partial(jit, static_argnums=(1, 2))
Expand Down Expand Up @@ -568,7 +580,7 @@ def _body_fn(kXVU):
return lax.select(lax.eq(z, zero), onp.finfo(z.dtype).tiny, z)


def gamma(key, a, shape=(), dtype=onp.float32):
def gamma(key, a, shape=(), dtype=onp.float64):
"""Sample Gamma random values with given shape and float dtype.

Args:
Expand All @@ -577,15 +589,17 @@ def gamma(key, a, shape=(), dtype=onp.float32):
of the random variables.
shape: optional, a tuple of nonnegative integers representing the shape
(default scalar).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _gamma(key, a, shape, dtype)

@partial(jit, static_argnums=(2, 3))
def _gamma(key, a, shape=(), dtype=onp.float32):
def _gamma(key, a, shape, dtype):
_check_shape("gamma", shape)
a = lax.convert_element_type(a, dtype)
shape = shape or onp.shape(a)
Expand All @@ -597,18 +611,20 @@ def _gamma(key, a, shape=(), dtype=onp.float32):
return np.reshape(samples, shape)


def gumbel(key, shape=(), dtype=onp.float32):
def gumbel(key, shape=(), dtype=onp.float64):
"""Sample Gumbel random values with given shape and float dtype.

Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the shape
(default scalar).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _gumbel(key, shape, dtype)

@partial(jit, static_argnums=(1, 2))
Expand All @@ -617,18 +633,20 @@ def _gumbel(key, shape, dtype):
return -np.log(-np.log(uniform(key, shape, dtype)))


def laplace(key, shape=(), dtype=onp.float32):
def laplace(key, shape=(), dtype=onp.float64):
"""Sample Laplace random values with given shape and float dtype.

Args:
key: a PRNGKey used as the random key.
shape: optional, a tuple of nonnegative integers representing the shape
(default scalar).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _laplace(key, shape, dtype)

@partial(jit, static_argnums=(1, 2))
Expand All @@ -638,7 +656,7 @@ def _laplace(key, shape, dtype):
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))


def pareto(key, b, shape=(), dtype=onp.float32):
def pareto(key, b, shape=(), dtype=onp.float64):
"""Sample Pareto random values with given shape and float dtype.

Args:
Expand All @@ -647,11 +665,13 @@ def pareto(key, b, shape=(), dtype=onp.float32):
of the random variables.
shape: optional, a tuple of nonnegative integers representing the shape
(default scalar).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _pareto(key, b, shape, dtype)

@partial(jit, static_argnums=(2, 3))
Expand All @@ -665,7 +685,7 @@ def _pareto(key, b, shape, dtype):
return lax.exp(lax.div(e, b))


def t(key, df, shape=(), dtype=onp.float32):
def t(key, df, shape=(), dtype=onp.float64):
"""Sample Student's t random values with given shape and float dtype.

Args:
Expand All @@ -674,11 +694,13 @@ def t(key, df, shape=(), dtype=onp.float32):
of the random variables.
shape: optional, a tuple of nonnegative integers representing the shape
(default scalar).
dtype: optional, a float dtype for the returned values (default float32).
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).

Returns:
A random array with the specified shape and dtype.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
return _t(key, df, shape, dtype)

@partial(jit, static_argnums=(2, 3))
Expand Down
8 changes: 8 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,14 @@ def feature_map(n, d, sigma=1.0, seed=123):
self.assertRaisesRegex(ValueError, re.compile(r'.*requires a concrete.*'),
lambda: feature_map(5, 3))

def testIssue756(self):
key = random.PRNGKey(0)
w = random.normal(key, ())
if FLAGS.jax_enable_x64:
self.assertEqual(onp.result_type(w), onp.float64)
else:
self.assertEqual(onp.result_type(w), onp.float32)


if __name__ == "__main__":
absltest.main()