Skip to content

Commit c7cbdd9

Browse files
committed
initial tests of full-featured typing.Array
1 parent e855a9c commit c7cbdd9

File tree

18 files changed

+291
-333
lines changed

18 files changed

+291
-333
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ repos:
1818
hooks:
1919
- id: mypy
2020
files: (jax/|tests/typing_test\.py)
21+
exclude: jax/_src/basearray.py # Use pyi instead
2122
additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5]
2223

2324
- repo: https://github.com/mwouts/jupytext

jax/_src/basearray.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import abc
16+
17+
class Array(abc.ABC):
18+
"""Array base class for JAX
19+
20+
`jax.Array` is meant as the public interface for instance checks and type
21+
annotation of JAX array objects.
22+
"""
23+
# Note: no abstract methods are defined in this base class; the associated pyi
24+
# file contains the type signature for static type checking.
25+
26+
# at property must be defined because we overwrite its docstring in lax_numpy.py
27+
@property
28+
def at(self):
29+
raise NotImplementedError("property must be defined in subclasses")

jax/_src/basearray.pyi

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import abc
16+
from typing import Any, Optional, Tuple, Union
17+
import numpy as np
18+
19+
20+
class Array(abc.ABC):
21+
dtype: np.dtype
22+
ndim: int
23+
size: int
24+
aval: Any
25+
26+
def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
27+
order=None):
28+
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
29+
" Use jax.numpy.array, or jax.numpy.zeros instead.")
30+
31+
def __getitem__(self, key, indices_are_sorted=False,
32+
unique_indices=False) -> Any: ...
33+
def __setitem__(self, key, value) -> Any: ...
34+
def __len__(self) -> Any: ...
35+
def __iter__(self) -> Any: ...
36+
def __reversed__(self) -> Any: ...
37+
38+
# Comparisons
39+
def __lt__(self, other) -> Any: ...
40+
def __le__(self, other) -> Any: ...
41+
def __eq__(self, other) -> Any: ...
42+
def __ne__(self, other) -> Any: ...
43+
def __gt__(self, other) -> Any: ...
44+
def __ge__(self, other) -> Any: ...
45+
46+
# Unary arithmetic
47+
48+
def __neg__(self) -> Any: ...
49+
def __pos__(self) -> Any: ...
50+
def __abs__(self) -> Any: ...
51+
def __invert__(self) -> Any: ...
52+
53+
# Binary arithmetic
54+
55+
def __add__(self, other) -> Any: ...
56+
def __sub__(self, other) -> Any: ...
57+
def __mul__(self, other) -> Any: ...
58+
def __matmul__(self, other) -> Any: ...
59+
def __truediv__(self, other) -> Any: ...
60+
def __floordiv__(self, other) -> Any: ...
61+
def __mod__(self, other) -> Any: ...
62+
def __divmod__(self, other) -> Any: ...
63+
def __pow__(self, other) -> Any: ...
64+
def __lshift__(self, other) -> Any: ...
65+
def __rshift__(self, other) -> Any: ...
66+
def __and__(self, other) -> Any: ...
67+
def __xor__(self, other) -> Any: ...
68+
def __or__(self, other) -> Any: ...
69+
70+
def __radd__(self, other) -> Any: ...
71+
def __rsub__(self, other) -> Any: ...
72+
def __rmul__(self, other) -> Any: ...
73+
def __rmatmul__(self, other) -> Any: ...
74+
def __rtruediv__(self, other) -> Any: ...
75+
def __rfloordiv__(self, other) -> Any: ...
76+
def __rmod__(self, other) -> Any: ...
77+
def __rdivmod__(self, other) -> Any: ...
78+
def __rpow__(self, other) -> Any: ...
79+
def __rlshift__(self, other) -> Any: ...
80+
def __rrshift__(self, other) -> Any: ...
81+
def __rand__(self, other) -> Any: ...
82+
def __rxor__(self, other) -> Any: ...
83+
def __ror__(self, other) -> Any: ...
84+
85+
def __bool__(self) -> Any: ...
86+
def __complex__(self) -> Any: ...
87+
def __int__(self) -> Any: ...
88+
def __float__(self) -> Any: ...
89+
def __round__(self, ndigits=None) -> Any: ...
90+
91+
def __index__(self) -> Any: ...
92+
93+
# np.ndarray methods:
94+
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
95+
keepdims=None) -> Any: ...
96+
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
97+
keepdims=None) -> Any: ...
98+
def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
99+
def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ...
100+
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Any: ...
101+
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
102+
def astype(self, dtype) -> Any: ...
103+
def choose(self, choices, out=None, mode='raise') -> Any: ...
104+
def clip(self, min=None, max=None, out=None) -> Any: ...
105+
def compress(self, condition, axis: Optional[int] = None, out=None) -> Any: ...
106+
def conj(self) -> Any: ...
107+
def conjugate(self) -> Any: ...
108+
def copy(self) -> Any: ...
109+
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
110+
dtype=None, out=None) -> Any: ...
111+
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
112+
dtype=None, out=None) -> Any: ...
113+
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Any: ...
114+
def dot(self, b, *, precision=None) -> Any: ...
115+
def flatten(self) -> Any: ...
116+
@property
117+
def imag(self) -> Any: ...
118+
def item(self, *args) -> Any: ...
119+
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
120+
keepdims=None, initial=None, where=None) -> Any: ...
121+
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
122+
out=None, keepdims=False, *, where=None,) -> Any: ...
123+
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
124+
keepdims=None, initial=None, where=None) -> Any: ...
125+
@property
126+
def nbytes(self) -> Any: ...
127+
def nonzero(self, *, size=None, fill_value=None) -> Any: ...
128+
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
129+
out=None, keepdims=None, initial=None, where=None) -> Any: ...
130+
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
131+
keepdims=False,) -> Any: ...
132+
def ravel(self, order='C') -> Any: ...
133+
@property
134+
def real(self) -> Any: ...
135+
def repeat(self, repeats, axis: Optional[int] = None, *,
136+
total_repeat_length=None) -> Any: ...
137+
def reshape(self, *args, order='C') -> Any: ...
138+
def round(self, decimals=0, out=None) -> Any: ...
139+
def searchsorted(self, v, side='left', sorter=None) -> Any: ...
140+
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ...
141+
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Any: ...
142+
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
143+
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
144+
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
145+
out=None, keepdims=None, initial=None, where=None) -> Any: ...
146+
def swapaxes(self, axis1: int, axis2: int) -> Any: ...
147+
def take(self, indices, axis: Optional[int] = None, out=None,
148+
mode=None) -> Any: ...
149+
def tobytes(self, order='C') -> Any: ...
150+
def tolist(self) -> Any: ...
151+
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
152+
out=None) -> Any: ...
153+
def transpose(self, *args) -> Any: ...
154+
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
155+
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ...
156+
def view(self, dtype=None, type=None) -> Any: ...
157+
158+
# Even though we don't always support the NumPy array protocol, e.g., for
159+
# tracer types, for type checking purposes we must declare support so we
160+
# implement the NumPy ArrayLike protocol.
161+
def __array__(self) -> Any: ...
162+
163+
def __dlpack__(self) -> Any: ...
164+
165+
# JAX extensions
166+
@property
167+
def at(self) -> Any: ...
168+
@property
169+
def shape(self) -> Tuple[int, ...]: ...
170+
@property
171+
def weak_type(self) -> bool: ...

jax/_src/checkify.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from jax._src.lax import control_flow as cf
3838
from jax._src.config import config
3939
from jax import lax
40+
from jax._src.typing import Array
4041
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
4142
safe_zip)
4243

@@ -62,9 +63,9 @@ def setnewattr(obj, name, val):
6263

6364
## Error value data type and functional assert.
6465

65-
Bool = Union[bool, core.Tracer]
66-
Int = Union[int, core.Tracer]
67-
Payload = Union[np.ndarray, jnp.ndarray, core.Tracer]
66+
Bool = Union[bool, Array]
67+
Int = Union[int, Array]
68+
Payload = Union[np.ndarray, Array]
6869

6970
# For now, the payload needs to be a fixed-size array: 3 int32s, used for the
7071
# OOB message.

jax/_src/device_array.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from jax._src import profiler
3131
from jax._src.lib import xla_client as xc
3232
import jax._src.util as util
33+
from jax._src.typing import Array
3334

3435
### device-persistent data
3536

@@ -332,7 +333,9 @@ class DeletedBuffer(object): pass
332333
deleted_buffer = DeletedBuffer()
333334

334335

336+
Array.register(DeviceArray)
335337
device_array_types: List[type] = [xc.Buffer, _DeviceArray]
336338
for _device_array in device_array_types:
337339
core.literalable_types.add(_device_array)
338340
core.pytype_aval_mappings[_device_array] = abstract_arrays.canonical_concrete_aval
341+
Array.register(_device_array)

jax/_src/lax/lax.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import itertools
2020
import operator
2121
from typing import (Any, Callable, Optional, Sequence, Tuple, List, Dict,
22-
TypeVar, Union, cast as type_cast)
22+
TypeVar, Union, cast as type_cast, overload)
2323
import warnings
2424

2525
import numpy as np
@@ -1105,6 +1105,14 @@ def _reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array:
11051105
def _reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
11061106
return reduce_xor_p.bind(operand, axes=tuple(axes))
11071107

1108+
@overload
1109+
def sort(operand: Array, dimension: int = -1,
1110+
is_stable: bool = True, num_keys: int = 1) -> Array: ...
1111+
1112+
@overload
1113+
def sort(operand: Sequence[Array], dimension: int = -1,
1114+
is_stable: bool = True, num_keys: int = 1) -> Tuple[Array, ...]: ...
1115+
11081116
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
11091117
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
11101118
"""Wraps XLA's `Sort
@@ -1303,7 +1311,7 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:
13031311

13041312
### convenience wrappers around traceables
13051313

1306-
def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
1314+
def full_like(x: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
13071315
shape: Optional[Shape] = None) -> Array:
13081316
"""Create a full array like np.full based on the example array `x`.
13091317
@@ -1317,7 +1325,7 @@ def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None
13171325
An ndarray with the same shape as `x` with its entries set equal to
13181326
`fill_value`, similar to the output of np.full.
13191327
"""
1320-
from jax.experimental import sharding, array
1328+
from jax.experimental import array
13211329

13221330
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
13231331
weak_type = dtype is None and dtypes.is_weakly_typed(x)
@@ -1330,11 +1338,11 @@ def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None
13301338
# probably in the form of a primitive like `val = match_sharding_p.bind(x, val)`
13311339
# (so it works in staged-out code as well as 'eager' code). Related to
13321340
# equi-sharding.
1333-
if (config.jax_array and hasattr(x, 'sharding') and
1334-
not dispatch.is_single_device_sharding(x.sharding) and
1335-
not isinstance(x.sharding, sharding.PmapSharding)):
1336-
return array.make_array_from_callback(
1337-
fill_shape, x.sharding, lambda idx: val[idx]) # type: ignore[arg-type]
1341+
if config.jax_array and hasattr(x, 'sharding'):
1342+
sharding = x.sharding # type: ignore[union-attr]
1343+
if (not dispatch.is_single_device_sharding(sharding) and
1344+
not isinstance(sharding, sharding.PmapSharding)):
1345+
return array.make_array_from_callback(fill_shape, sharding, lambda idx: val[idx])
13381346
return val
13391347

13401348

jax/_src/lax/slicing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@
4040
from jax._src.lib.mlir.dialects import mhlo
4141
from jax._src.lib import xla_bridge
4242
from jax._src.lib import xla_client
43+
from jax._src.typing import Shape
4344

4445
xb = xla_bridge
4546
xc = xla_client
4647

4748
Array = Any
48-
Shape = core.Shape
4949

5050
map, unsafe_map = safe_map, map
5151
zip, unsafe_zip = safe_zip, zip

jax/_src/numpy/lax_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3595,12 +3595,12 @@ def replace(tup, val):
35953595
j += 1
35963596

35973597

3598-
gather_indices = lax.concatenate(gather_indices, dimension=j)
3598+
gather_indices_arr = lax.concatenate(gather_indices, dimension=j)
35993599
dnums = lax.GatherDimensionNumbers(
36003600
offset_dims=tuple(offset_dims),
36013601
collapsed_slice_dims=tuple(collapsed_slice_dims),
36023602
start_index_map=tuple(start_index_map))
3603-
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes),
3603+
return lax.gather(arr, gather_indices_arr, dnums, tuple(slice_sizes),
36043604
mode="fill" if mode is None else mode)
36053605

36063606
### Indexing

0 commit comments

Comments
 (0)