Skip to content

Indexing with stride != 1 results in warning about int64 #2795

@kohr-h

Description

@kohr-h

This code:

import jax.numpy as jnp
jnp.zeros(5).at[::2].set(1)

results in a warning

UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.int64'> requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.

No warning is issued when doing integer indexing or slice indexing with stride 1.

When raising warnings as errors:

import warnings
warnings.simplefilter("error")
import jax.numpy as jnp
jnp.zeros(5).at[::2].set(1)

this is the stacktrace:

---------------------------------------------------------------------------
UserWarning                               Traceback (most recent call last)
<ipython-input-7-058b40a9a665> in <module>
----> 1 jnp.zeros(5).at[::2].set(1)

[...]/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in set(self, values)
   3929     See :mod:`jax.ops` for details.
   3930     """
-> 3931     return ops.index_update(self.array, self.index, values)
   3932 
   3933   def add(self, values):

[...]/lib/python3.8/site-packages/jax/ops/scatter.py in index_update(x, idx, y)
    279          [1., 1., 1., 6., 6., 6.]], dtype=float32)
    280   """
--> 281   return _scatter_update(x, idx, y, lax.scatter)
    282 
    283 def segment_sum(data, segment_ids, num_segments=None):

[...]/lib/python3.8/site-packages/jax/ops/scatter.py in _scatter_update(x, idx, y, scatter_op)
     45   # is more or less a transpose of the gather equivalent.
     46   treedef, static_idx, dynamic_idx = np._split_index_for_jit(idx)
---> 47   return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx)
     48 
     49 

[...]/lib/python3.8/site-packages/jax/ops/scatter.py in _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx)
     56 
     57   idx = np._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
---> 58   indexer = np._index_to_gather(np.shape(x), idx)
     59 
     60   # Broadcast `y` to the slice output shape.

[...]/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _index_to_gather(x_shape, idx)
   3367         start_index_map.append(x_axis)
   3368       else:
-> 3369         i = arange(start, limit, stride, dtype=index_dtype)
   3370         size = i.shape[0]
   3371         slice_shape.append(size)

[...]/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in arange(start, stop, step, dtype)
   2114 @_wraps(onp.arange)
   2115 def arange(start, stop=None, step=None, dtype=None):
-> 2116   lax._check_user_dtype_supported(dtype, "arange")
   2117   if stop is None and step is None:
   2118     dtype = dtype or _dtype(start)

[...]/lib/python3.8/site-packages/jax/lax/lax.py in _check_user_dtype_supported(dtype, fun_name)
   5102     fun_name = "requested in {}".format(fun_name) if fun_name else ""
   5103     truncated_dtype = dtypes.canonicalize_dtype(dtype).name
-> 5104     warnings.warn(msg.format(dtype, fun_name , truncated_dtype))

UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.int64'> requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.

It looks innocuous, but the code that generates the index_dtype is here:
https://github.com/google/jax/blob/a2c06d6113ea02075bfbc924d2d6d8fd39c2f6d3/jax/numpy/lax_numpy.py#L3278

Why would it ever produce int64 in this scenario?

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions