|
| 1 | +# Copyright 2022 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import abc |
| 16 | +from typing import Any, Optional, Tuple, Union |
| 17 | +import numpy as np |
| 18 | + |
| 19 | + |
| 20 | +class Array(abc.ABC): |
| 21 | + dtype: np.dtype |
| 22 | + ndim: int |
| 23 | + size: int |
| 24 | + aval: Any |
| 25 | + |
| 26 | + def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None, |
| 27 | + order=None): |
| 28 | + raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly." |
| 29 | + " Use jax.numpy.array, or jax.numpy.zeros instead.") |
| 30 | + |
| 31 | + def __getitem__(self, key, indices_are_sorted=False, |
| 32 | + unique_indices=False) -> Any: ... |
| 33 | + def __setitem__(self, key, value) -> Any: ... |
| 34 | + def __len__(self) -> Any: ... |
| 35 | + def __iter__(self) -> Any: ... |
| 36 | + def __reversed__(self) -> Any: ... |
| 37 | + |
| 38 | + # Comparisons |
| 39 | + def __lt__(self, other) -> Any: ... |
| 40 | + def __le__(self, other) -> Any: ... |
| 41 | + def __eq__(self, other) -> Any: ... |
| 42 | + def __ne__(self, other) -> Any: ... |
| 43 | + def __gt__(self, other) -> Any: ... |
| 44 | + def __ge__(self, other) -> Any: ... |
| 45 | + |
| 46 | + # Unary arithmetic |
| 47 | + |
| 48 | + def __neg__(self) -> Any: ... |
| 49 | + def __pos__(self) -> Any: ... |
| 50 | + def __abs__(self) -> Any: ... |
| 51 | + def __invert__(self) -> Any: ... |
| 52 | + |
| 53 | + # Binary arithmetic |
| 54 | + |
| 55 | + def __add__(self, other) -> Any: ... |
| 56 | + def __sub__(self, other) -> Any: ... |
| 57 | + def __mul__(self, other) -> Any: ... |
| 58 | + def __matmul__(self, other) -> Any: ... |
| 59 | + def __truediv__(self, other) -> Any: ... |
| 60 | + def __floordiv__(self, other) -> Any: ... |
| 61 | + def __mod__(self, other) -> Any: ... |
| 62 | + def __divmod__(self, other) -> Any: ... |
| 63 | + def __pow__(self, other) -> Any: ... |
| 64 | + def __lshift__(self, other) -> Any: ... |
| 65 | + def __rshift__(self, other) -> Any: ... |
| 66 | + def __and__(self, other) -> Any: ... |
| 67 | + def __xor__(self, other) -> Any: ... |
| 68 | + def __or__(self, other) -> Any: ... |
| 69 | + |
| 70 | + def __radd__(self, other) -> Any: ... |
| 71 | + def __rsub__(self, other) -> Any: ... |
| 72 | + def __rmul__(self, other) -> Any: ... |
| 73 | + def __rmatmul__(self, other) -> Any: ... |
| 74 | + def __rtruediv__(self, other) -> Any: ... |
| 75 | + def __rfloordiv__(self, other) -> Any: ... |
| 76 | + def __rmod__(self, other) -> Any: ... |
| 77 | + def __rdivmod__(self, other) -> Any: ... |
| 78 | + def __rpow__(self, other) -> Any: ... |
| 79 | + def __rlshift__(self, other) -> Any: ... |
| 80 | + def __rrshift__(self, other) -> Any: ... |
| 81 | + def __rand__(self, other) -> Any: ... |
| 82 | + def __rxor__(self, other) -> Any: ... |
| 83 | + def __ror__(self, other) -> Any: ... |
| 84 | + |
| 85 | + def __bool__(self) -> Any: ... |
| 86 | + def __complex__(self) -> Any: ... |
| 87 | + def __int__(self) -> Any: ... |
| 88 | + def __float__(self) -> Any: ... |
| 89 | + def __round__(self, ndigits=None) -> Any: ... |
| 90 | + |
| 91 | + def __index__(self) -> Any: ... |
| 92 | + |
| 93 | + # np.ndarray methods: |
| 94 | + def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, |
| 95 | + keepdims=None) -> Any: ... |
| 96 | + def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, |
| 97 | + keepdims=None) -> Any: ... |
| 98 | + def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ... |
| 99 | + def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ... |
| 100 | + def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Any: ... |
| 101 | + def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ... |
| 102 | + def astype(self, dtype) -> Any: ... |
| 103 | + def choose(self, choices, out=None, mode='raise') -> Any: ... |
| 104 | + def clip(self, min=None, max=None, out=None) -> Any: ... |
| 105 | + def compress(self, condition, axis: Optional[int] = None, out=None) -> Any: ... |
| 106 | + def conj(self) -> Any: ... |
| 107 | + def conjugate(self) -> Any: ... |
| 108 | + def copy(self) -> Any: ... |
| 109 | + def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, |
| 110 | + dtype=None, out=None) -> Any: ... |
| 111 | + def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, |
| 112 | + dtype=None, out=None) -> Any: ... |
| 113 | + def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Any: ... |
| 114 | + def dot(self, b, *, precision=None) -> Any: ... |
| 115 | + def flatten(self) -> Any: ... |
| 116 | + @property |
| 117 | + def imag(self) -> Any: ... |
| 118 | + def item(self, *args) -> Any: ... |
| 119 | + def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, |
| 120 | + keepdims=None, initial=None, where=None) -> Any: ... |
| 121 | + def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, |
| 122 | + out=None, keepdims=False, *, where=None,) -> Any: ... |
| 123 | + def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, |
| 124 | + keepdims=None, initial=None, where=None) -> Any: ... |
| 125 | + @property |
| 126 | + def nbytes(self) -> Any: ... |
| 127 | + def nonzero(self, *, size=None, fill_value=None) -> Any: ... |
| 128 | + def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, |
| 129 | + out=None, keepdims=None, initial=None, where=None) -> Any: ... |
| 130 | + def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, |
| 131 | + keepdims=False,) -> Any: ... |
| 132 | + def ravel(self, order='C') -> Any: ... |
| 133 | + @property |
| 134 | + def real(self) -> Any: ... |
| 135 | + def repeat(self, repeats, axis: Optional[int] = None, *, |
| 136 | + total_repeat_length=None) -> Any: ... |
| 137 | + def reshape(self, *args, order='C') -> Any: ... |
| 138 | + def round(self, decimals=0, out=None) -> Any: ... |
| 139 | + def searchsorted(self, v, side='left', sorter=None) -> Any: ... |
| 140 | + def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ... |
| 141 | + def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Any: ... |
| 142 | + def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, |
| 143 | + dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ... |
| 144 | + def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, |
| 145 | + out=None, keepdims=None, initial=None, where=None) -> Any: ... |
| 146 | + def swapaxes(self, axis1: int, axis2: int) -> Any: ... |
| 147 | + def take(self, indices, axis: Optional[int] = None, out=None, |
| 148 | + mode=None) -> Any: ... |
| 149 | + def tobytes(self, order='C') -> Any: ... |
| 150 | + def tolist(self) -> Any: ... |
| 151 | + def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, |
| 152 | + out=None) -> Any: ... |
| 153 | + def transpose(self, *args) -> Any: ... |
| 154 | + def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, |
| 155 | + dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ... |
| 156 | + def view(self, dtype=None, type=None) -> Any: ... |
| 157 | + |
| 158 | + # Even though we don't always support the NumPy array protocol, e.g., for |
| 159 | + # tracer types, for type checking purposes we must declare support so we |
| 160 | + # implement the NumPy ArrayLike protocol. |
| 161 | + def __array__(self) -> Any: ... |
| 162 | + |
| 163 | + def __dlpack__(self) -> Any: ... |
| 164 | + |
| 165 | + # JAX extensions |
| 166 | + @property |
| 167 | + def at(self) -> Any: ... |
| 168 | + @property |
| 169 | + def shape(self) -> Tuple[int, ...]: ... |
| 170 | + @property |
| 171 | + def weak_type(self) -> bool: ... |
0 commit comments