-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
This code:
import jax.numpy as jnp
jnp.zeros(5).at[::2].set(1)
results in a warning
UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.int64'> requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
No warning is issued when doing integer indexing or slice indexing with stride 1.
When raising warnings as errors:
import warnings
warnings.simplefilter("error")
import jax.numpy as jnp
jnp.zeros(5).at[::2].set(1)
this is the stacktrace:
---------------------------------------------------------------------------
UserWarning Traceback (most recent call last)
<ipython-input-7-058b40a9a665> in <module>
----> 1 jnp.zeros(5).at[::2].set(1)
[...]/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in set(self, values)
3929 See :mod:`jax.ops` for details.
3930 """
-> 3931 return ops.index_update(self.array, self.index, values)
3932
3933 def add(self, values):
[...]/lib/python3.8/site-packages/jax/ops/scatter.py in index_update(x, idx, y)
279 [1., 1., 1., 6., 6., 6.]], dtype=float32)
280 """
--> 281 return _scatter_update(x, idx, y, lax.scatter)
282
283 def segment_sum(data, segment_ids, num_segments=None):
[...]/lib/python3.8/site-packages/jax/ops/scatter.py in _scatter_update(x, idx, y, scatter_op)
45 # is more or less a transpose of the gather equivalent.
46 treedef, static_idx, dynamic_idx = np._split_index_for_jit(idx)
---> 47 return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx)
48
49
[...]/lib/python3.8/site-packages/jax/ops/scatter.py in _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx)
56
57 idx = np._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
---> 58 indexer = np._index_to_gather(np.shape(x), idx)
59
60 # Broadcast `y` to the slice output shape.
[...]/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _index_to_gather(x_shape, idx)
3367 start_index_map.append(x_axis)
3368 else:
-> 3369 i = arange(start, limit, stride, dtype=index_dtype)
3370 size = i.shape[0]
3371 slice_shape.append(size)
[...]/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in arange(start, stop, step, dtype)
2114 @_wraps(onp.arange)
2115 def arange(start, stop=None, step=None, dtype=None):
-> 2116 lax._check_user_dtype_supported(dtype, "arange")
2117 if stop is None and step is None:
2118 dtype = dtype or _dtype(start)
[...]/lib/python3.8/site-packages/jax/lax/lax.py in _check_user_dtype_supported(dtype, fun_name)
5102 fun_name = "requested in {}".format(fun_name) if fun_name else ""
5103 truncated_dtype = dtypes.canonicalize_dtype(dtype).name
-> 5104 warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.int64'> requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
It looks innocuous, but the code that generates the index_dtype
is here:
https://github.com/google/jax/blob/a2c06d6113ea02075bfbc924d2d6d8fd39c2f6d3/jax/numpy/lax_numpy.py#L3278
Why would it ever produce int64
in this scenario?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working