Skip to content

Commit 1e5ec78

Browse files
committed
TMP
1 parent 6f13e47 commit 1e5ec78

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

jax/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from jax._src.lib import jax_jit
5151
from jax._src import traceback_util
5252
from jax._src.typing import DimSize, Shape
53+
from jax._src import typing
5354
traceback_util.register_exclusion(__file__)
5455

5556
zip, unsafe_zip = safe_zip, zip
@@ -530,9 +531,9 @@ def escaped_tracer_error(tracer, detail=None):
530531
msg += f'Detail: {detail}'
531532
return UnexpectedTracerError(msg)
532533

533-
class Tracer:
534+
class Tracer(typing.Array):
534535
__array_priority__ = 1000
535-
__slots__ = ['_trace', '__weakref__', '_line_info']
536+
__slots__ = ['_trace', '_line_info']
536537

537538
def __array__(self, *args, **kw):
538539
raise TracerArrayConversionError(self)

jax/experimental/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
9595

9696

9797
@pxla.use_cpp_class(xc.Array if xc._version >= 92 else None)
98-
class Array:
98+
class Array(typing.Array):
9999
# TODO(yashkatariya): Add __slots__ here.
100100

101101
@pxla.use_cpp_method

tests/typing_test.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717
This test is meant to be both a runtime test and a static type annotation test,
1818
so it should be checked with pytype/mypy as well as being run with pytest.
1919
"""
20-
from typing import Union
20+
from typing import Any, Optional, Union
2121

2222
import jax
23+
from jax import core
2324
from jax._src import test_util as jtu
2425
from jax._src import typing
2526
from jax import lax
2627
import jax.numpy as jnp
2728

29+
from jax.experimental.array import Array as ArrayImpl
30+
2831
from absl.testing import absltest
2932
import numpy as np
3033

@@ -109,6 +112,21 @@ def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]:
109112
self.assertTrue(jax.jit(is_array)(x))
110113
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
111114

115+
def testAnnotations(self):
116+
# This test is mainly meant for static type checking: we want to ensure that
117+
# Tracer and ArrayImpl are valid as array.Array.
118+
119+
def f(x: Any) -> Optional[typing.Array]:
120+
if isinstance(x, core.Tracer):
121+
return x
122+
elif isinstance(x, ArrayImpl):
123+
return x
124+
else:
125+
return None
126+
127+
x = jnp.arange(10)
128+
y = f(x)
129+
self.assertArraysEqual(x, y)
112130

113131
if __name__ == '__main__':
114132
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)