From 6f50677ace0dbf227f575bd45ab895ed0c77d450 Mon Sep 17 00:00:00 2001 From: archis Date: Tue, 21 Mar 2023 08:20:09 -0700 Subject: [PATCH 01/65] hilbert transform --- jax/_src/scipy/signal.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 0e5f7aa31ee5..61b0e0647d42 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -637,3 +637,31 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', time = jnp.arange(x.shape[0], dtype=np.finfo(x.dtype).dtype) / fs return time, x + +@_wraps(osp_signal.hilbert) +def hilbert(x, N=None, axis=-1): + x = jnp.asarray(x) + if jnp.iscomplexobj(x): + raise ValueError("x must be real.") + if N is None: + N = x.shape[axis] + if N <= 0: + raise ValueError("N must be positive.") + + Xf = jnp.fft.fft(x, N, axis=axis) + if N % 2 == 0: + h = jnp.concatenate([jnp.ones(1), jnp.ones(N // 2 - 2) * 2, jnp.ones(1), jnp.zeros(N // 2)]) + # h[0] = h[N // 2] = 1 + # h[1:N // 2] = 2 + else: + h = jnp.concatenate([jnp.ones(1), jnp.ones((N + 1) // 2 - 1) * 2, jnp.zeros(N // 2)]) + # h[0] = 1 + # h[1:(N + 1) // 2] = 2 + + if x.ndim > 1: + raise NotImplementedError("x must be 1D.") + # ind = [np.newaxis] * x.ndim + # ind[axis] = slice(None) + # h = h[tuple(ind)] + x = jnp.fft.ifft(Xf * h, axis=axis) + return x \ No newline at end of file From 2bdf6c9ccad4b36b392783dddb8a4ad9aa9e8960 Mon Sep 17 00:00:00 2001 From: archis Date: Tue, 21 Mar 2023 08:33:48 -0700 Subject: [PATCH 02/65] newline --- jax/_src/scipy/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 61b0e0647d42..3c6bb8e1be46 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -664,4 +664,4 @@ def hilbert(x, N=None, axis=-1): # ind[axis] = slice(None) # h = h[tuple(ind)] x = jnp.fft.ifft(Xf * h, axis=axis) - return x \ No newline at end of file + return x From cf5ddb002f3a3e133f28aa17a3f4b17f02fa1094 Mon Sep 17 00:00:00 2001 From: archis Date: Fri, 24 Mar 2023 09:00:53 -0700 Subject: [PATCH 03/65] dtype errors --- jax/scipy/signal.py | 1 + tests/scipy_signal_test.py | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/jax/scipy/signal.py b/jax/scipy/signal.py index a0fe5987c3f6..60c00dac84a3 100644 --- a/jax/scipy/signal.py +++ b/jax/scipy/signal.py @@ -25,4 +25,5 @@ istft as istft, stft as stft, welch as welch, + hilbert as hilbert, ) diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 622e36ac430e..1780280f0a66 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -56,7 +56,8 @@ ((65, 24), 24, 7, -2, -1), ] - +float_dtypes = jtu.dtypes.floating +real_dtypes = float_dtypes + jtu.dtypes.integer + jtu.dtypes.boolean default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex _TPU_FFT_TOL = 0.15 @@ -390,5 +391,24 @@ def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, check_dtypes=False) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) + @jtu.sample_product( + shape=onedim_shapes, + N=[None, 1, 7, 13, 20], + axis=[-1, 0], + dtype=real_dtypes, + ) + def testHilbert(self, shape, N, axis, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype),) + name = 'hilbert' + jnp_op = getattr(jsp_signal, name) + osp_op = getattr(osp_signal, name) + jnp_fn = lambda a: jnp_op(a, N=N, axis=axis) + osp_fn = lambda a: osp_op(a, N=N, axis=axis) + # Numpy promotes to complex128 aggressively. + self._CheckAgainstNumpy(osp_fn, jnp_fn, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(jnp_op, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From e120f13845ef62c487eaf391a085c8580fe1c5f4 Mon Sep 17 00:00:00 2001 From: archis Date: Fri, 24 Mar 2023 12:45:18 -0700 Subject: [PATCH 04/65] passing tests --- jax/_src/scipy/signal.py | 18 ++++++++++-------- tests/scipy_signal_test.py | 14 ++++---------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 3c6bb8e1be46..20fa66030259 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -639,8 +639,14 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', return time, x @_wraps(osp_signal.hilbert) -def hilbert(x, N=None, axis=-1): +def hilbert(x: Array, N: int = None, axis: int = -1): + check_arraylike('hilbert', x) x = jnp.asarray(x) + if x.ndim > 1: + raise NotImplementedError("x must be 1D.") + # ind = [np.newaxis] * x.ndim + # ind[axis] = slice(None) + # h = h[tuple(ind)] if jnp.iscomplexobj(x): raise ValueError("x must be real.") if N is None: @@ -650,18 +656,14 @@ def hilbert(x, N=None, axis=-1): Xf = jnp.fft.fft(x, N, axis=axis) if N % 2 == 0: - h = jnp.concatenate([jnp.ones(1), jnp.ones(N // 2 - 2) * 2, jnp.ones(1), jnp.zeros(N // 2)]) + h = jnp.zeros(N, Xf.dtype).at[0].set(1).at[1:N // 2].set(2).at[N // 2].set(1) # h[0] = h[N // 2] = 1 # h[1:N // 2] = 2 else: - h = jnp.concatenate([jnp.ones(1), jnp.ones((N + 1) // 2 - 1) * 2, jnp.zeros(N // 2)]) + h = jnp.zeros(N, Xf.dtype).at[0].set(1).at[1:(N+1) // 2].set(2) # h[0] = 1 # h[1:(N + 1) // 2] = 2 - if x.ndim > 1: - raise NotImplementedError("x must be 1D.") - # ind = [np.newaxis] * x.ndim - # ind[axis] = slice(None) - # h = h[tuple(ind)] + x = jnp.fft.ifft(Xf * h, axis=axis) return x diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 1780280f0a66..276bc1352582 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -392,23 +392,17 @@ def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) @jtu.sample_product( - shape=onedim_shapes, - N=[None, 1, 7, 13, 20], - axis=[-1, 0], + shape=onedim_shapes, N=[None, 1, 7, 13, 20], axis=[-1, 0], dtype=real_dtypes, ) def testHilbert(self, shape, N, axis, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) - name = 'hilbert' - jnp_op = getattr(jsp_signal, name) - osp_op = getattr(osp_signal, name) - jnp_fn = lambda a: jnp_op(a, N=N, axis=axis) - osp_fn = lambda a: osp_op(a, N=N, axis=axis) - # Numpy promotes to complex128 aggressively. + jnp_fn = lambda a: jsp_signal.hilbert(a, N=N, axis=axis) + osp_fn = lambda a: osp_signal.hilbert(a, N=N, axis=axis) self._CheckAgainstNumpy(osp_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_op, args_maker) + self._CompileAndCheck(jnp_fn, args_maker) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 0fd7920e74991ae55cf6edbfbff6cf5811b3d184 Mon Sep 17 00:00:00 2001 From: archis Date: Fri, 24 Mar 2023 13:00:06 -0700 Subject: [PATCH 05/65] linters --- jax/_src/scipy/signal.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 20fa66030259..86efdef7b619 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -639,7 +639,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', return time, x @_wraps(osp_signal.hilbert) -def hilbert(x: Array, N: int = None, axis: int = -1): +def hilbert(x: Array, N: Union[int, None] = None, axis: int = -1): check_arraylike('hilbert', x) x = jnp.asarray(x) if x.ndim > 1: @@ -664,6 +664,5 @@ def hilbert(x: Array, N: int = None, axis: int = -1): # h[0] = 1 # h[1:(N + 1) // 2] = 2 - x = jnp.fft.ifft(Xf * h, axis=axis) return x From abf8947d1aa26f25f3fa2bd648498d7a333f583e Mon Sep 17 00:00:00 2001 From: archis Date: Fri, 24 Mar 2023 13:02:31 -0700 Subject: [PATCH 06/65] Optional int --- jax/_src/scipy/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 86efdef7b619..3a7618b72509 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -639,7 +639,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', return time, x @_wraps(osp_signal.hilbert) -def hilbert(x: Array, N: Union[int, None] = None, axis: int = -1): +def hilbert(x: Array, N: Optional[int] = None, axis: int = -1): check_arraylike('hilbert', x) x = jnp.asarray(x) if x.ndim > 1: From cbc25dc0acc24bd17d34beb6e4add44433255713 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 21 Mar 2023 08:39:46 -0700 Subject: [PATCH 07/65] Raise a better error message when there is a device assignment mismatch via the apply_primitive route. PiperOrigin-RevId: 518282464 --- jax/_src/dispatch.py | 19 +++++++++++++++++-- jax/_src/pjit.py | 27 ++++++++++++++++++--------- tests/pjit_test.py | 12 ++++++++++++ 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 817fb91b2076..c3e09a7e5d2d 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -39,6 +39,7 @@ from jax._src import core from jax._src import dtypes from jax._src import linear_util as lu +from jax._src import api_util from jax._src import path from jax._src import profiler from jax._src import source_info_util @@ -110,10 +111,24 @@ def arg_spec(x: Any) -> ArgSpec: def apply_primitive(prim, *args, **params): """Impl rule that compiles and runs a single primitive 'prim' using XLA.""" - compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), - **params) + from jax._src import pjit + + try: + compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), + **params) + except pxla.DeviceAssignmentMismatchError as e: + fails, = e.args + # TODO(yashkatariya): Thread through a signature_fun via every primitive + # using apply_primtive so that the error message has the right argument + # name instead of `args[0]`, etc. + arg_names = api_util._arg_names(prim.impl, args, {}, (), ()) + msg = pjit._device_assignment_mismatch_error( + prim.name, fails, args, 'jit', arg_names) + raise ValueError(msg) from None + return compiled_fun(*args) + def simple_impl(prim): prim.def_impl(partial(apply_primitive, prim)) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d44abc4be5df..a6b4a77f10a8 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -141,13 +141,13 @@ def _find_arg_mismatch(arg_list, fails, fun_name): break return mismatched_args_msg - -def _device_assignment_mismatch_error(fun, fails, in_tree, args_flat, api_name): +# TODO(yashkatariya): Try to use debug_info that is populated in +# common_infer_params. +def _get_arg_names(fun, in_tree, args_flat): sig = _try_infer_args(fun, in_tree) - args = tree_unflatten(in_tree, args_flat) - args_aug = generate_key_paths(args) + args_aug = generate_key_paths(tree_unflatten(in_tree, args_flat)) - arg_list = [] + arg_names = [] for arg_key, val in args_aug: ak, *rem_keys = arg_key if sig is not None: @@ -155,10 +155,17 @@ def _device_assignment_mismatch_error(fun, fails, in_tree, args_flat, api_name): arg_name = f'{list(sig.arguments.keys())[ak.idx]}{loc}' else: arg_name = '' - da = val.sharding._device_assignment if hasattr(val, 'sharding') else None - arg_list.append((arg_name, da, shaped_abstractify(val))) + arg_names.append(arg_name) + return arg_names + + +def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, + arg_names): + arg_list = [] + for a, n in safe_zip(args_flat, arg_names): + da = a.sharding._device_assignment if hasattr(a, 'sharding') else None + arg_list.append((n, da, shaped_abstractify(a))) - fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) if len(mismatched_args_msg) == 2: @@ -186,8 +193,10 @@ def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if params['resource_env'] is None else 'pjit' + arg_names = _get_arg_names(fun, in_tree, args_flat) + fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun, fails, in_tree, args_flat, api_name) + fun_name, fails, args_flat, api_name, arg_names) raise ValueError(msg) from None outs = tree_unflatten(out_tree, out_flat) return outs, out_flat, out_tree, args_flat diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 625de5eee536..72d335bee75f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2865,6 +2865,18 @@ def f(x): self.assertListEqual(sorted([d.id for d in out1.devices()]), [d.id for d in dev1]) + def test_device_assignment_mismatch_apply_primitive(self): + if jax.device_count() < 2: + self.skipTest("Requires >=2 devices.") + arr = jax.device_put(np.arange(8), jax.devices()[0]) + arr2 = jax.device_put(np.arange(8), jax.devices()[1]) + with self.assertRaisesRegex( + ValueError, + "Received incompatible devices for jitted computation. Got argument " + r"args\[0\] of concatenate with shape int.*\[8\].*and argument " + r"args\[1\].*"): + jnp.concatenate([arr, arr2]) + class TempSharding(Sharding): From ab916570abd3c4343c8089f2e6f3e4366b1c985e Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Tue, 21 Mar 2023 07:58:27 -0700 Subject: [PATCH 08/65] Fix inspect_array_sharding with grad. --- jax/_src/debugging.py | 2 +- tests/debugging_primitives_test.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index dff9813134fb..57ce4f3c7bbc 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -294,7 +294,7 @@ def _inspect_sharding_batching_rule(args, _, *, callback): _inspect_sharding_batching_rule) def _inspect_sharding_jvp_rule(primals, _, **params): - return inspect_sharding_p.bind(*primals, **params) + return inspect_sharding_p.bind(*primals, **params), [] ad.primitive_jvps[inspect_sharding_p] = _inspect_sharding_jvp_rule sharding_callbacks = weakref.WeakValueDictionary() # type: ignore diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index ae688dd8402d..9f070d632ffe 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -1184,14 +1184,20 @@ def _cb(sd): self.assertIsInstance(sd, jax.sharding.Sharding) self.assertLen(sd.device_set, 1) - def f(x): + def f_(x): debugging.inspect_array_sharding(x, callback=_cb) return jnp.square(x) - f = jax.jit(f) + f = jax.jit(f_) f(np.arange(8, dtype=jnp.int32)) self.assertTrue(is_called) + # Test in grad + is_called = False + f = jax.jit(jax.grad(lambda x: f_(x).sum())) + f(np.arange(8, dtype=jnp.float32)) + self.assertTrue(is_called) + if not rich: del VisualizeShardingTest From c0923124fa06d9bbd095ff678ab898b430dfce7e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 21 Mar 2023 10:12:26 -0700 Subject: [PATCH 09/65] Relax type annotations on lax slicing functions. PiperOrigin-RevId: 518308356 --- jax/_src/lax/slicing.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 89236b7d4797..719322bf23a2 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -111,7 +111,7 @@ def dynamic_slice( slice_sizes=tuple(static_sizes)) -def dynamic_update_slice(operand: Array, update: ArrayLike, +def dynamic_update_slice(operand: Union[Array, np.ndarray], update: ArrayLike, start_indices: Union[Array, Sequence[ArrayLike]]) -> Array: """Wraps XLA's `DynamicUpdateSlice `_ @@ -640,7 +640,7 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: ### convenience wrappers around traceables -def slice_in_dim(operand: Array, start_index: Optional[int], +def slice_in_dim(operand: Union[Array, np.ndarray], start_index: Optional[int], limit_index: Optional[int], stride: int = 1, axis: int = 0) -> Array: """Convenience wrapper around slice applying to only one dimension.""" @@ -669,7 +669,7 @@ def slice_in_dim(operand: Array, start_index: Optional[int], return slice(operand, start_indices, limit_indices, strides) -def index_in_dim(operand: Array, index: int, axis: int = 0, +def index_in_dim(operand: Union[Array, np.ndarray], index: int, axis: int = 0, keepdims: bool = True) -> Array: """Convenience wrapper around slice to perform int indexing.""" index, axis = core._canonicalize_dimension(index), int(axis) @@ -685,7 +685,8 @@ def index_in_dim(operand: Array, index: int, axis: int = 0, return lax.squeeze(result, (axis,)) -def dynamic_slice_in_dim(operand: Array, start_index: ArrayLike, +def dynamic_slice_in_dim(operand: Union[Array, np.ndarray], + start_index: ArrayLike, slice_size: int, axis: int = 0) -> Array: """Convenience wrapper around dynamic_slice applying to one dimension.""" start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim @@ -697,7 +698,8 @@ def dynamic_slice_in_dim(operand: Array, start_index: ArrayLike, return dynamic_slice(operand, start_indices, slice_sizes) -def dynamic_index_in_dim(operand: Array, index: Union[int, Array], +def dynamic_index_in_dim(operand: Union[Array, np.ndarray], + index: Union[int, Array], axis: int = 0, keepdims: bool = True) -> Array: """Convenience wrapper around dynamic_slice to perform int indexing.""" result = dynamic_slice_in_dim(operand, index, 1, axis) @@ -707,7 +709,8 @@ def dynamic_index_in_dim(operand: Array, index: Union[int, Array], return lax.squeeze(result, (axis,)) -def dynamic_update_slice_in_dim(operand: Array, update: ArrayLike, +def dynamic_update_slice_in_dim(operand: Union[Array, np.ndarray], + update: ArrayLike, start_index: ArrayLike, axis: int) -> Array: """Convenience wrapper around :func:`dynamic_update_slice` to update a slice in a single ``axis``. @@ -718,7 +721,8 @@ def dynamic_update_slice_in_dim(operand: Array, update: ArrayLike, return dynamic_update_slice(operand, update, start_indices) -def dynamic_update_index_in_dim(operand: Array, update: ArrayLike, index: ArrayLike, +def dynamic_update_index_in_dim(operand: Union[Array, np.ndarray], + update: ArrayLike, index: ArrayLike, axis: int) -> Array: """Convenience wrapper around :func:`dynamic_update_slice` to update a slice of size 1 in a single ``axis``. From 8e2e2f5d2ca1f7ad7eec82aae27bbf7a2e440200 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 21 Mar 2023 13:20:41 -0700 Subject: [PATCH 10/65] [jax2tf] Clean up the jax2tf sharding_tests. The sharding tests are very important because jax2tf must do non-trivial manipulation of shardings, e.g., to wrap the inputs and the outputs with sharding annotations. Merged two test classes ShardedJitHloTest and Sharding. One of the classes was just checking the annotations in the TF HLO, the other one was just running the code and comparing the results. Now we do both in one test. Refactored the code to log JAX HLO and to check the occurrences of annotations in TF HLO. Now we support checking that the occurrence count is equal to a value or greater or equal to a value. Added more annotation checking. This makes the test more informative (because there is no other good way to check that sharding was applied correctly). But this makes it also possible that the test will fail when we change the JAX lowering. PiperOrigin-RevId: 518362978 --- .../jax2tf/tests/sharding_test.py | 635 +++++++++--------- 1 file changed, 322 insertions(+), 313 deletions(-) diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 7d451b4c2289..4d8fab47c831 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -44,10 +44,10 @@ # Must come after initializing the flags from jax.experimental.jax2tf.tests import tf_test_util -from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation - prev_xla_flags = None + + def setUpModule(): global prev_xla_flags prev_xla_flags = os.getenv("XLA_FLAGS") @@ -70,224 +70,18 @@ def tearDownModule(): jtu.restore_spmd_lowering_flag() -def check_sharding_annotations(test, - f_jax, - args: Sequence[Any], - *, - num_replicas=1, - num_partitions=2, - num_variables=0, - native_serialization="default", - checks=()): - """Log the HLO generated from f_jax and its conversion. - - Ideally this would check the sharding of intermediate results in JAX and - TF, but this has turned out to be very brittle and broke down for - StableHLO lowering (the sharding annotation are now binary-encoded - attributes). We kept the logging aspect of this function, which should - help some debugging. - - Args: - checks: a list of tuples, with a regular expression and a number of times - it is expected to occur in the TF-generated HLO. - """ - if jtu.device_under_test() == "gpu": - raise unittest.SkipTest("Sharding HLO tests not useful for GPU") - - jax_comp = f_jax.lower(*args).compiler_ir(dialect="mhlo") - jax_hlo = str(jax_comp) - logging.info("[%s] got JAX HLO %s", test._testMethodName, jax_hlo) - - # We only dump JAX optimized code on the TPU - if jtu.device_under_test() == "tpu": - backend = xla_bridge.get_backend() - device_assignment = np.arange(num_partitions * num_replicas) - device_assignment = np.reshape(device_assignment, (-1, num_partitions)) - use_spmd_partitioning = num_partitions > 1 - compile_options = xla_bridge.get_compile_options( - num_replicas=num_replicas, - num_partitions=num_partitions, - device_assignment=device_assignment, - use_spmd_partitioning=use_spmd_partitioning, - ) - jax_optimized_hlo = backend.compile( - jax_hlo, compile_options).hlo_modules()[0].to_string() - logging.info("[%s] got JAX optimized HLO for platform %s %s", - test._testMethodName, backend.platform, jax_optimized_hlo) - - f_tf_base = jax2tf.convert(f_jax, with_gradient=False, - native_serialization=native_serialization) - if num_variables > 0: - args_vars = [tf.Variable(a) for a in args[:num_variables]] - args = args[:num_variables] - f_tf = lambda *inputs: f_tf_base(*args_vars, *inputs) - else: - f_tf = f_tf_base - f_tf_fun = tf.function(f_tf, jit_compile=True, autograph=False) - logging.info("[%s] Got TF graph %s", - test._testMethodName, - f_tf_fun.get_concrete_function(*args).graph.as_graph_def()) - device_name = f"/device:{jtu.device_under_test().upper()}:0" - tf_hlo_generator = f_tf_fun.experimental_get_compiler_ir(*args) - tf_hlo = tf_hlo_generator(stage="hlo", device_name=device_name) - logging.info("[%s] got TF HLO %s", test._testMethodName, tf_hlo) - tf_optimized_hlo = tf_hlo_generator(stage="optimized_hlo", - device_name=device_name) - logging.info("[%s] got TF optimized HLO for %s: %s", test._testMethodName, - device_name, tf_optimized_hlo) - - for check_re, check_count in checks: - count = len(re.findall(check_re, tf_hlo)) - test.assertEqual( - count, check_count, - (f"regular expression `{check_re}` expected to occur {check_count}" - f" but occurs {count} times in the TF HLO.\nThis is the TF HLO:\n{tf_hlo}")) - - -class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): +class ShardingTest(tf_test_util.JaxToTfTestCase): """Tests that inspect the HLO for the sharding annotations. - These tests can run on any device. - """ - - - @jtu.with_mesh([("x", 2)]) - def test_pjit_basic1D(self): - - @partial(pjit.pjit, - in_shardings=(P("x"), P("x")), - out_shardings=None) - def jax_func(x, y): - return x + y - - shape = (8, 10) - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - hlo = jax_func.lower(x, x).compiler_ir(dialect="hlo").as_hlo_text() - logging.info("HLO is %s", hlo) - logging.info("JAXPR is %s", jax.make_jaxpr(jax_func)(x, x)) - check_sharding_annotations(self, - jax_func, [x, x], - num_partitions=2) - - @jtu.with_mesh([("x", 2)]) - def test_pjit_basic1D_variable(self): - # The first argument is a tf.Variable - @partial(pjit.pjit, - in_shardings=(P("x"), P("x")), - out_shardings=None) - def jax_func(x, y): - return x + y - - shape = (8, 10) - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - hlo = jax_func.lower(x, x).compiler_ir(dialect="hlo").as_hlo_text() - logging.info("HLO is %s", hlo) - logging.info("JAXPR is %s", jax.make_jaxpr(jax_func)(x, x)) - check_sharding_annotations(self, - jax_func, [x, x], - num_partitions=2, - num_variables=1) - - @jtu.with_mesh([("x", 2), ("y", 2)]) - def test_pjit_basic2D(self): - @partial(pjit.pjit, - in_shardings=(P(None, "x", "y"), P("y")), - out_shardings=P("x")) - def jax_func(x, y): - return x @ y - - x_shape = (8, 6, 4) - y_shape = (4, 2) - x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) - y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape) - check_sharding_annotations(self, - jax_func, - [x, y], - num_partitions=4) - - @jtu.with_mesh([("x", 2), ("y", 2)]) - def test_pjit_TwoMeshAxisSharding(self): - @partial(pjit.pjit, - in_shardings=P(("x", "y"),), - out_shardings=P(("x", "y"),)) - def jax_func(x, y): - return x @ y - - x_shape = (24, 8) - y_shape = (8, 2) - x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) - y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape) - check_sharding_annotations(self, - jax_func, - [x, y], - num_partitions=4) - - @parameterized.named_parameters( - dict( testcase_name=f"_nested_pjit={nested_pjit}", nested_pjit=nested_pjit) - for nested_pjit in (True, False) - ) - @jtu.with_mesh([("x", 2), ("y", 1)]) - def test_pjit_ShardingConstraint(self, nested_pjit=True): - @partial(pjit.pjit, in_shardings=None, - out_shardings=None) - def jax_func(x): # x: f32[12, 8] - y = jnp.cos(x) - if nested_pjit: - y = pjit.pjit(lambda y: y, in_shardings=P("x", "y"), - out_shardings=P("x", "y"))(y) - else: - y = pjit.with_sharding_constraint(y, P("x", "y")) - return jnp.sin(y) # res: f32[6, 8] - - shape = (12, 8) - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - count_sharding = 2 if nested_pjit else 1 - check_sharding_annotations(self, - jax_func, [x], - num_partitions=2, - checks=[ - (r"custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_sharding), - (r"custom_call_target.*Sharding.*sharding.*replicated", 2), - ]) - - @parameterized.named_parameters( - dict( testcase_name=f"_nested_pjit={nested_pjit}", nested_pjit=nested_pjit) - for nested_pjit in (True, False) - ) - @jtu.with_mesh([("x", 2), ("y", 1)]) - def test_pjit_ShardingConstraintReplicated(self, nested_pjit=True): - shape = (12, 8) - @partial(pjit.pjit, in_shardings=(P("x", "y"),), - out_shardings=P("y", "x")) - def jax_func(x): - y = jnp.cos(x) - if nested_pjit: - y = pjit.pjit(lambda y: y, in_shardings=None, out_shardings=None)(y) - else: - y = pjit.with_sharding_constraint(y, None) - return jnp.sin(y) - - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - count_replicated = 2 if nested_pjit else 1 - check_sharding_annotations( - self, - jax_func, [x], - num_partitions=2, - checks=[ - (r"custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), - (r"custom_call_target.*Sharding.*sharding.*replicated", count_replicated), - (r"custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", 1) - ]) - - -class ShardingTest(tf_test_util.JaxToTfTestCase): - """ To verify that the tests do run indeed on multiple devices you can run perftools/gputools/profiler/jfprof.sh jax/experimental/jax2tf/tests:sharding_test_tpu -- -c opt --test_filter=ShardingTest.test_shmap_all_to_all --test_arg=--vmodule=jax2tf=3 -- """ def setUp(self): super().setUp() + if jtu.device_under_test() == "gpu": + raise unittest.SkipTest("Sharding HLO tests not useful for GPU") + if len(jax.devices()) < 2: raise unittest.SkipTest("Test requires at least 2 local devices") self.devices = np.array(jax.devices()[:2]) # use 2 devices @@ -299,74 +93,274 @@ def setUp(self): self.topology = tf.tpu.experimental.initialize_tpu_system(resolver) else: self.topology = None + def log_jax_hlo(self, f_jax, args: Sequence[Any], *, + num_replicas=1, num_partitions=2): + """Log the HLO generated from JAX before and after optimizations""" + jax_comp = f_jax.lower(*args).compiler_ir(dialect="mhlo") + jax_hlo = str(jax_comp) + logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo) + + # We only dump JAX optimized code on the TPU + if jtu.device_under_test() == "tpu": + backend = xla_bridge.get_backend() + device_assignment = np.arange(num_partitions * num_replicas) + device_assignment = np.reshape(device_assignment, (-1, num_partitions)) + use_spmd_partitioning = num_partitions > 1 + compile_options = xla_bridge.get_compile_options( + num_replicas=num_replicas, + num_partitions=num_partitions, + device_assignment=device_assignment, + use_spmd_partitioning=use_spmd_partitioning, + ) + jax_optimized_hlo = backend.compile( + jax_hlo, compile_options).hlo_modules()[0].to_string() + logging.info("[%s] got JAX optimized HLO for platform %s %s", + self._testMethodName, backend.platform, jax_optimized_hlo) def device_assignment(self, computation_shape=(1, 1, 1, 2), num_replicas=1): self.assertEqual(jtu.device_under_test(), "tpu") return tf.tpu.experimental.DeviceAssignment.build( - self.topology, computation_shape=computation_shape, - num_replicas=num_replicas) + self.topology, computation_shape=computation_shape, + num_replicas=num_replicas) + + def tf_hlo(self, f_tf, args_tf: Sequence[Any]) -> str: + """Get the unoptimized HLO from TF""" + f_tf_fun = tf.function(f_tf, autograph=False, jit_compile=True) + logging.info("[%s] Got TF graph %s", + self._testMethodName, + f_tf_fun.get_concrete_function(*args_tf).graph.as_graph_def()) + device_name = f"/device:{jtu.device_under_test().upper()}:0" + tf_hlo_generator = f_tf_fun.experimental_get_compiler_ir(*args_tf) + tf_hlo = tf_hlo_generator(stage="hlo", device_name=device_name) + logging.info("[%s] got TF HLO %s", self._testMethodName, tf_hlo) + tf_optimized_hlo = tf_hlo_generator(stage="optimized_hlo", + device_name=device_name) + logging.info("[%s] got TF optimized HLO for %s: %s", self._testMethodName, + device_name, tf_optimized_hlo) + # Before we check, we drop the metadata= at the end of tf_hlo + return re.sub(r'metadata=.*', '', tf_hlo) + + + def GEQ(self, value): + # Construct an expected >= value. See `check_sharding`. + return (">=", value) + + def check_sharding(self, f_tf, args_tf: Sequence[Any], *, + checks=()): + """Check the sharding in TF. + + Args: + f_tf: the TF callable + args_tf: the TF args + checks: a list of tuples. The first element is a regular expression, the + second element is an integer representing the expected number of + occurrences of the regular expression in the TF HLO. As a special case, + the second element can be the result of `self.GEQ(v)` to check that + the number of occurrences is greater or equal to a value. + """ + tf_hlo = self.tf_hlo(f_tf, args_tf) + for check_re, expected_count in checks: + count = len(re.findall(check_re, tf_hlo)) + if isinstance(expected_count, int): + self.assertEqual( + count, expected_count, + (f"regular expression `{check_re}` expected to occur " + f"{expected_count} times but occurs {count} times in " + f"the TF HLO.\nThis is the TF HLO:\n{tf_hlo}")) + elif isinstance(expected_count, tuple) and expected_count[0] == ">=": + self.assertGreaterEqual( + count, expected_count[1], + (f"regular expression `{check_re}` expected to occur " + f"at least {expected_count[1]} times but occurs {count} times in " + f"the TF HLO.\nThis is the TF HLO:\n{tf_hlo}")) + else: + assert False - def test_pjit_basic1D(self): - @partial(pjit.pjit, in_shardings=(P("x"),), - out_shardings=None) - def f_jax(a): - return a + a + @parameterized.named_parameters( + dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}", + in_shardings=in_shardings, out_shardings=out_shardings) + for in_shardings in ("missing", None, "P") + for out_shardings in ("missing", None, "P") + ) + @jtu.with_mesh([("x", 2)]) + def test_pjit_basic(self, in_shardings=None, out_shardings="missing"): + # Ensure that we can distinguish the inputs and outputs by shape + def f_jax(x): # f32[10,20] -> f32[20,10] + return jnp.sin(x.T) - shape = (8, 10) - a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + pjit_kwargs = {} + if in_shardings != "missing": + pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None) + if out_shardings != "missing": + pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None) + f_jax = pjit.pjit(f_jax, **pjit_kwargs) + + x_shape = (10, 20) + x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + + self.log_jax_hlo(f_jax, [x], num_partitions=2) @tf.function(autograph=False, jit_compile=True) - def f_tf(a): - f_converted = jax2tf.convert(f_jax, - native_serialization=True) + def f_tf(x): + f_converted = jax2tf.convert(f_jax) if jtu.device_under_test() == "tpu": - res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a)], + return tf.compat.v1.tpu.rewrite( + f_converted, [tf.convert_to_tensor(x)], device_assignment=self.device_assignment( computation_shape=[1, 1, 1, 2], ))[0] else: - res = f_converted(a) - return res + return f_converted(x) - with Mesh(self.devices, axis_names=("x",)): - res_jax = f_jax(a) - res_tf = f_tf(a) - self.assertAllClose(res_tf.numpy(), res_jax) + # Annotation count for the input + count_in_P = 1 if in_shardings == "P" else 0 + if config.jax2tf_default_native_serialization: + # With native serialization even unspecified in_shardings turn into replicated + count_in_replicated = 1 if in_shardings in [None, "missing"] else 0 + else: + count_in_replicated = 1 if in_shardings is None else 0 + # Annotation count for the output + count_out_P = 1 if out_shardings == "P" else 0 + count_out_replicated = 1 if out_shardings is None else 0 - def test_pjit_closed_over_const(self): + self.check_sharding( + jax2tf.convert(f_jax), [x], + checks=[ + # The argument + (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", + count_in_P), + (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", + count_in_replicated), + # The result + (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + count_out_P), + (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", + count_out_replicated), + # No other shardings + (r"custom_call_target.*Sharding", + count_in_P + count_in_replicated + count_out_P + count_out_replicated), + ]) + + res_jax = f_jax(x) + res_tf = f_tf(x) + self.assertAllClose(res_tf.numpy(), res_jax) - const = jnp.full((4, 3), 7, dtype=np.float32) - a = np.ones((4, 3), dtype=np.float32) - b = np.ones((1, 1), dtype=np.float32) - @partial(pjit.pjit, in_shardings=(P("x"), None), + @jtu.with_mesh([("x", 2)]) + def test_pjit_variable_arg(self): + # The first argument is a tf.Variable + @partial(pjit.pjit, in_shardings=(P(None, "x"), P("x", None)), out_shardings=None) - def f_jax(a, b): - return a + b * const + def f_jax(x, y): # f32[10,20] , f32[20,30] -> f32[10,30] + return x @ y + + shape_x = (10, 20) + x = np.arange(np.prod(shape_x), dtype=np.float32).reshape(shape_x) + shape_y = (20, 30) + y = np.arange(np.prod(shape_y), dtype=np.float32).reshape(shape_y) + + self.log_jax_hlo(f_jax, [x, y], num_partitions=2) + + x_v = tf.Variable(x) + f_tf = lambda y: jax2tf.convert(f_jax)(x_v, y) + + self.check_sharding( + f_tf, [y], + checks=[ + # The variable argument + (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", 1), + # The y argument + (r"f32\[20,30\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), + # The output sharding + (r"f32\[10,30\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + # No other annotations + (r"custom_call_target.*Sharding", 3) + ]) + + + @jtu.with_mesh([("x", 2)]) + def test_pjit_closed_over_const(self): + x = np.ones((10, 20), dtype=np.float32) + const = jnp.full((10, 20), 7, dtype=np.float32) + + @partial(pjit.pjit, in_shardings=(P("x"),), out_shardings=None) + def f_jax(x): # f32[10,20] -> f32[20,10] + return (x * const).T @tf.function(autograph=False, jit_compile=True) - def f_tf(a, b): - f_converted = jax2tf.convert(f_jax, native_serialization=True) + def f_tf(x): + f_converted = jax2tf.convert(f_jax) if jtu.device_under_test() == "tpu": - res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], + return tf.compat.v1.tpu.rewrite( + f_converted, [tf.convert_to_tensor(x)], device_assignment=self.device_assignment( computation_shape=[1, 1, 1, 2]) - )[0] + )[0] else: - res = f_converted(a, b) - return res + return f_converted(x) - with Mesh(self.devices, axis_names=("x",)): - res_jax = f_jax(a, b) - res_tf = f_tf(a, b) - self.assertAllClose(res_tf, res_jax) + self.check_sharding( + jax2tf.convert(f_jax), [x], + checks=[ + # x + (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + 1), + # The result + (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", + self.GEQ(1)), + ]) + + res_jax = f_jax(x) + res_tf = f_tf(x) + self.assertAllClose(res_tf, res_jax) + + @parameterized.named_parameters( + dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint=}", + nested_pjit=nested_pjit) + # We add a constraint either with a nested pjit or with a sharding_constraint + for nested_pjit in (True, False) + for constraint in (None, "P") + ) + @jtu.with_mesh([("x", 2)]) + def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P"): + constraint_sharding = P("x", None) if constraint == "P" else None + @partial(pjit.pjit, in_shardings=None, + out_shardings=None) + def f_jax(x): # x: f32[10, 20] + y = jnp.concatenate([x, x], axis=1) # y: f32[10, 40] + if nested_pjit: + y = pjit.pjit(lambda y: y, in_shardings=constraint_sharding, + out_shardings=constraint_sharding)(y) + else: + y = pjit.with_sharding_constraint(y, constraint_sharding) + return jnp.concatenate([y, y], axis=1) # res: f32[10, 80] + + shape = (10, 20) + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + self.log_jax_hlo(f_jax, [x], num_partitions=2) + f_tf = jax2tf.convert(f_jax) + + # If we use a pjit then we see two constraints, otherwise only 1 + count_inner_sharding = 2 if nested_pjit else 1 + self.check_sharding( + f_tf, [x], + checks=[ + # The input argument + (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + # The y argument + (r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + count_inner_sharding), + # The output sharding + (r"f32\[10,80\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + # No other annotations + (r"custom_call_target.*Sharding", 2 + count_inner_sharding) + ]) @parameterized.named_parameters( - dict( testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}", - kind=kind, in_shardings=in_shardings, out_shardings=out_shardings) + dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}", + kind=kind, in_shardings=in_shardings, out_shardings=out_shardings) for kind in ("pjit", "jit", "sharding_constraint") for in_shardings in ( ("none", "P") if kind == "sharding_constraint" else @@ -381,46 +375,46 @@ def test_pjit_error_inner_sharding(self, kind="pjit", in_shardings="P", # Check that we raise an error if there is no top-level pjit but we convert # a function with non-replicated shardings (with native lowering). shardings_map = dict(none=None, P=P("x")) - def f_jax(a): + + def f_jax(x): if kind == "pjit": pjit_kwargs = {} if in_shardings != "unspecified": pjit_kwargs["in_shardings"] = shardings_map[in_shardings] if out_shardings != "unspecified": pjit_kwargs["out_shardings"] = shardings_map[out_shardings] - res = pjit.pjit(lambda a: a * 2., **pjit_kwargs)(a) + res = pjit.pjit(lambda x: x * 2., **pjit_kwargs)(x) elif kind == "jit": - res = jax.jit(lambda a: a * 2.)(a) + res = jax.jit(lambda x: x * 2.)(x) elif kind == "sharding_constraint": - res = pjit.with_sharding_constraint(a * 2., shardings_map[in_shardings]) + res = pjit.with_sharding_constraint(x * 2., shardings_map[in_shardings]) else: assert False return res expect_error = (in_shardings == "P" or out_shardings == "P") shape = (8, 10) - a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - # TODO(necula): on TPU this function gets executed on the CPU device and - # fails the platform check. I turned off strict_checks to work around this! - f_tf = tf.function(jax2tf.convert(f_jax, native_serialization=True, - native_serialization_strict_checks=False), - autograph=False, jit_compile=False) - with Mesh(self.devices, axis_names=("x",)): - with contextlib.ExitStack() as stack: - if expect_error: - stack.enter_context(self.assertRaisesRegex(ValueError, - "Lowered function does not have a top-level pjit but it has non-replicated sharding annotations")) - f_tf(a) + f_tf = tf.function(jax2tf.convert(f_jax, native_serialization=True), + autograph=False, jit_compile=True) + with contextlib.ExitStack() as stack: + if expect_error: + stack.enter_context(self.assertRaisesRegex( + ValueError, + "Lowered function does not have a top-level pjit but it has non-replicated sharding annotations")) + with Mesh(self.devices, axis_names=("x",)): + f_tf(x) @parameterized.named_parameters( - dict( testcase_name=f"_func={func}", func=func) + dict(testcase_name=f"_func={func}", func=func) for func in ("pjit_sharded", "pjit_replicated", "nested_pjit_sharded", "nested_pjit_replicated") ) def test_pjit_eager_error(self, func="pjit_sharded"): if config.jax2tf_default_native_serialization: raise unittest.SkipTest("There is no error in eager mode for native serialization") + # Define some test functions @partial(pjit.pjit, in_shardings=(P("x"),), out_shardings=None) @@ -461,48 +455,59 @@ def f_nested_pjit_replicated(a): def test_xmap_basic(self): devices = np.reshape(self.devices, (1, 2)) - - f_jax = xmap(lambda a, b: (a * 2, b * 4), + ashape = (16, 8, 5) + a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) + bshape = (2, 7) + b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) + + # f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28] + # f_jax: f32[5], f32[7] -> f32[10], f32[28] + f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2., + jnp.concatenate([b, b, b, b], axis=0) * 4.), in_axes=({0: 'a', 1: 'b'}, ['c', ...]), out_axes=({0: 'a', 1: 'b'}, ['c', ...]), axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) @tf.function(autograph=False, jit_compile=True) def f_tf(a, b): + # xmap works only with native serialization f_converted = jax2tf.convert(f_jax, native_serialization=True) if jtu.device_under_test() == "tpu": res = tf.compat.v1.tpu.rewrite( f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], device_assignment=self.device_assignment( computation_shape=[1, 1, 1, 2]) - ) - res = (res[0], res[1]) + ) + return (res[0], res[1]) else: - res = f_converted(a, b) - return res + return f_converted(a, b) with Mesh(devices, ('x', 'y')): - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape)).reshape(ashape) - bshape = (2, 7) - b = np.arange(np.prod(bshape)).reshape(bshape) - res_jax = f_jax(a, b) - self.assertAllClose(res_jax, (a * 2, b * 4)) - - # jax2tf for xmap works only with native lowering - check_sharding_annotations(self, f_jax, [a, b], - native_serialization=True) + self.assertAllClose(res_jax, (jnp.concatenate([a, a], axis=2) * 2., + jnp.concatenate([b, b, b, b], axis=1) * 4.)) res_tf = f_tf(a, b) self.assertAllClose(res_tf, res_jax) + self.check_sharding( + jax2tf.convert(f_jax, native_serialization=True), [a, b], + checks=[ + (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1), + # The output sharding + (r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[2,28\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + ]) + def test_xmap_collective_reduce(self): devices = np.reshape(self.devices, (1, 2)) - - f_jax = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4), - in_axes=(['a', 'b', ...], {0: 'c'}), - out_axes=(['b', ...], {0: 'c'}), - axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) + ashape = (16, 8, 5) + a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) + bshape = (2, 7) + b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) + f_jax = xmap(lambda a, b: (lax.psum(a * 2., 'a'), b * 4.), + in_axes=(['a', 'b', ...], {0: 'c'}), + out_axes=(['b', ...], {0: 'c'}), + axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) @tf.function(autograph=False, jit_compile=True) def f_tf(a, b): @@ -512,31 +517,32 @@ def f_tf(a, b): f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], device_assignment=self.device_assignment( computation_shape=[1, 1, 1, 2]) - ) - res = (res[0], res[1]) + ) + return (res[0], res[1]) else: - res = f_converted(a, b) - return res + return f_converted(a, b) with Mesh(devices, ('x', 'y')): - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape)).reshape(ashape) - bshape = (2, 7) - b = np.arange(np.prod(bshape)).reshape(bshape) res_jax = f_jax(a, b) - self.assertAllClose(res_jax, ((a * 2).sum(0), b * 4)) - - check_sharding_annotations(self, f_jax, [a, b], - native_serialization=True) + self.assertAllClose(res_jax, ((a * 2.).sum(0), b * 4.)) res_tf = f_tf(a, b) self.assertAllClose(res_tf, res_jax) + self.check_sharding( + jax2tf.convert(f_jax, native_serialization=True), [a, b], + checks=[ + (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1), + (r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 2), + (r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), + ]) @jtu.ignore_warning(category=UserWarning, message="all_to_all .* are only implemented properly for TPUs and GPUs .*") def test_shmap_all_to_all(self): if jtu.device_under_test() == "cpu": raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash") + mesh = Mesh(self.devices, axis_names=('x')) + a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4)) @partial(pjit.pjit, in_shardings=(P('x', None),), out_shardings=P(None, 'x')) @@ -549,17 +555,15 @@ def f_jax(b): # b: f32[2, 4] def f_tf(a): f_converted = jax2tf.convert(f_jax, native_serialization=True) if jtu.device_under_test() == "tpu": - res = tf.compat.v1.tpu.rewrite( + return tf.compat.v1.tpu.rewrite( f_converted, [tf.convert_to_tensor(a)], device_assignment=self.device_assignment( computation_shape=[1, 1, 1, 2]) - )[0] + )[0] else: - res = f_converted(a) - return res + return f_converted(a) with mesh: - a = np.arange(np.prod(4 * 4)).reshape((4, 4)) res_jax = f_jax(a) # res: f32[2, 8] b0, b1 = np.split(a, 2, axis=0) # The shard_map in_specs splits on axis 0 b00, b01 = np.split(b0, 2, axis=1) # split_axis=1 @@ -571,6 +575,11 @@ def f_tf(a): res_tf = f_tf(a) self.assertAllClose(res_tf, res_jax) + # TODO(b/274648842): Failed to GetCompilerIr + # self.check_sharding( + # jax2tf.convert(f_jax, native_serialization=True), [a], + # checks=[]) + @unittest.skip("TODO(b/268295912): ShardingRemover crash") def test_repro_xla_bug_shmap_collective_permute(self): mesh = Mesh(self.devices, axis_names=('x')) @@ -592,9 +601,6 @@ def f_jax(b): # b: f32[2, 4] expected = np.concatenate([b0, b1], axis=0) # out_specs concatenates on axis 0 self.assertAllClose(res_jax, expected) - check_sharding_annotations(self, f_jax, [a], - native_serialization=True, - num_partitions=2, num_replicas=1) # XLA bug: invoke the f_tf without tpu.replicate f_tf = tf.function( jax2tf.convert(f_jax, native_serialization=True), @@ -607,6 +613,7 @@ def test_shmap_collective_permute(self): if jtu.device_under_test() == "cpu": raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash") mesh = Mesh(self.devices, axis_names=('x')) + a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4)) @partial(pjit.pjit, in_shardings=(P('x', None),), out_shardings=P('x', None)) @@ -625,13 +632,12 @@ def f_tf(a): f_converted, [tf.convert_to_tensor(a)], device_assignment=self.device_assignment( computation_shape=[1, 1, 1, 2]) - )[0] + )[0] else: res = f_converted(a) return res with mesh: - a = np.arange(np.prod(4 * 4)).reshape((4, 4)) res_jax = f_jax(a) b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0 b0, b1 = b1, b0 @@ -639,7 +645,10 @@ def f_tf(a): self.assertAllClose(res_jax, expected) res_tf = f_tf(a) self.assertAllClose(res_tf, expected) - + # TODO(b/274648842): Failed to GetCompilerIr + # self.check_sharding( + # jax2tf.convert(f_jax, native_serialization=True), [a], + # checks=[]) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 41c1d93686d25c57c907c69efb9cc8dab032b252 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 21 Mar 2023 13:41:42 -0700 Subject: [PATCH 11/65] Remove the config.jax_array and jax_jit_pjit_api_merge flag usage since those are always True PiperOrigin-RevId: 518368963 --- tests/api_test.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/api_test.py b/tests/api_test.py index d7bc434fb515..8a3baf8ee048 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1130,9 +1130,6 @@ def f(x, y, *args, **kwargs): @parameterized.parameters([0, 2, [(0, 2)]]) def test_jit_lower_arg_info_static_argnums(self, static_argnums): - if not config.jax_array or not jax.config.jax_jit_pjit_api_merge: - raise unittest.SkipTest("test only applies after jit-pjit api merge") - def f(x, y, *args, **kwargs): return y['hi'] + args[1] + sum(kwargs.values()) @@ -1148,9 +1145,6 @@ def f(x, y, *args, **kwargs): @parameterized.parameters(['a', 'b', [('a', 'b')]]) def test_jit_lower_arg_info_static_argnames(self, static_argnames): - if not config.jax_array or not jax.config.jax_jit_pjit_api_merge: - raise unittest.SkipTest("test only applies after jit-pjit api merge") - def f(x, y, *args, **kwargs): return y['hi'] + args[1] + kwargs['z'] + kwargs['w'] From 8717494afbed95e648178fa05ca7d5cf007d3531 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 21 Mar 2023 13:53:20 -0700 Subject: [PATCH 12/65] Document ShapeDtypeStruct --- docs/jax.rst | 1 + jax/_src/api.py | 41 +++++++++++++++++++---------------------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/docs/jax.rst b/docs/jax.rst index b3ec03a58e55..51ce45d44815 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -69,6 +69,7 @@ Just-in-time compilation (:code:`jit`) xla_computation make_jaxpr eval_shape + ShapeDtypeStruct device_put device_put_replicated device_put_sharded diff --git a/jax/_src/api.py b/jax/_src/api.py index 8476fe6acce7..c614d0914eb8 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2704,6 +2704,16 @@ def device_get(x: Any): class ShapeDtypeStruct: + """A container for the shape, dtype, and other static attributes of an array. + + ``ShapeDtypeStruct`` is often used in conjunction with :func:`jax.eval_shape`. + + Args: + shape: a sequence of integers representing an array shape + dtype: a dtype-like object + named_shape: (optional) a dictionary representing a named shape + sharding: (optional) a :class:`jax.Sharding` object + """ __slots__ = ["shape", "dtype", "named_shape", "sharding"] def __init__(self, shape, dtype, named_shape=None, sharding=None): self.shape = tuple(shape) @@ -2764,20 +2774,9 @@ def eval_shape(fun: Callable, *args, **kwargs): def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) + shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.tree_util.tree_map(shape_dtype_struct, out) - def shape_dtype_struct(x): - return ShapeDtypeStruct(x.shape, x.dtype) - - class ShapeDtypeStruct: - __slots__ = ["shape", "dtype"] - def __init__(self, shape, dtype): - self.shape = shape - self.dtype = dtype - - In particular, the output is a pytree of objects that have ``shape`` and - ``dtype`` attributes, but nothing else about them is guaranteed by the API. - But instead of applying ``fun`` directly, which might be expensive, it uses JAX's abstract interpretation machinery to evaluate the shapes without doing any FLOPs. @@ -2790,26 +2789,24 @@ def __init__(self, shape, dtype): *args: a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the ``shape`` and ``dtype`` attributes are - accessed, only values that duck-type arrays are required, rather than real - ndarrays. The duck-typed objects cannot be namedtuples because those are - treated as standard Python containers. See the example below. + accessed, one can use :class:`jax.ShapeDtypeStruct` or another container + that duck-types as ndarrays (note however that duck-typed objects cannot + be namedtuples because those are treated as standard Python containers). **kwargs: a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in ``args``, array values need only be duck-typed to have ``shape`` and ``dtype`` attributes. + Returns: + out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves. + For example: >>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) - >>> class MyArgArray(object): - ... def __init__(self, shape, dtype): - ... self.shape = shape - ... self.dtype = jnp.dtype(dtype) - ... - >>> A = MyArgArray((2000, 3000), jnp.float32) - >>> x = MyArgArray((3000, 1000), jnp.float32) + >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) + >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) From fa19118ebaa6ac7530c11f6ae0b40dc980795a18 Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Tue, 21 Mar 2023 16:25:43 -0700 Subject: [PATCH 13/65] [jax2tf] Add back_compat test for LuDecomposition PiperOrigin-RevId: 518412422 --- jax/experimental/jax2tf/jax2tf.py | 8 +++ .../jax2tf/tests/back_compat_test.py | 14 ++++- .../tests/back_compat_testdata/tpu_Lu.py | 57 +++++++++++++++++++ .../jax2tf/tests/primitives_test.py | 4 -- 4 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Lu.py diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 099745882c27..476fe65f0afa 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -119,6 +119,14 @@ "cusolver_geqrf", "cusolver_orgqr", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", + # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU + # # lu on CPU + # "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf", + # # lu on GPU + # "cublas_getrf_batched", "cusolver_getrf", + # "hipblas_getrf_batched", "hipsolver_getrf", + # lu on TPU + "LuDecomposition", ] def _sanitize_scope_name(name): diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index 8d52ae6cf275..84f69ea376c0 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -86,6 +86,7 @@ def func(...): ... from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_cusolver_syev from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_threefry2x32 from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Eigh +from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Lu from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Qr from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Sharding @@ -315,7 +316,7 @@ def test_custom_call_coverage(self): cpu_ducc_fft.data_2023_03_17, cpu_lapack_syev.data_2023_03_17, cpu_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15, cuda_cusolver_geqrf.data_2023_03_18, cuda_cusolver_syev.data_2023_03_17, - tpu_Eigh.data, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16] + tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16] covering_testdatas = itertools.chain( *[load_testdata_nested(d) for d in covering_testdatas]) covered_targets = set() @@ -419,6 +420,17 @@ def test_tpu_Qr(self): data = load_testdata(tpu_Qr.data_2023_03_17) self.run_one_test(func, data, rtol=1e-3) + @staticmethod + def lu_harness(shape, dtype): + operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape) + return lax.linalg.lu(operand) + + def test_tpu_Lu(self): + # For lax.linalg.lu + func = lambda: CompatTest.lu_harness((3, 3), np.float32) + data = load_testdata(tpu_Lu.data_2023_03_21) + self.run_one_test(func, data, rtol=1e-3) + def test_cu_threefry2x32(self): def func(x): return jax.random.uniform(x, (2, 4), dtype=np.float32) diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Lu.py b/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Lu.py new file mode 100644 index 000000000000..1c8a6b072872 --- /dev/null +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/tpu_Lu.py @@ -0,0 +1,57 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +import datetime +from numpy import array, float32, int32 + +# Pasted from the test output (see back_compat_test.py module docstring) +data_2023_03_21 = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['LuDecomposition'], + serialized_date=datetime.date(2023, 3, 21), + inputs=(), + expected_outputs=(array([[6. , 7. , 8. ], + [0. , 1. , 2. ], + [0.5, 0.5, 0. ]], dtype=float32), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=""" +module @jit__lambda_ { + func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xf32> + %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> + %2:3 = call @lu(%1) : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xi32>, tensor<3xi32>) + return %2#0, %2#1, %2#2 : tensor<3x3xf32>, tensor<3xi32>, tensor<3xi32> + } + func.func private @lu(%arg0: tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xi32>, tensor<3xi32>) { + %0 = call @xla_fallback_lu(%arg0) : (tensor<3x3xf32>) -> tuple, tensor<3xi32>, tensor<3xi32>> + %1 = stablehlo.get_tuple_element %0[0] : (tuple, tensor<3xi32>, tensor<3xi32>>) -> tensor<3x3xf32> + %2 = stablehlo.get_tuple_element %0[1] : (tuple, tensor<3xi32>, tensor<3xi32>>) -> tensor<3xi32> + %3 = stablehlo.get_tuple_element %0[2] : (tuple, tensor<3xi32>, tensor<3xi32>>) -> tensor<3xi32> + return %1, %2, %3 : tensor<3x3xf32>, tensor<3xi32>, tensor<3xi32> + } + func.func private @xla_fallback_lu(%arg0: tensor<3x3xf32>) -> tuple, tensor<3xi32>, tensor<3xi32>> { + %0 = stablehlo.custom_call @LuDecomposition(%arg0) {xla_shape = "(f32[3,3]{1,0}, s32[3]{0}, s32[3]{0})"} : (tensor<3x3xf32>) -> tuple, tensor<3xi32>, tensor<3xi32>> + %1 = stablehlo.get_tuple_element %0[0] : (tuple, tensor<3xi32>, tensor<3xi32>>) -> tensor<3x3xf32> + %2 = stablehlo.get_tuple_element %0[1] : (tuple, tensor<3xi32>, tensor<3xi32>>) -> tensor<3xi32> + %3 = stablehlo.get_tuple_element %0[2] : (tuple, tensor<3xi32>, tensor<3xi32>>) -> tensor<3xi32> + %4 = stablehlo.tuple %1, %2, %3 {xla_shape = "(f32[3,3]{1,0}, s32[3]{0}, s32[3]{0})"} : tuple, tensor<3xi32>, tensor<3xi32>> + return %4 : tuple, tensor<3xi32>, tensor<3xi32>> + } +} +""", + mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01!\x05\x01\x05\x01\x03\x05\x03\x11\x07\t\x0b\r\x0f\x11\x13\x15\x03\xbd\x99\x15\x01e\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x0b\x17\x13\x0b33\x0b\x173S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0f\x0b\x13\x13\x0b\x0f\x0b\x0f\x0b\x13\x035\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x0f\x0f\x03\x15\x13\x17\x07\x17\x07\x1b\x1f\x17\x13\x07\x02f\x04\x1f\x1d')\x05\x17\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\x0f\x91\x03\x03\x0f\x93\x03\x03\x0f\x95\x05'\x17\x11\xaa\x06\x01\x03\x03\x05!\x05)\x03\x0b\x07e\tq\x0bs\x05\x81\r\x83\x03\x0b\x07e\t\x85\x0be\x05i\rk\x05+\x17\x11\xae\x06\x01\x03\x0b\x07e\t\x87\x0be\x05m\rk\x03\x13/\x891\x8b3\x8d5e7\x8f9e;e=e\x13o\x05-\x05/\x051\x053\x055\x057\x059\x05;\x1dA\x01\x05=\x1dE\x01\x05?\x1dI\x01\x05A\x1dM\x01\x05C\x03\x03\x13o\x1dS\x01\x05E\x03\x03\x1bm\x03\x03Y\x97\x05G\x1d]\x1d\x05I\x1da\x1d\x05K\x03\x03\x1bi\x03\x01\x1dM\x1dO\x1dQ\x1dS\x1dU#\x0b\x03\x07uy}\r\x03gw\x1dW\r\x03g{\x1dY\r\x03g\x7f\x1d[\x1d]\x1d_#\r#\x0f\x0b\x03\x1da\x1dc\x05\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x13\x01)\x03\r\x05)\x05\r\r\t\x1b/\x07\x03\x01\x01\t\x11\x01\x07\x03\x01\x01\x11\x03\x03\x07\x03\x01\x01\x11\x03\x03\x03\x07)\x03%\t\x1d\x04n\x02\x05\x01\x11\x01\x1f\x07\x03\x01\r\x05\x11\x01#\x05\x03\x0b\x11\x0f\x03[W\x03\x11\x11\x06_\x03\x03\x03\x01\t\x07\x03c\x07\x03\x01\x01\x03\x03\x07\x04\x01\x07\x05\x07\t\x05\x11\x03%\x05\x03\x0b\x17\x03\x03\x01\t\x07\x03U\x03\x07\x03\x01\x03\x07\x03\x15\x03\x03\x03\x03\x03\x07\x03\x17\x03\x01\x03\x03\x03\x07\x03\x19\x03\x01\x03\x03\x07\x04\x03\x07\x05\x07\t\x05\x11\x01+\x05\x03\r\x1b\x03\x03\x01\x0b\x07?-\x03\x07\x03\x01\x03\x07C\x15\x03\x03\x03\x03\x03\x07G\x17\x03\x01\x03\x03\x03\x07K\x19\x03\x01\x03\x03\r\x07QO\x03\x07\x07\x05\x07\t\x07\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x06\re!\x03\x0f\x0b\t\t\tM!\x11\x07!\x85\x87\x1f\x11)))\x1d\x1f/!!)#\x1f\x197\x1b\x0f\x15\x83\r\x1f\x15\x1d\x15\x13\x17\x11\x13\x1f\x11\x15\x11+\x0f\x0b\x11builtin\x00vhlo\x00module\x00get_tuple_element_v1\x00func_v1\x00return_v1\x00call_v1\x00custom_call_v1\x00tuple_v1\x00iota_v1\x00reshape_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00index\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00xla_shape\x00callee\x00jit__lambda_\x00jit()/jit(main)/lu\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00custom-call.2\x00get-tuple-element.3\x00get-tuple-element.4\x00get-tuple-element.5\x00tuple.6\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jax.result_info\x00lu\x00private\x00xla_fallback_lu\x00(f32[3,3]{1,0}, s32[3]{0}, s32[3]{0})\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00LuDecomposition\x00", + xla_call_module_version=4, +) # End paste diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index c1e64c57a0da..449d36de8b5e 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -146,10 +146,6 @@ def skipCustomCallTest(target: str): skipCustomCallTest("lapack_sgetrf, lapack_dgetrf") elif device == "tpu": - if "lu_shape" in harness.fullname: - skipCustomCallTest("LuDecomposition") - if "custom_linear_solve_" in harness.fullname: - skipCustomCallTest("LuDecomposition") if "approx_top_k_large=True" in harness.fullname: skipCustomCallTest("PartialReduce") # ApproxTopK From ddab64a581a67f87ebb7d82523617487cdd0be92 Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Tue, 21 Mar 2023 16:48:57 -0700 Subject: [PATCH 14/65] LuDecomposition moved from fallback path to custom_call PiperOrigin-RevId: 518417901 --- jax/_src/lax/linalg.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index f36cca56a330..3b95c4136cca 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1212,8 +1212,21 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand): return [lu, pivot, perm] -def _lu_tpu_translation_rule(ctx, avals_in, avals_out, operand): - return xops.LU(operand) +def _lu_tpu_lowering_rule(ctx, operand): + + op = hlo.CustomCallOp( + [ir.TupleType.get_tuple([mlir.aval_to_ir_type(ctx.avals_out[0]), + mlir.aval_to_ir_type(ctx.avals_out[1]), + mlir.aval_to_ir_type(ctx.avals_out[2])])], + [operand], + call_target_name=ir.StringAttr.get("LuDecomposition"), + has_side_effect=ir.BoolAttr.get(False), + ) + return ( + hlo.GetTupleElementOp(op, 0).result, + hlo.GetTupleElementOp(op, 1).result, + hlo.GetTupleElementOp(op, 2).result, + ) lu_p = Primitive('lu') @@ -1235,7 +1248,7 @@ def _lu_tpu_translation_rule(ctx, avals_in, avals_out, operand): lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf), platform='rocm') -xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu') +mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu') @partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)') From 48db6c80a953f4ab27f83f79f6bc28767ac2d26d Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 21 Mar 2023 16:52:49 -0700 Subject: [PATCH 15/65] [PJRT C API] Add parsing PJRT client create options from json file. PiperOrigin-RevId: 518418760 --- jax/_src/xla_bridge.py | 54 ++++++++++++++++--- tests/BUILD | 1 + .../testdata/example_pjrt_plugin_config.json | 9 ++++ tests/xla_bridge_test.py | 38 +++++++++++-- 4 files changed, 92 insertions(+), 10 deletions(-) create mode 100644 tests/testdata/example_pjrt_plugin_config.json diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 2e0f9a1a2c56..ad814b9c7a01 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -20,11 +20,13 @@ """ from functools import partial, lru_cache +import io +import json import logging import os import platform as py_platform import threading -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import warnings import numpy as np @@ -286,19 +288,59 @@ def _get_pjrt_plugin_names_and_library_paths( return pjrt_plugins +def _get_pjrt_plugin_config( + json_path: str, +) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]: + """Gets PJRT plugin configuration from a json file. + + The json file needs to have a "library_path" field for the plugin library + path. It can have an optional "create_option" field for the options used when + creating a PJRT plugin client. The value of "create_option" is key-value + pairs. Please see xla_client._NameValueMapping for the supported types of + values. + """ + with io.open(json_path, 'r') as f: + config = json.load(f) + if 'library_path' not in config.keys(): + raise ValueError( + 'PJRT plugin config file should contain "library_path" field.' + ) + return (config['library_path'], config.get('create_options')) + + def register_pjrt_plugin_factories(plugins_from_env: str) -> None: """Registers backend factories for PJRT plugins. A backend factory will be registered for every PJRT plugin in the input string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' - for windows). TPU PJRT plugin will be loaded and registered separately in - make_tpu_client. + for windows). The path can be a path to the plugin library or a path to the + plugin configuration json file. The json file needs to have a "library_path" + field for the plugin library path. It can have an optional "create_option" + field for the options used when creating a PJRT plugin client. The value of + "create_option" is key-value pairs. Please see xla_client._NameValueMapping + for the supported types of values. + + TPU PJRT plugin will be loaded and registered separately in make_tpu_client. """ - def make_factory(name, path): + def make_factory(name: str, path: str): def factory(): - xla_client.load_pjrt_plugin_dynamically(name, path) - return xla_client.make_c_api_client(name) + if path.endswith('.json'): + library_path, options = _get_pjrt_plugin_config(path) + else: + library_path = path + options = None + + xla_client.load_pjrt_plugin_dynamically(name, library_path) + if lib.xla_extension_version >= 134: + return xla_client.make_c_api_client(name, options) + else: + if options: + raise ValueError( + 'Setting PJRT plugin options through json file requires' + ' jaxlib.xla_extension_version >= 134.' + ) + return xla_client.make_c_api_client(name) return factory diff --git a/tests/BUILD b/tests/BUILD index fde4db9e6cd9..dd692e7e12a1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -848,6 +848,7 @@ py_test( py_test( name = "xla_bridge_test", srcs = ["xla_bridge_test.py"], + data = ["testdata/example_pjrt_plugin_config.json"], deps = [ "//jax", "//jax:test_util", diff --git a/tests/testdata/example_pjrt_plugin_config.json b/tests/testdata/example_pjrt_plugin_config.json new file mode 100644 index 000000000000..2a195727c38e --- /dev/null +++ b/tests/testdata/example_pjrt_plugin_config.json @@ -0,0 +1,9 @@ +{ + "library_path": "/path/pjrt_plugin_name1.so", + "create_options": { + "int_option": 64, + "int_list_option": [32, 64], + "string_option": "string", + "float_option": 1.0 + } +} diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index db35aeac1f37..468cd1e4bb0f 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import time import warnings @@ -90,9 +91,6 @@ def _mock_tpu_client(): xb.tpu_client_timer_callback(0.01) def test_register_plugin(self): - if xc._version < 126: - return - with self.assertLogs(level="WARNING") as log_output: xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3") client_factory, priotiy = xb._backend_factories["name1"] @@ -111,7 +109,39 @@ def test_register_plugin(self): self.assertIn("name2", xb._backend_factories) self.assertEqual(priotiy, 400) mock_load_plugin.assert_called_once_with("name1", "path1") - mock_make.assert_called_once_with("name1") + if xc._version >= 134: + mock_make.assert_called_once_with("name1", None) + else: + mock_make.assert_called_once_with("name1") + + def test_register_plugin_with_config(self): + if xc._version < 134: + return + test_json_file_path = os.path.join( + os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json" + ) + xb.register_pjrt_plugin_factories(f"name1:{test_json_file_path}") + client_factory, priority = xb._backend_factories["name1"] + with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: + with mock.patch.object( + xc, "load_pjrt_plugin_dynamically", autospec=True + ) as mock_load_plugin: + client_factory() + + self.assertIn("name1", xb._backend_factories) + self.assertEqual(priority, 400) + mock_load_plugin.assert_called_once_with( + "name1", "/path/pjrt_plugin_name1.so" + ) + mock_make.assert_called_once_with( + "name1", + { + "int_option": 64, + "int_list_option": [32, 64], + "string_option": "string", + "float_option": 1.0, + }, + ) class GetBackendTest(jtu.JaxTestCase): From ffc8a3477e4de31f29ca77f3ec01f6ce339769bf Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 14 Mar 2023 05:05:49 +0000 Subject: [PATCH 16/65] `custom_vjp` symbolic zeros support --- jax/_src/checkify.py | 11 +- jax/_src/core.py | 3 +- jax/_src/custom_derivatives.py | 115 +++++++++++++------ jax/_src/interpreters/ad.py | 31 +++-- jax/_src/interpreters/batching.py | 29 +++-- jax/_src/interpreters/pxla.py | 7 +- jax/experimental/jax2tf/jax2tf.py | 5 +- jax/interpreters/partial_eval.py | 47 +++++--- tests/api_test.py | 183 +++++++++++++++++++++++++++++- 9 files changed, 350 insertions(+), 81 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 80ac676d9d23..e250e240c18f 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -915,7 +915,8 @@ def jvp(*xs): error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, - fwd_jaxpr_thunk, num_consts, bwd, out_trees): + fwd_jaxpr_thunk, num_consts, bwd, out_trees, + symbolic_zeros): err_vals, err_tree = jtu.tree_flatten(in_err) fun = lu.wrap_init( functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, @@ -923,15 +924,17 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun) @lu.wrap_init - def fwd(*xs): + def fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() + xs, zeros = args[::2], args[1::2] + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) fwd, fwd_out_tree = flatten_fun_output(fwd) all_outs = custom_derivatives.custom_vjp_call_p.bind( - fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees) + fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) if fst: err_and_out_tree, _ = out_metadata diff --git a/jax/_src/core.py b/jax/_src/core.py index 95d2f38e4d4a..d16e0f4baef2 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -511,7 +511,8 @@ def process_custom_transpose(self, prim, call, tracers, **params): "to handle custom_transpose_call primitives") raise NotImplementedError(msg) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, + out_trees, symbolic_zeros): msg = (f"{type(self)} must override process_custom_vjp_call " "to handle custom_vjp primitives") raise NotImplementedError(msg) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index e93132a7806f..66c527081001 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses from functools import update_wrapper, reduce, partial import inspect from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any) @@ -31,8 +32,8 @@ from jax._src import effects from jax._src import linear_util as lu from jax._src import traceback_util -from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval, - stop_gradient_p) +from jax._src.ad_util import ( + stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import argnums_partial, flatten_fun_nokwargs from jax._src.core import raise_to_shaped from jax._src.interpreters import ad @@ -163,12 +164,13 @@ def defjvp(self, and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof. symbolic_zeros: boolean, indicating whether the rule should be passed - objects representing static symbolic zeros in its tangent tuple - argument; otherwise, only standard JAX types (e.g. array-likes) are - passed. Setting this option to True allows a JVP rule to detect whether - certain inputs are not involved in differentiation, but at the cost of - needing special handling for these objects (which e.g. can't be passed - into jax.numpy functions). Default False. + objects representing static symbolic zeros in its tangent argument in + correspondence with unperturbed values; otherwise, only standard JAX + types (e.g. array-likes) are passed. Setting this option to ``True`` + allows a JVP rule to detect whether certain inputs are not involved in + differentiation, but at the cost of needing special handling for these + objects (which e.g. can't be passed into jax.numpy functions). Default + ``False``. Returns: None. @@ -480,12 +482,15 @@ def __init__(self, self.nondiff_argnums = nondiff_argnums self.fwd: Optional[Callable[..., Tuple[ReturnValue, Any]]] = None self.bwd: Optional[Callable[..., Tuple[Any, ...]]] = None + self.symbolic_zeros = False __getattr__ = custom_api_util.forward_attr def defvjp(self, fwd: Callable[..., Tuple[ReturnValue, Any]], - bwd: Callable[..., Tuple[Any, ...]]) -> None: + bwd: Callable[..., Tuple[Any, ...]], + symbolic_zeros: bool = False, + ) -> None: """Define a custom VJP rule for the function represented by this instance. Args: @@ -506,6 +511,27 @@ def defvjp(self, function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments. + symbolic_zeros: boolean, indicating whether to indicate symbolic zeros in + the ``fwd`` and ``bwd`` rules. Setting this option to ``True`` allows + custom derivative rules to detect when certain inputs, and when certain + cotangent outputs, are not involved in differentiation. If ``True``: + + * ``fwd`` must accept, for each leaf value ``x`` in the pytree + comprising an argument to the original function, a pair ``(x, zero)``, + where ``x`` is the original argument and ``zero`` is a boolean. The + ``zero`` part indicates whether or not the argument is not involved in + differentiation (i.e., whether the corresponding Jacobian "column" is + zero). + + * ``bwd`` will be passed objects representing static symbolic zeros in + its cotangent argument in correspondence with unperturbed values; + otherwise, only standard JAX types (e.g. array-likes) are passed. + + Setting this option to ``True`` allows these rules to detect whether + certain inputs and outputs are not involved in differentiation, but at + the cost of special handling: the signature of ``fwd`` changes, and + ``bwd`` receives objects that, for instance, cannot be passed to + ``jax.numpy`` functions. Default ``False``. Returns: None. @@ -527,6 +553,7 @@ def f_bwd(res, g): """ self.fwd = fwd self.bwd = bwd + self.symbolic_zeros = symbolic_zeros @traceback_util.api_boundary def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation @@ -534,7 +561,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable if not self.fwd or not self.bwd: msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." raise AttributeError(msg) - fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) + fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) args = _resolve_kwargs(self.fun, args, kwargs) if config.jax_enable_custom_vjp_by_custom_transpose: if self.nondiff_argnums: @@ -555,6 +582,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable else: f_, dyn_args = lu.wrap_init(self.fun), args fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) + fwd = _project_fwd(fwd, self.symbolic_zeros) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) @@ -562,10 +590,22 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable out_type) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, - *args_flat, out_trees=out_trees) + *args_flat, out_trees=out_trees, + symbolic_zeros=self.symbolic_zeros) _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees) return tree_unflatten(out_tree, out_flat) +@dataclasses.dataclass +class ZeroTagged: + val: Any + zero: bool + +@lu.transformation +def _project_fwd(symbolic_zeros, *args, **kwargs): + project_leaf = ((lambda x: (x.val, x.zero)) if symbolic_zeros else + (lambda x: x.val)) + yield (yield tree_map(project_leaf, (args, kwargs))) + def _check_for_tracers(x): for leaf in tree_leaves(x): if isinstance(x, core.Tracer): @@ -579,8 +619,10 @@ def _check_for_tracers(x): raise UnexpectedTracerError(msg) @lu.transformation_with_aux -def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args): - py_args = tree_unflatten(in_tree, args) +def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, + *args): + tagged_args = [ZeroTagged(x, z) for x, z in zip(args[::2], args[1::2])] + py_args = tree_unflatten(in_tree, tagged_args) pair_out = yield py_args, {} if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " @@ -672,7 +714,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args): class CustomVJPCallPrimitive(core.CallPrimitive): initial_style: core.Primitive - def bind(self, fun, fwd, bwd, *args, out_trees): + def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = process_env_traces( @@ -682,7 +724,8 @@ def bind(self, fun, fwd, bwd, *args, out_trees): tracers = map(top_trace.full_raise, args) # type: ignore bwd_ = lambda *args: bwd(*args) outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, - out_trees=out_trees) + out_trees=out_trees, + symbolic_zeros=symbolic_zeros) fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) if fst: return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) @@ -748,31 +791,34 @@ def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): def _custom_vjp_call_jaxpr_jvp( primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], - bwd: Callable, out_trees: Callable, num_consts: int): + fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]], + num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): raise ad.CustomVJPException() - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers! + zeros = [type(t) is Zero for t in args_dot] + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers! out_tree, res_tree = out_trees() + res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) + res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) # Cast float0 to zeros with the primal dtype because custom vjp rules don't # currently handle float0s args_dot = map(ad.replace_float0s, args, args_dot) - res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] tangents_out = ad.custom_lin_p.bind( - *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out) + *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(ad.recast_to_float0, primals_out, tangents_out) return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp -def _custom_vjp_call_jaxpr_vmap(spmd_axis_name, - axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], - bwd: Callable, out_trees: Callable, num_consts: int): +def _custom_vjp_call_jaxpr_vmap( + spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, + fun_jaxpr: core.ClosedJaxpr, + fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]], + num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] @@ -785,8 +831,8 @@ def _custom_vjp_call_jaxpr_vmap(spmd_axis_name, out_dims2 = [] @pe._memoize - def batched_fwd_jaxpr_thunk(): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers + def batched_fwd_jaxpr_thunk(*zeros): + fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name, main_type) @@ -795,17 +841,20 @@ def batched_fwd_jaxpr_thunk(): fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] - batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims, - fwd_args_batched, main_type, spmd_axis_name) + batched_bwd = batching.batch_custom_vjp_bwd( + bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type, + spmd_axis_name) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, - out_trees=out_trees, num_consts=num_consts) + num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims -batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap -batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(_custom_vjp_call_jaxpr_vmap, None) +batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \ + _custom_vjp_call_jaxpr_vmap +batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial( + _custom_vjp_call_jaxpr_vmap, None) xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index bdcf3b961e51..8e4c41434110 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -27,9 +27,9 @@ from jax._src import core from jax._src import source_info_util from jax._src.ad_util import ( - add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval, - zeros_like_p, Zero, replace_internal_symbolic_zeros, - replace_rule_output_symbolic_zeros) + add_jaxvals, add_jaxvals_p, replace_internal_symbolic_zeros, + replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, + zeros_like_jaxval, zeros_like_p) from jax._src.api_util import flatten_fun, flatten_fun_nokwargs from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, raise_to_shaped) @@ -387,16 +387,24 @@ def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros): def post_process_custom_jvp_call(self, out_tracers, _): raise CustomJVPException() - def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, + symbolic_zeros): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - tangents_in = map(instantiate_zeros, tangents_in) - res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) + fwd_in = [(core.full_lower(p), type(t) is Zero) + for p, t in zip(primals_in, tangents_in)] + fwd_in = [x for pair in fwd_in for x in pair] # flatten + res_and_primals_out = fwd.call_wrapped(*fwd_in) out_tree, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + # We don't need to handle any symbolic zeros on tangents_in or + # tangents_out below, because custom_lin_p is never executed and + # doesn't correspond to any custom user rule. + # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! + tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out) + out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) @@ -745,10 +753,15 @@ def raise_custom_vjp_error_on_jvp(*_, **__): "function.") custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp) -def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals): +def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals, + symbolic_zeros): res, _ = split_list(invals, [num_res]) - cts_out = map(instantiate_zeros_aval, out_avals, cts_out) + if symbolic_zeros: + cts_out = map(replace_internal_symbolic_zeros, cts_out) + else: + cts_out = map(instantiate_zeros_aval, out_avals, cts_out) cts_in = bwd(*res, *cts_out) + cts_in = map(replace_rule_output_symbolic_zeros, cts_in) return [None] * num_res + list(cts_in) primitive_transposes[custom_lin_p] = _custom_lin_transpose diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 323104692178..c9100f141c6a 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -23,19 +23,19 @@ import jax from jax.config import config +from jax.interpreters import partial_eval as pe from jax._src import core from jax._src import source_info_util -from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName -from jax._src.tree_util import (tree_unflatten, tree_flatten, - register_pytree_node) +from jax._src import linear_util as lu from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p, Zero, SymbolicZero, replace_rule_output_symbolic_zeros, instantiate) -from jax._src import linear_util as lu +from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName +from jax._src.tree_util import (tree_unflatten, tree_flatten, + register_pytree_node) from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache) -from jax.interpreters import partial_eval as pe Array = Any map, unsafe_map = safe_map, map @@ -473,16 +473,19 @@ def todo(vals): return map(partial(BatchTracer, trace), vals, dims, srcs) return vals, todo - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): # pytype: disable=signature-mismatch + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, + symbolic_zeros): # pytype: disable=signature-mismatch in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} + fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) - fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) + fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims) bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name) - out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) + out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: _, res_tree = out_trees() @@ -779,8 +782,14 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals): def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name): def new_bwd(*args): + in_dims_ = in_dims() if callable(in_dims) else in_dims + args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval)) + if type(x) is SymbolicZero else x + for x, dim in zip(args, in_dims_)] + in_dims_ = [None if type(x) is SymbolicZero else d + for x, d in zip(args, in_dims_)] bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd)) - bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type, + bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type, spmd_axis_name) bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests) @@ -797,7 +806,7 @@ def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py # TODO(mattjj): dedup with matchaxis - if isinstance(x, Zero): + if isinstance(x, (Zero, SymbolicZero)): if src == dst: return x elif type(src) == type(dst) == int: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 983245752ac7..86c1afd0d68e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -945,10 +945,11 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, return self.process_primitive(fake_primitive, tracers, {}) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, - out_trees): + out_trees, symbolic_zeros): bind = HashableFunction( - lambda *args, **kwargs: primitive.bind(fun, fwd, bwd, *args, - out_trees=out_trees, **kwargs), + lambda *args, **kwargs: primitive.bind( + fun, fwd, bwd, *args, out_trees=out_trees, + symbolic_zeros=symbolic_zeros, **kwargs), (primitive, fun, fwd, bwd)) fake_primitive = FakePrimitive(multiple_results=True, bind=bind) return self.process_primitive(fake_primitive, tracers, {}) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 476fe65f0afa..d9f4570c9f05 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1491,11 +1491,12 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): def post_process_custom_jvp_call(self, out_tracers, _): assert False # unreachable assuming jax2tf runs with clean trace state - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): # Drop the custom differentiation rule and act like a call primitive. This # behavior is desirable because jax2tf stages code out of the JAX system, so # there are no more JAX differentiation transformations to be applied. - del fwd, bwd, out_trees # Unused. + del fwd, bwd, out_trees, symbolic_zeros # Unused. return self.process_call(core.call_p, fun, tracers, {}) def post_process_custom_vjp_call(self, out_tracers, _): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b6d44923a5ad..f20097fbd842 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -496,14 +496,16 @@ def process_custom_transpose(self, prim, call, tracers, **params): for t in out_tracers: t.recipe = eqn return out_tracers - def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, + symbolic_zeros): # TODO(mattjj): after old remat is deleted, make this method trivial. # Because we instantiate all tracers, in_knowns is all False. tracers = map(self.instantiate_const_abstracted, tracers) in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) f = trace_to_subjaxpr_nounits(f, self.main, True) f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees) + out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) out_knowns, out_avals, jaxpr, env = aux() out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) res_tracers = map(self.new_instantiated_const, res) @@ -513,8 +515,9 @@ def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees): closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) @_memoize - def fwd_jaxpr_thunk(): - fwd_ = trace_to_subjaxpr_nounits(fwd, self.main, True) + def fwd_jaxpr_thunk(*zeros): + fwd_ = _interleave_fun(fwd, zeros) + fwd_ = trace_to_subjaxpr_nounits(fwd_, self.main, True) fwd_, aux = partial_eval_wrapper_nounits( fwd_, tuple(in_knowns), tuple(in_avals)) with core.new_sublevel(): @@ -531,7 +534,8 @@ def fwd_jaxpr_thunk(): dict(fun_jaxpr=closed_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_thunk, num_consts=len(res) + len(env), - bwd=bwd, out_trees=out_trees), + bwd=bwd, out_trees=out_trees, + symbolic_zeros=symbolic_zeros), jaxpr.effects, source) for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) @@ -1905,23 +1909,29 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): def post_process_custom_jvp_call(self, out_tracers, _): assert False # unreachable - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): in_avals = [t.aval for t in tracers] with core.new_sublevel(): fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + main_ = ref(self.main) - fwd_jaxpr_thunk = _memoize( - lambda: trace_to_subjaxpr_dynamic(fwd, main_(), in_avals)[::2]) + @_memoize + def fwd_jaxpr_from_zeros(*zeros): + fwd_ = _interleave_fun(fwd, zeros) + return trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals)[::2] + out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, dict(fun_jaxpr=closed_fun_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_thunk, + fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, num_consts=len(consts), - bwd=bwd, out_trees=out_trees), + bwd=bwd, out_trees=out_trees, + symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, source_info_util.current()) self.frame.add_eqn(eqn) @@ -1970,18 +1980,23 @@ def process_custom_transpose(self, prim, call, tracers, custom_staging_rules: Dict[Primitive, Callable] = {} -def _memoize(thunk): - cell = [] +@lu.transformation +def _interleave_fun(every_others, *args, **kwargs): + args_ = [x for pair in zip(args, every_others) for x in pair] + yield (yield (args_, kwargs)) + +def _memoize(fn): + cells = {} saved_state = [core.thread_local_state.trace_state.copy()] - def memoized(): - if not cell: + def memoized(*args): + if args not in cells: prev_state = core.thread_local_state.trace_state core.thread_local_state.trace_state = saved_state.pop() try: - cell.append(thunk()) + cells[args] = fn(*args) finally: core.thread_local_state.trace_state = prev_state - return cell[0] + return cells[args] return memoized # TODO(mattjj): remove this DebugInfo and helper functions, replace with diff --git a/tests/api_test.py b/tests/api_test.py index 8a3baf8ee048..e615bb678b81 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7308,10 +7308,10 @@ def run(primal_ins, tangent_ins): primal_outs, tangent_outs = run(primal_ins, tangent_ins) primal_out1, primal_out2 = primal_outs tangent_out1, tangent_out2 = tangent_outs - scalar_dtype = jax.Array if maybe_jit or maybe_vmap else float - self.assertIsInstance(primal_out1, scalar_dtype) + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + self.assertIsInstance(primal_out1, scalar_type) self.assertAllClose(primal_out1, 5.) - self.assertIsInstance(tangent_out1, scalar_dtype) + self.assertIsInstance(tangent_out1, scalar_type) self.assertAllClose(tangent_out1, 91.) self.assertIsInstance(primal_out2, jax.Array) self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) @@ -8389,6 +8389,183 @@ def f(x): f_vjp(jnp.array([3.])) f_vjp(jnp.array([3.])) # doesn't crash + def test_symbolic_zero_custom_vjp_basic(self): + @jax.custom_vjp + def f(x, y, z): + return x, x + + def fwd(x, y, z): + self.assertFalse(x[1]) + self.assertTrue(y[1]) + self.assertTrue(z[1]) + return (x[0], x[0]), None + + def fwd_all(x, y, z): + self.assertFalse(x[1]) + self.assertFalse(y[1]) + self.assertFalse(z[1]) + return (x[0], x[0]), None + + def bwd_all(_, g): + x1, x2 = g + self.assertFalse(type(x1) is custom_derivatives_public.SymbolicZero) + self.assertFalse(type(x2) is custom_derivatives_public.SymbolicZero) + return x1, x1, x2 + + def bwd_fst(_, g): + x1, x2 = g + self.assertFalse(type(x1) is custom_derivatives_public.SymbolicZero) + self.assertIs(type(x2), custom_derivatives_public.SymbolicZero) + return x1, x1, x2 + + def bwd_snd(_, g): + x1, x2 = g + self.assertIs(type(x1), custom_derivatives_public.SymbolicZero) + self.assertFalse(type(x2) is custom_derivatives_public.SymbolicZero) + return x1, x1, x2 + + x, y, z = 4., 5., 6. + i = np.array(7, np.int32) + zero = np.array(0.) + + f.defvjp(fwd, bwd_all, symbolic_zeros=True) + h = jax.jit(f) + jax.jacrev(h)(x, y, z) + jax.jacrev(lambda x: h(x, y, z))(x) + jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) + + f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) + fst_f = lambda *xs: f(*xs)[0] + _, vjp = jax.vjp(fst_f, x, y, z) + _, _, gz = vjp(x) + self.assertArraysAllClose(gz, zero) + + f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) + snd_f = lambda *xs: f(*xs)[1] + _, vjp = jax.vjp(snd_f, x, y, z) + gx, gy, _ = vjp(x) + self.assertArraysAllClose(gx, zero) + self.assertArraysAllClose(gy, zero) + + f.defvjp(fwd, bwd_snd, symbolic_zeros=True) + _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) + gx, = vjp(x) + self.assertArraysAllClose(gx, zero) + + @parameterized.named_parameters( + ('jit_vmap', True, True), + ('jit', True, False), + ('vmap', False, True), + ('', False, False), + ) + def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): + # below: + # * static_scalar will be static in and out + # * static_array will be static in, but dynamic out + # * dyn_scalar and dyn_array will be dynamic in and out + + ZERO = custom_derivatives_public.SymbolicZero + + def f(static_scalar, static_array, dyn_scalar, dyn_array): + out1 = static_scalar + dyn_scalar + out2 = static_array + dyn_array + return static_scalar, static_array, out1, out2 + + def _pack(x): + return lax.broadcast(x, (1,)) + + def _unpack(x): + (x,) = x + return x + + def _vmap(fun): + def _fun(*args): + args = tree_util.tree_map(_pack, args) + out = jax.vmap(fun)(*args) + out = tree_util.tree_map(_unpack, out) + return out + return _fun + + f = api.custom_vjp(f) + + def fwd(*args): + xs, zeros = [x[0] for x in args], [x[1] for x in args] + self.assertTrue(zeros[0]) + self.assertTrue(zeros[1]) + self.assertFalse(zeros[2]) + self.assertFalse(zeros[3]) + return f(*xs), xs + + def bwd(res, g): + static_scalar, *_ = res + t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g + self.assertIs(type(t_static), ZERO) + self.assertFalse(type(t_static_arr) is ZERO) + self.assertFalse(type(t_dyn_scalar) is ZERO) + self.assertFalse(type(t_dyn_array) is ZERO) + self.assertEqual(t_static.shape, ()) + self.assertEqual(t_static_arr.shape, (2,)) + return (static_scalar + 90, + t_static_arr + 91, + t_dyn_scalar + 92, + t_dyn_array + 93) + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + def g(dyn_scalar, dyn_array): + if maybe_vmap: + f_ = _vmap(f) + else: + f_ = f + outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) + return outs[1:] + + def run(primal_ins, cotangent_outs): + primal_outs, vjp = jax.vjp(g, *primal_ins) + cotangent_ins = vjp(cotangent_outs) + return primal_outs, cotangent_ins + + if maybe_jit: + run = jax.jit(run) + + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + primal_ins = (4., jnp.array([5., 6.])) + cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) + primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) + + primal_out1, primal_out2, primal_out3 = primal_outs + self.assertIsInstance(primal_out1, jax.Array) + self.assertAllClose(primal_out1, jnp.array([2., 3.])) + self.assertIsInstance(primal_out2, scalar_type) + self.assertAllClose(primal_out2, 5.) + self.assertIsInstance(primal_out3, jax.Array) + self.assertAllClose(primal_out3, jnp.array([7., 9.])) + + ct_in1, ct_in2 = cotangent_ins + self.assertIsInstance(ct_in1, scalar_type) + self.assertAllClose(ct_in1, 99.) + self.assertIsInstance(ct_in2, jax.Array) + self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) + + def test_symbolic_zero_custom_vjp_vmap_output(self): + @api.custom_vjp + def f(x, y): + return x, y + + def fwd(x, y): + (x, x0), (y, y0) = x, y + self.assertFalse(x0) + self.assertTrue(y0) + return f(x, y), None + + def bwd(_, g): + ct_x, ct_y = g + #import ipdb; ipdb.set_trace() + self.assertIs(type(ct_y), custom_derivatives_public.SymbolicZero) + return g + + f.defvjp(fwd, bwd, symbolic_zeros=True) + jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) def transpose_unary(f, x_example): def transposed(y): From ce3f5343477c1a9d20b139638f077ac71a5d43c9 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 21 Mar 2023 06:01:10 +0100 Subject: [PATCH 17/65] [jax2tf] Fix grad of pjit in native lowering. Since jax2tf.convert is called recursively for the purpose of serializing the vjp function, we must ensure that if the primal function is a pjit with shardings then the vjp function must also be converted as a pjit. Without this fix the serialization with gradients of a pjit function will fail the an error that there are shardings but not pjit at the top-level. --- jax/_src/pjit.py | 2 +- jax/experimental/jax2tf/jax2tf.py | 89 ++++++++++++++----- jax/experimental/jax2tf/tests/jax2tf_test.py | 2 +- .../jax2tf/tests/sharding_test.py | 84 ++++++++++++++++- 4 files changed, 150 insertions(+), 27 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a6b4a77f10a8..20a5ab0c7249 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -784,7 +784,7 @@ def flatten_axis_resources(what, tree, shardings, tupled_args): axis_tree = shardings - # Because ecause we only have the `tree` treedef and not the full pytree here, + # Because we only have the `tree` treedef and not the full pytree here, # we construct a dummy tree to compare against. Revise this in callers? dummy_tree = tree_unflatten(tree, [PytreeLeaf()] * tree.num_leaves) errors = prefix_errors(axis_tree, dummy_tree) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index d9f4570c9f05..53bedbf1057a 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -430,16 +430,15 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal: lowering_platform = native_serialization_platforms[0] else: lowering_platform = None - exported: Exported = serialize_native( + exported: Optional[Exported] = serialize_native( fun_flat_jax, args_avals_flat, lowering_platform=lowering_platform, strict_checks=native_serialization_strict_checks) - def run_fun_flat_as_tf( args_flat_tf: Sequence[TfVal] ) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]: outs_tf, out_avals = run_exported_as_tf( - args_avals_flat, args_flat_tf, exported, + args_avals_flat, args_flat_tf, exported, # type: ignore native_serialization_strict_checks) return outs_tf, out_avals else: @@ -448,6 +447,7 @@ def run_fun_flat_as_tf( dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf, args_avals_flat, name_stack) shape_env = zip(dim_vars, dim_values) # type: ignore + exported = None def run_fun_flat_as_tf( args_flat_tf: Sequence[TfVal] ) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]: @@ -477,7 +477,7 @@ def run_fun_flat_as_tf( def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal: outs_tf, out_avals = run_fun_flat_as_tf(args_flat_tf) return (tuple(outs_tf), - make_custom_gradient_fn_tf( + _make_custom_gradient_fn_tf( fun_flat_jax=fun_flat_jax, args_flat_tf=args_flat_tf, args_avals_flat=args_avals_flat, @@ -485,7 +485,8 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal: out_avals=out_avals, native_serialization=native_serialization, native_serialization_platforms=native_serialization_platforms, - native_serialization_strict_checks=native_serialization_strict_checks)) + native_serialization_strict_checks=native_serialization_strict_checks, + exported_primal=exported)) out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf) else: @@ -599,17 +600,19 @@ def preprocess_arg_tf(arg_idx: int, return arg_tf, arg_aval -# Prepare the grad_fn for tf.custom_gradient. -def make_custom_gradient_fn_tf(*, - fun_flat_jax: Callable, - args_flat_tf: Sequence[TfVal], - polymorphic_shapes_flat: Sequence[str], - args_avals_flat: Sequence[core.ShapedArray], - out_avals: Sequence[core.ShapedArray], - native_serialization: Union[str, bool], - native_serialization_platforms: Sequence[str], - native_serialization_strict_checks: bool -): +def _make_custom_gradient_fn_tf(*, + fun_flat_jax: Callable, + args_flat_tf: Sequence[TfVal], + polymorphic_shapes_flat: Sequence[str], + args_avals_flat: Sequence[core.ShapedArray], + out_avals: Sequence[core.ShapedArray], + native_serialization: Union[str, bool], + native_serialization_platforms: Sequence[str], + native_serialization_strict_checks: bool, + exported_primal: Optional["Exported"]): + """Prepares the TF function to be used with tf.custom_gradient. + + """ def grad_fn_tf(*out_cts_flat_tf: TfVal, variables=None): @@ -659,6 +662,45 @@ def fix_in_ct(in_ct_jax, arg_aval: core.ShapedArray): in_cts_fixed_flat_jax = tuple(map(fix_in_ct, in_cts_flat_jax, args_avals_flat)) return in_cts_fixed_flat_jax + if exported_primal is not None: + # Native lowering + all_in_shardings = [pxla._UNSPECIFIED] * len(exported_primal.in_avals) + for idx, in_s in zip(sorted(exported_primal.module_kept_var_idx), + exported_primal.in_shardings): + all_in_shardings[idx] = in_s # type: ignore + all_shardings = all_in_shardings + list(exported_primal.out_shardings) + # We cannot mix unspecified and specified shardings. Make the unspecified + # ones replicated + specified_shardings = [ + s for s in all_shardings if not pxla._is_unspecified(s)] + if 0 < len(specified_shardings) < len(all_shardings): + # There are some specified, but not all + in_s = specified_shardings[0] # pjit will enforce that all have same devices + assert isinstance(in_s, sharding.XLACompatibleSharding) + replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment) + all_shardings = [ + s if not pxla._is_unspecified(s) else replicated_s + for s in all_shardings] + # Since fun_vjp_jax takes two tuples of arguments we must split the in_shardings + vjp_in_args_shardings, vjp_in_out_ct_shardings = util.split_list(all_shardings, + [len(exported_primal.in_avals)]) + # pjit front-end does not like all-unspecified + if all(pxla._is_unspecified(s) for s in vjp_in_args_shardings): + vjp_in_args_shardings = pxla._UNSPECIFIED + else: + vjp_in_args_shardings = tuple(vjp_in_args_shardings) + if all(pxla._is_unspecified(s) for s in vjp_in_out_ct_shardings): + vjp_in_out_ct_shardings = pxla._UNSPECIFIED + else: + vjp_in_out_ct_shardings = tuple(vjp_in_out_ct_shardings) + + if pxla._is_unspecified(vjp_in_args_shardings) and pxla._is_unspecified(vjp_in_args_shardings): + vjp_in_shardings = pxla._UNSPECIFIED + else: + vjp_in_shardings = (vjp_in_args_shardings, vjp_in_out_ct_shardings) + fun_vjp_jax = pjit.pjit(fun_vjp_jax, + in_shardings=vjp_in_shardings, + out_shardings=vjp_in_args_shardings) # TODO: enable higher-order gradients with tf.name_scope("jax2tf_vjp"): in_cts_flat = convert( @@ -707,15 +749,16 @@ class Exported: """Represents a lowered and serialized module.""" in_avals: Sequence[core.ShapedArray] out_avals: Sequence[core.ShapedArray] - in_shardings: Optional[Sequence[Any]] - out_shardings: Optional[Sequence[Any]] + # The in_shardings reflect only the module_ket_var_idx + in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]] + out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]] lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm" mlir_module: mlir.ir.Module mlir_module_serialized: bytes # VHLO bytecode format xla_call_module_version: int # Follows the versions of XlaCallModule - module_kept_var_idx: Sequence[bool] # Specifies if an argument is kept in the - # lowering. As long as `out_avals`. + module_kept_var_idx: Sequence[int] # Specifies if an argument is kept in the + # lowering. As long as `out_avals`. dim_args_spec: Sequence[str] def serialize_native(fun_jax: Callable, @@ -767,7 +810,7 @@ def serialize_native(fun_jax: Callable, raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering") if "kept_var_idx" in lowered.compile_args: - module_kept_var_idx = lowered.compile_args["kept_var_idx"] + module_kept_var_idx = tuple(sorted(lowered.compile_args["kept_var_idx"])) else: # For pmap module_kept_var_idx = tuple(range(len(args_avals))) @@ -837,8 +880,8 @@ def serialize_native(fun_jax: Callable, return Exported( in_avals=args_avals, out_avals=out_avals, - in_shardings=lowered.compile_args.get("in_shardings"), - out_shardings=lowered.compile_args.get("out_shardings"), + in_shardings=lowered.compile_args["in_shardings"], + out_shardings=lowered.compile_args["out_shardings"], lowering_platform=lowering_platform or default_jax_backend(), mlir_module=mlir_module, mlir_module_serialized=mlir_module_serialized, diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 1d719e679faa..79fa468b2694 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1101,7 +1101,7 @@ def test_error_disallowed_custom_call(self): "Cannot serialize code with custom calls whose targets .*"): jax2tf.convert( lambda a, b: jax.lax.linalg.triangular_solve(a, b, left_side=True), - experimental_native_lowering=True)(a, b) + native_serialization=True)(a, b) def test_op_metadata_simple(self): self.skipTest("include_xla_op_metadata not yet enabled") diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 4d8fab47c831..5ad24d058bdd 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -358,6 +358,58 @@ def f_jax(x): # x: f32[10, 20] (r"custom_call_target.*Sharding", 2 + count_inner_sharding) ]) + + @parameterized.named_parameters( + dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}", + in_shardings=in_shardings, out_shardings=out_shardings) + for in_shardings in ("missing", None, "P") + for out_shardings in ("missing", None, "P") + ) + @jtu.with_mesh([("x", 2)]) + def test_grad_pjit(self, in_shardings="missing", out_shardings="None"): + def f_jax(x): # x: f32[10,20] -> f32[20,10] + return jnp.sin(x.T) + + pjit_kwargs = {} + if in_shardings != "missing": + pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None) + if out_shardings != "missing": + pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None) + f_jax = pjit.pjit(f_jax, **pjit_kwargs) + x_shape = (10, 20) + x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + + def f_grad_tf(x_v, res_ct): + with tf.GradientTape(persistent=True) as tape: + tape.watch(x_v) + res_tf = jax2tf.convert(f_jax)(x_v) + return tape.gradient(res_tf, x_v, output_gradients=res_ct) + + # Annotation count for the primal input and the grad output + count_in_P = self.GEQ(2) if in_shardings == "P" else 0 + if config.jax2tf_default_native_serialization: + # With native serialization even unspecified in_shardings turn into replicated + count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0 + else: + count_in_replicated = self.GEQ(2) if in_shardings is None else 0 + # Annotation count for the contangent input + count_out_P = self.GEQ(1) if out_shardings == "P" else 0 + if config.jax2tf_default_native_serialization: + # With native serialization even unspecified in_shardings turn into replicated + count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0 + else: + count_out_replicated = self.GEQ(1) if out_shardings is None else 0 + + self.check_sharding(f_grad_tf, [x, x.T], + checks=[ + # The input primal argument, and the output grad + (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P), + (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated), + # The primal result, and the input cotangent + (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P), + (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated), + ]) + @parameterized.named_parameters( dict(testcase_name=f"_kind={kind}_in_shardings={in_shardings}_out_shardings={out_shardings}", kind=kind, in_shardings=in_shardings, out_shardings=out_shardings) @@ -460,8 +512,8 @@ def test_xmap_basic(self): bshape = (2, 7) b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) - # f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28] - # f_jax: f32[5], f32[7] -> f32[10], f32[28] + # f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28] + # lambda ...: f32[5], f32[7] -> f32[10], f32[28] f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2., jnp.concatenate([b, b, b, b], axis=0) * 4.), in_axes=({0: 'a', 1: 'b'}, ['c', ...]), @@ -535,6 +587,34 @@ def f_tf(a, b): (r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), ]) + def test_grad_xmap(self): + devices = np.reshape(self.devices, (1, 2)) + ashape = (16, 8, 5) + a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) + + # f_jax: f32[16,8,5]-> f32[16,8,10] + # lambda ...: f32[5]-> f32[10] + f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2., + in_axes=({0: 'a', 1: 'b'}), + out_axes={0: 'a', 1: 'b'}, + axis_resources={'a': 'x', 'b': 'y'}) + + def f_grad_tf(a, res_ct): + with tf.GradientTape(persistent=True) as tape: + tape.watch(a) + res_tf = jax2tf.convert(f_jax, native_serialization=True)(a) + return tape.gradient(res_tf, a, output_gradients=res_ct) + + + with Mesh(devices, ('x', 'y')): + self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)], + checks=[ + # Primal input and grad output + (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)), + # Input cotangent + (r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)), + ]) + @jtu.ignore_warning(category=UserWarning, message="all_to_all .* are only implemented properly for TPUs and GPUs .*") def test_shmap_all_to_all(self): From 4aa8ae941b175a31f9001b342e123fac72d1cec4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 22 Mar 2023 08:18:05 -0700 Subject: [PATCH 18/65] Fix mypy failures in jax2tf. PiperOrigin-RevId: 518572905 --- jax/experimental/jax2tf/jax2tf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 53bedbf1057a..4afed82bc698 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -682,6 +682,9 @@ def fix_in_ct(in_ct_jax, arg_aval: core.ShapedArray): s if not pxla._is_unspecified(s) else replicated_s for s in all_shardings] # Since fun_vjp_jax takes two tuples of arguments we must split the in_shardings + vjp_in_args_shardings: Any + vjp_in_out_ct_shardings: Any + vjp_in_shardings: Any vjp_in_args_shardings, vjp_in_out_ct_shardings = util.split_list(all_shardings, [len(exported_primal.in_avals)]) # pjit front-end does not like all-unspecified From f4a40dc6c74cea625b1e2e971b3d54963b14b51c Mon Sep 17 00:00:00 2001 From: jiayaobo Date: Tue, 21 Mar 2023 11:13:40 +0800 Subject: [PATCH 19/65] add wald random generator add wald to random.py --- docs/jax.random.rst | 1 + jax/_src/random.py | 56 ++++++++++++++++++++++++++++++++++++++++++++ jax/random.py | 1 + tests/random_test.py | 15 ++++++++++++ 4 files changed, 73 insertions(+) diff --git a/docs/jax.random.rst b/docs/jax.random.rst index 69981e4747f6..1d7f68e39e9c 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -49,5 +49,6 @@ List of Available Functions t truncated_normal uniform + wald weibull_min diff --git a/jax/_src/random.py b/jax/_src/random.py index d594fc4317a7..c163d6df3c0d 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1880,3 +1880,59 @@ def _rayleigh(key, scale, shape, dtype) -> Array: sqrt_u = lax.sqrt(lax.mul(log_u, n_two)) ray = lax.mul(scale, sqrt_u) return ray + +def wald(key: KeyArray, + mean: RealArray, + scale: RealArray, + shape: Optional[Shape] = None, + dtype: DTypeLikeFloat = dtypes.float_) -> Array: + """Sample Wald random values with given shape and float dtype. + + Args: + key: a PRNG key used as the random key. + mean: a float or array of floats broadcast-compatible with ``shape`` + representing the mean parameter of the distribution. + scale: a float or array of floats broadcast-compatible with ``shape`` + representing the scale parameter of the distribution. + shape: optional, a tuple of nonnegative integers specifying the result + shape. Must be broadcast-compatible with ``mean`` and ``scale``. The default + (None) produces a result shape equal to ``lax.broadcast_shapes(np.shape(mean), np.shape(scale))``. + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + + Returns: + A random array with the specified dtype and with shape given by ``shape`` if + ``shape`` is not None, or else by ``mean.shape`` and ``scale.shape``. + """ + key, _ = _check_prng_key(key) + if not dtypes.issubdtype(dtype, np.floating): + raise ValueError("dtype argument to `wald` must be a float " + f"dtype, got {dtype}") + dtype = dtypes.canonicalize_dtype(dtype) + if shape is not None: + shape = core.canonicalize_shape(shape) + return _wald(key, mean, scale, shape, dtype) + +@partial(jit, static_argnums=(3, 4), inline=True) +def _wald(key, mean, scale, shape, dtype) -> Array: + if shape is None: + shape = lax.broadcast_shapes(np.shape(mean), np.shape(scale)) + else: + _check_shape("wald", shape, np.shape(mean), np.shape(scale)) + k1, k2 = _split(key, 2) + mean = mean.astype(dtype) + scale = scale.astype(dtype) + mean = jnp.broadcast_to(mean, shape) + scale = jnp.broadcast_to(scale, shape) + v = normal(k1, shape, dtype) + z = uniform(k2, shape, dtype) + two = _lax_const(mean, 2) + y = lax.integer_pow(v, 2) + y_sq = lax.integer_pow(y, 2) + mean_sq = lax.integer_pow(mean, 2) + mean_two = lax.mul(mean, two) + scale_two = lax.mul(scale, two) + sqrt_term = lax.sqrt(mean_two * scale_two * y + mean_sq * y_sq) + x = mean + mean_sq * y / scale_two - mean / scale_two * sqrt_term + w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x) + return w diff --git a/jax/random.py b/jax/random.py index 5e6b6e179222..8a84098f753c 100644 --- a/jax/random.py +++ b/jax/random.py @@ -189,5 +189,6 @@ truncated_normal as truncated_normal, uniform as uniform, unsafe_rbg_key as unsafe_rbg_key, + wald as wald, weibull_min as weibull_min, ) diff --git a/tests/random_test.py b/tests/random_test.py index cba3ccf4835d..9245e832e47d 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1537,6 +1537,21 @@ def testRayleigh(self, scale, dtype): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.rayleigh(scale=scale).cdf) + @jtu.sample_product( + mean= [0.2, 1., 2., 10. ,100.], + scale= [0.2, 1., 2., 10. ,100.], + dtype=jtu.dtypes.floating) + def testWald(self, mean, scale, dtype): + key = self.seed_prng(0) + rand = lambda key: random.wald(key, mean, scale, shape = (10000, ), dtype = dtype) + crand = jax.jit(rand) + + uncompiled_samples = rand(key) + compiled_samples = crand(key) + + for samples in [uncompiled_samples, compiled_samples]: + self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean / scale, scale = scale).cdf) + class KeyArrayTest(jtu.JaxTestCase): # Key arrays involve: # * a Python key array type, backed by an underlying uint32 "base" array, From 92e79b36cec14112a5f320847efbf67d3d5f3968 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 22 Mar 2023 09:55:19 -0700 Subject: [PATCH 20/65] Revert: `custom_vjp` symbolic zeros support PiperOrigin-RevId: 518597609 --- jax/_src/checkify.py | 11 +- jax/_src/core.py | 3 +- jax/_src/custom_derivatives.py | 115 ++++++------------- jax/_src/interpreters/ad.py | 31 ++--- jax/_src/interpreters/batching.py | 29 ++--- jax/_src/interpreters/pxla.py | 7 +- jax/experimental/jax2tf/jax2tf.py | 5 +- jax/interpreters/partial_eval.py | 47 +++----- tests/api_test.py | 183 +----------------------------- 9 files changed, 81 insertions(+), 350 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index e250e240c18f..80ac676d9d23 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -915,8 +915,7 @@ def jvp(*xs): error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, - fwd_jaxpr_thunk, num_consts, bwd, out_trees, - symbolic_zeros): + fwd_jaxpr_thunk, num_consts, bwd, out_trees): err_vals, err_tree = jtu.tree_flatten(in_err) fun = lu.wrap_init( functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, @@ -924,17 +923,15 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun) @lu.wrap_init - def fwd(*args): + def fwd(*xs): # TODO(lenamartens, sharadmv): why not checkify here? - xs, zeros = args[::2], args[1::2] - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) fwd, fwd_out_tree = flatten_fun_output(fwd) all_outs = custom_derivatives.custom_vjp_call_p.bind( - fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) + fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees) fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) if fst: err_and_out_tree, _ = out_metadata diff --git a/jax/_src/core.py b/jax/_src/core.py index d16e0f4baef2..95d2f38e4d4a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -511,8 +511,7 @@ def process_custom_transpose(self, prim, call, tracers, **params): "to handle custom_transpose_call primitives") raise NotImplementedError(msg) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, - out_trees, symbolic_zeros): + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): msg = (f"{type(self)} must override process_custom_vjp_call " "to handle custom_vjp primitives") raise NotImplementedError(msg) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 66c527081001..e93132a7806f 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses from functools import update_wrapper, reduce, partial import inspect from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any) @@ -32,8 +31,8 @@ from jax._src import effects from jax._src import linear_util as lu from jax._src import traceback_util -from jax._src.ad_util import ( - stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) +from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval, + stop_gradient_p) from jax._src.api_util import argnums_partial, flatten_fun_nokwargs from jax._src.core import raise_to_shaped from jax._src.interpreters import ad @@ -164,13 +163,12 @@ def defjvp(self, and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof. symbolic_zeros: boolean, indicating whether the rule should be passed - objects representing static symbolic zeros in its tangent argument in - correspondence with unperturbed values; otherwise, only standard JAX - types (e.g. array-likes) are passed. Setting this option to ``True`` - allows a JVP rule to detect whether certain inputs are not involved in - differentiation, but at the cost of needing special handling for these - objects (which e.g. can't be passed into jax.numpy functions). Default - ``False``. + objects representing static symbolic zeros in its tangent tuple + argument; otherwise, only standard JAX types (e.g. array-likes) are + passed. Setting this option to True allows a JVP rule to detect whether + certain inputs are not involved in differentiation, but at the cost of + needing special handling for these objects (which e.g. can't be passed + into jax.numpy functions). Default False. Returns: None. @@ -482,15 +480,12 @@ def __init__(self, self.nondiff_argnums = nondiff_argnums self.fwd: Optional[Callable[..., Tuple[ReturnValue, Any]]] = None self.bwd: Optional[Callable[..., Tuple[Any, ...]]] = None - self.symbolic_zeros = False __getattr__ = custom_api_util.forward_attr def defvjp(self, fwd: Callable[..., Tuple[ReturnValue, Any]], - bwd: Callable[..., Tuple[Any, ...]], - symbolic_zeros: bool = False, - ) -> None: + bwd: Callable[..., Tuple[Any, ...]]) -> None: """Define a custom VJP rule for the function represented by this instance. Args: @@ -511,27 +506,6 @@ def defvjp(self, function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments. - symbolic_zeros: boolean, indicating whether to indicate symbolic zeros in - the ``fwd`` and ``bwd`` rules. Setting this option to ``True`` allows - custom derivative rules to detect when certain inputs, and when certain - cotangent outputs, are not involved in differentiation. If ``True``: - - * ``fwd`` must accept, for each leaf value ``x`` in the pytree - comprising an argument to the original function, a pair ``(x, zero)``, - where ``x`` is the original argument and ``zero`` is a boolean. The - ``zero`` part indicates whether or not the argument is not involved in - differentiation (i.e., whether the corresponding Jacobian "column" is - zero). - - * ``bwd`` will be passed objects representing static symbolic zeros in - its cotangent argument in correspondence with unperturbed values; - otherwise, only standard JAX types (e.g. array-likes) are passed. - - Setting this option to ``True`` allows these rules to detect whether - certain inputs and outputs are not involved in differentiation, but at - the cost of special handling: the signature of ``fwd`` changes, and - ``bwd`` receives objects that, for instance, cannot be passed to - ``jax.numpy`` functions. Default ``False``. Returns: None. @@ -553,7 +527,6 @@ def f_bwd(res, g): """ self.fwd = fwd self.bwd = bwd - self.symbolic_zeros = symbolic_zeros @traceback_util.api_boundary def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation @@ -561,7 +534,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable if not self.fwd or not self.bwd: msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." raise AttributeError(msg) - fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) + fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) args = _resolve_kwargs(self.fun, args, kwargs) if config.jax_enable_custom_vjp_by_custom_transpose: if self.nondiff_argnums: @@ -582,7 +555,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable else: f_, dyn_args = lu.wrap_init(self.fun), args fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) - fwd = _project_fwd(fwd, self.symbolic_zeros) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) @@ -590,22 +562,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable out_type) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, - *args_flat, out_trees=out_trees, - symbolic_zeros=self.symbolic_zeros) + *args_flat, out_trees=out_trees) _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees) return tree_unflatten(out_tree, out_flat) -@dataclasses.dataclass -class ZeroTagged: - val: Any - zero: bool - -@lu.transformation -def _project_fwd(symbolic_zeros, *args, **kwargs): - project_leaf = ((lambda x: (x.val, x.zero)) if symbolic_zeros else - (lambda x: x.val)) - yield (yield tree_map(project_leaf, (args, kwargs))) - def _check_for_tracers(x): for leaf in tree_leaves(x): if isinstance(x, core.Tracer): @@ -619,10 +579,8 @@ def _check_for_tracers(x): raise UnexpectedTracerError(msg) @lu.transformation_with_aux -def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, - *args): - tagged_args = [ZeroTagged(x, z) for x, z in zip(args[::2], args[1::2])] - py_args = tree_unflatten(in_tree, tagged_args) +def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args): + py_args = tree_unflatten(in_tree, args) pair_out = yield py_args, {} if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " @@ -714,7 +672,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args): class CustomVJPCallPrimitive(core.CallPrimitive): initial_style: core.Primitive - def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros): + def bind(self, fun, fwd, bwd, *args, out_trees): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = process_env_traces( @@ -724,8 +682,7 @@ def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros): tracers = map(top_trace.full_raise, args) # type: ignore bwd_ = lambda *args: bwd(*args) outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, - out_trees=out_trees, - symbolic_zeros=symbolic_zeros) + out_trees=out_trees) fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) if fst: return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) @@ -791,34 +748,31 @@ def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): def _custom_vjp_call_jaxpr_jvp( primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): + fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], + bwd: Callable, out_trees: Callable, num_consts: int): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): raise ad.CustomVJPException() - zeros = [type(t) is Zero for t in args_dot] - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers! + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers! out_tree, res_tree = out_trees() - res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) # Cast float0 to zeros with the primal dtype because custom vjp rules don't # currently handle float0s args_dot = map(ad.replace_float0s, args, args_dot) + res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) + res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] tangents_out = ad.custom_lin_p.bind( - *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) + *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out) tangents_out = map(ad.recast_to_float0, primals_out, tangents_out) return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp -def _custom_vjp_call_jaxpr_vmap( - spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, - fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): +def _custom_vjp_call_jaxpr_vmap(spmd_axis_name, + axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, + fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], + bwd: Callable, out_trees: Callable, num_consts: int): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] @@ -831,8 +785,8 @@ def _custom_vjp_call_jaxpr_vmap( out_dims2 = [] @pe._memoize - def batched_fwd_jaxpr_thunk(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers + def batched_fwd_jaxpr_thunk(): + fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name, main_type) @@ -841,20 +795,17 @@ def batched_fwd_jaxpr_thunk(*zeros): fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] - batched_bwd = batching.batch_custom_vjp_bwd( - bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type, - spmd_axis_name) + batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims, + fwd_args_batched, main_type, spmd_axis_name) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, - num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) + out_trees=out_trees, num_consts=num_consts) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims -batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \ - _custom_vjp_call_jaxpr_vmap -batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial( - _custom_vjp_call_jaxpr_vmap, None) +batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap +batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(_custom_vjp_call_jaxpr_vmap, None) xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 8e4c41434110..bdcf3b961e51 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -27,9 +27,9 @@ from jax._src import core from jax._src import source_info_util from jax._src.ad_util import ( - add_jaxvals, add_jaxvals_p, replace_internal_symbolic_zeros, - replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, - zeros_like_jaxval, zeros_like_p) + add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval, + zeros_like_p, Zero, replace_internal_symbolic_zeros, + replace_rule_output_symbolic_zeros) from jax._src.api_util import flatten_fun, flatten_fun_nokwargs from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, raise_to_shaped) @@ -387,24 +387,16 @@ def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros): def post_process_custom_jvp_call(self, out_tracers, _): raise CustomJVPException() - def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, - symbolic_zeros): + def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - fwd_in = [(core.full_lower(p), type(t) is Zero) - for p, t in zip(primals_in, tangents_in)] - fwd_in = [x for pair in fwd_in for x in pair] # flatten - res_and_primals_out = fwd.call_wrapped(*fwd_in) + tangents_in = map(instantiate_zeros, tangents_in) + res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) out_tree, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] - # We don't need to handle any symbolic zeros on tangents_in or - # tangents_out below, because custom_lin_p is never executed and - # doesn't correspond to any custom user rule. - # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! - tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) + out_avals=avals_out) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) @@ -753,15 +745,10 @@ def raise_custom_vjp_error_on_jvp(*_, **__): "function.") custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp) -def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals, - symbolic_zeros): +def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals): res, _ = split_list(invals, [num_res]) - if symbolic_zeros: - cts_out = map(replace_internal_symbolic_zeros, cts_out) - else: - cts_out = map(instantiate_zeros_aval, out_avals, cts_out) + cts_out = map(instantiate_zeros_aval, out_avals, cts_out) cts_in = bwd(*res, *cts_out) - cts_in = map(replace_rule_output_symbolic_zeros, cts_in) return [None] * num_res + list(cts_in) primitive_transposes[custom_lin_p] = _custom_lin_transpose diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index c9100f141c6a..323104692178 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -23,19 +23,19 @@ import jax from jax.config import config -from jax.interpreters import partial_eval as pe from jax._src import core from jax._src import source_info_util -from jax._src import linear_util as lu -from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, - zeros_like_p, Zero, SymbolicZero, - replace_rule_output_symbolic_zeros, instantiate) from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName from jax._src.tree_util import (tree_unflatten, tree_flatten, register_pytree_node) +from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, + zeros_like_p, Zero, SymbolicZero, + replace_rule_output_symbolic_zeros, instantiate) +from jax._src import linear_util as lu from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache) +from jax.interpreters import partial_eval as pe Array = Any map, unsafe_map = safe_map, map @@ -473,19 +473,16 @@ def todo(vals): return map(partial(BatchTracer, trace), vals, dims, srcs) return vals, todo - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, - symbolic_zeros): # pytype: disable=signature-mismatch + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): # pytype: disable=signature-mismatch in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} - fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) - fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims) + fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name) - out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) + out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: _, res_tree = out_trees() @@ -782,14 +779,8 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals): def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name): def new_bwd(*args): - in_dims_ = in_dims() if callable(in_dims) else in_dims - args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval)) - if type(x) is SymbolicZero else x - for x, dim in zip(args, in_dims_)] - in_dims_ = [None if type(x) is SymbolicZero else d - for x, d in zip(args, in_dims_)] bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd)) - bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type, + bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type, spmd_axis_name) bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests) @@ -806,7 +797,7 @@ def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py # TODO(mattjj): dedup with matchaxis - if isinstance(x, (Zero, SymbolicZero)): + if isinstance(x, Zero): if src == dst: return x elif type(src) == type(dst) == int: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 86c1afd0d68e..983245752ac7 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -945,11 +945,10 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, return self.process_primitive(fake_primitive, tracers, {}) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, - out_trees, symbolic_zeros): + out_trees): bind = HashableFunction( - lambda *args, **kwargs: primitive.bind( - fun, fwd, bwd, *args, out_trees=out_trees, - symbolic_zeros=symbolic_zeros, **kwargs), + lambda *args, **kwargs: primitive.bind(fun, fwd, bwd, *args, + out_trees=out_trees, **kwargs), (primitive, fun, fwd, bwd)) fake_primitive = FakePrimitive(multiple_results=True, bind=bind) return self.process_primitive(fake_primitive, tracers, {}) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 4afed82bc698..37b65bf85cdc 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1537,12 +1537,11 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): def post_process_custom_jvp_call(self, out_tracers, _): assert False # unreachable assuming jax2tf runs with clean trace state - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): # Drop the custom differentiation rule and act like a call primitive. This # behavior is desirable because jax2tf stages code out of the JAX system, so # there are no more JAX differentiation transformations to be applied. - del fwd, bwd, out_trees, symbolic_zeros # Unused. + del fwd, bwd, out_trees # Unused. return self.process_call(core.call_p, fun, tracers, {}) def post_process_custom_vjp_call(self, out_tracers, _): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index f20097fbd842..b6d44923a5ad 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -496,16 +496,14 @@ def process_custom_transpose(self, prim, call, tracers, **params): for t in out_tracers: t.recipe = eqn return out_tracers - def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, - symbolic_zeros): + def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees): # TODO(mattjj): after old remat is deleted, make this method trivial. # Because we instantiate all tracers, in_knowns is all False. tracers = map(self.instantiate_const_abstracted, tracers) in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) f = trace_to_subjaxpr_nounits(f, self.main, True) f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) + out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees) out_knowns, out_avals, jaxpr, env = aux() out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) res_tracers = map(self.new_instantiated_const, res) @@ -515,9 +513,8 @@ def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) @_memoize - def fwd_jaxpr_thunk(*zeros): - fwd_ = _interleave_fun(fwd, zeros) - fwd_ = trace_to_subjaxpr_nounits(fwd_, self.main, True) + def fwd_jaxpr_thunk(): + fwd_ = trace_to_subjaxpr_nounits(fwd, self.main, True) fwd_, aux = partial_eval_wrapper_nounits( fwd_, tuple(in_knowns), tuple(in_avals)) with core.new_sublevel(): @@ -534,8 +531,7 @@ def fwd_jaxpr_thunk(*zeros): dict(fun_jaxpr=closed_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_thunk, num_consts=len(res) + len(env), - bwd=bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros), + bwd=bwd, out_trees=out_trees), jaxpr.effects, source) for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) @@ -1909,29 +1905,23 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): def post_process_custom_jvp_call(self, out_tracers, _): assert False # unreachable - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): in_avals = [t.aval for t in tracers] with core.new_sublevel(): fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) - main_ = ref(self.main) - @_memoize - def fwd_jaxpr_from_zeros(*zeros): - fwd_ = _interleave_fun(fwd, zeros) - return trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals)[::2] - + fwd_jaxpr_thunk = _memoize( + lambda: trace_to_subjaxpr_dynamic(fwd, main_(), in_avals)[::2]) out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, dict(fun_jaxpr=closed_fun_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, + fwd_jaxpr_thunk=fwd_jaxpr_thunk, num_consts=len(consts), - bwd=bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros), + bwd=bwd, out_trees=out_trees), fun_jaxpr.effects, source_info_util.current()) self.frame.add_eqn(eqn) @@ -1980,23 +1970,18 @@ def process_custom_transpose(self, prim, call, tracers, custom_staging_rules: Dict[Primitive, Callable] = {} -@lu.transformation -def _interleave_fun(every_others, *args, **kwargs): - args_ = [x for pair in zip(args, every_others) for x in pair] - yield (yield (args_, kwargs)) - -def _memoize(fn): - cells = {} +def _memoize(thunk): + cell = [] saved_state = [core.thread_local_state.trace_state.copy()] - def memoized(*args): - if args not in cells: + def memoized(): + if not cell: prev_state = core.thread_local_state.trace_state core.thread_local_state.trace_state = saved_state.pop() try: - cells[args] = fn(*args) + cell.append(thunk()) finally: core.thread_local_state.trace_state = prev_state - return cells[args] + return cell[0] return memoized # TODO(mattjj): remove this DebugInfo and helper functions, replace with diff --git a/tests/api_test.py b/tests/api_test.py index e615bb678b81..8a3baf8ee048 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7308,10 +7308,10 @@ def run(primal_ins, tangent_ins): primal_outs, tangent_outs = run(primal_ins, tangent_ins) primal_out1, primal_out2 = primal_outs tangent_out1, tangent_out2 = tangent_outs - scalar_type = jax.Array if maybe_jit or maybe_vmap else float - self.assertIsInstance(primal_out1, scalar_type) + scalar_dtype = jax.Array if maybe_jit or maybe_vmap else float + self.assertIsInstance(primal_out1, scalar_dtype) self.assertAllClose(primal_out1, 5.) - self.assertIsInstance(tangent_out1, scalar_type) + self.assertIsInstance(tangent_out1, scalar_dtype) self.assertAllClose(tangent_out1, 91.) self.assertIsInstance(primal_out2, jax.Array) self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) @@ -8389,183 +8389,6 @@ def f(x): f_vjp(jnp.array([3.])) f_vjp(jnp.array([3.])) # doesn't crash - def test_symbolic_zero_custom_vjp_basic(self): - @jax.custom_vjp - def f(x, y, z): - return x, x - - def fwd(x, y, z): - self.assertFalse(x[1]) - self.assertTrue(y[1]) - self.assertTrue(z[1]) - return (x[0], x[0]), None - - def fwd_all(x, y, z): - self.assertFalse(x[1]) - self.assertFalse(y[1]) - self.assertFalse(z[1]) - return (x[0], x[0]), None - - def bwd_all(_, g): - x1, x2 = g - self.assertFalse(type(x1) is custom_derivatives_public.SymbolicZero) - self.assertFalse(type(x2) is custom_derivatives_public.SymbolicZero) - return x1, x1, x2 - - def bwd_fst(_, g): - x1, x2 = g - self.assertFalse(type(x1) is custom_derivatives_public.SymbolicZero) - self.assertIs(type(x2), custom_derivatives_public.SymbolicZero) - return x1, x1, x2 - - def bwd_snd(_, g): - x1, x2 = g - self.assertIs(type(x1), custom_derivatives_public.SymbolicZero) - self.assertFalse(type(x2) is custom_derivatives_public.SymbolicZero) - return x1, x1, x2 - - x, y, z = 4., 5., 6. - i = np.array(7, np.int32) - zero = np.array(0.) - - f.defvjp(fwd, bwd_all, symbolic_zeros=True) - h = jax.jit(f) - jax.jacrev(h)(x, y, z) - jax.jacrev(lambda x: h(x, y, z))(x) - jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) - - f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) - fst_f = lambda *xs: f(*xs)[0] - _, vjp = jax.vjp(fst_f, x, y, z) - _, _, gz = vjp(x) - self.assertArraysAllClose(gz, zero) - - f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) - snd_f = lambda *xs: f(*xs)[1] - _, vjp = jax.vjp(snd_f, x, y, z) - gx, gy, _ = vjp(x) - self.assertArraysAllClose(gx, zero) - self.assertArraysAllClose(gy, zero) - - f.defvjp(fwd, bwd_snd, symbolic_zeros=True) - _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) - gx, = vjp(x) - self.assertArraysAllClose(gx, zero) - - @parameterized.named_parameters( - ('jit_vmap', True, True), - ('jit', True, False), - ('vmap', False, True), - ('', False, False), - ) - def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): - # below: - # * static_scalar will be static in and out - # * static_array will be static in, but dynamic out - # * dyn_scalar and dyn_array will be dynamic in and out - - ZERO = custom_derivatives_public.SymbolicZero - - def f(static_scalar, static_array, dyn_scalar, dyn_array): - out1 = static_scalar + dyn_scalar - out2 = static_array + dyn_array - return static_scalar, static_array, out1, out2 - - def _pack(x): - return lax.broadcast(x, (1,)) - - def _unpack(x): - (x,) = x - return x - - def _vmap(fun): - def _fun(*args): - args = tree_util.tree_map(_pack, args) - out = jax.vmap(fun)(*args) - out = tree_util.tree_map(_unpack, out) - return out - return _fun - - f = api.custom_vjp(f) - - def fwd(*args): - xs, zeros = [x[0] for x in args], [x[1] for x in args] - self.assertTrue(zeros[0]) - self.assertTrue(zeros[1]) - self.assertFalse(zeros[2]) - self.assertFalse(zeros[3]) - return f(*xs), xs - - def bwd(res, g): - static_scalar, *_ = res - t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g - self.assertIs(type(t_static), ZERO) - self.assertFalse(type(t_static_arr) is ZERO) - self.assertFalse(type(t_dyn_scalar) is ZERO) - self.assertFalse(type(t_dyn_array) is ZERO) - self.assertEqual(t_static.shape, ()) - self.assertEqual(t_static_arr.shape, (2,)) - return (static_scalar + 90, - t_static_arr + 91, - t_dyn_scalar + 92, - t_dyn_array + 93) - - f.defvjp(fwd, bwd, symbolic_zeros=True) - - def g(dyn_scalar, dyn_array): - if maybe_vmap: - f_ = _vmap(f) - else: - f_ = f - outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) - return outs[1:] - - def run(primal_ins, cotangent_outs): - primal_outs, vjp = jax.vjp(g, *primal_ins) - cotangent_ins = vjp(cotangent_outs) - return primal_outs, cotangent_ins - - if maybe_jit: - run = jax.jit(run) - - scalar_type = jax.Array if maybe_jit or maybe_vmap else float - primal_ins = (4., jnp.array([5., 6.])) - cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) - primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) - - primal_out1, primal_out2, primal_out3 = primal_outs - self.assertIsInstance(primal_out1, jax.Array) - self.assertAllClose(primal_out1, jnp.array([2., 3.])) - self.assertIsInstance(primal_out2, scalar_type) - self.assertAllClose(primal_out2, 5.) - self.assertIsInstance(primal_out3, jax.Array) - self.assertAllClose(primal_out3, jnp.array([7., 9.])) - - ct_in1, ct_in2 = cotangent_ins - self.assertIsInstance(ct_in1, scalar_type) - self.assertAllClose(ct_in1, 99.) - self.assertIsInstance(ct_in2, jax.Array) - self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) - - def test_symbolic_zero_custom_vjp_vmap_output(self): - @api.custom_vjp - def f(x, y): - return x, y - - def fwd(x, y): - (x, x0), (y, y0) = x, y - self.assertFalse(x0) - self.assertTrue(y0) - return f(x, y), None - - def bwd(_, g): - ct_x, ct_y = g - #import ipdb; ipdb.set_trace() - self.assertIs(type(ct_y), custom_derivatives_public.SymbolicZero) - return g - - f.defvjp(fwd, bwd, symbolic_zeros=True) - jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) def transpose_unary(f, x_example): def transposed(y): From fcac7b4e557a67f1ac785719b8faf1cec6c04c20 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 22 Mar 2023 10:44:50 -0700 Subject: [PATCH 21/65] jax.random: remove scale from wald function --- jax/_src/random.py | 26 +++++++++----------------- tests/random_test.py | 7 +++---- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index c163d6df3c0d..156ab7cacc2c 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1883,7 +1883,6 @@ def _rayleigh(key, scale, shape, dtype) -> Array: def wald(key: KeyArray, mean: RealArray, - scale: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: """Sample Wald random values with given shape and float dtype. @@ -1892,11 +1891,9 @@ def wald(key: KeyArray, key: a PRNG key used as the random key. mean: a float or array of floats broadcast-compatible with ``shape`` representing the mean parameter of the distribution. - scale: a float or array of floats broadcast-compatible with ``shape`` - representing the scale parameter of the distribution. shape: optional, a tuple of nonnegative integers specifying the result - shape. Must be broadcast-compatible with ``mean`` and ``scale``. The default - (None) produces a result shape equal to ``lax.broadcast_shapes(np.shape(mean), np.shape(scale))``. + shape. Must be broadcast-compatible with ``mean``. The default + (None) produces a result shape equal to ``np.shape(mean)``. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). @@ -1911,28 +1908,23 @@ def wald(key: KeyArray, dtype = dtypes.canonicalize_dtype(dtype) if shape is not None: shape = core.canonicalize_shape(shape) - return _wald(key, mean, scale, shape, dtype) + return _wald(key, mean, shape, dtype) -@partial(jit, static_argnums=(3, 4), inline=True) -def _wald(key, mean, scale, shape, dtype) -> Array: +@partial(jit, static_argnums=(2, 3), inline=True) +def _wald(key, mean, shape, dtype) -> Array: if shape is None: - shape = lax.broadcast_shapes(np.shape(mean), np.shape(scale)) + shape = np.shape(mean) else: - _check_shape("wald", shape, np.shape(mean), np.shape(scale)) + _check_shape("wald", shape, np.shape(mean)) k1, k2 = _split(key, 2) mean = mean.astype(dtype) - scale = scale.astype(dtype) mean = jnp.broadcast_to(mean, shape) - scale = jnp.broadcast_to(scale, shape) v = normal(k1, shape, dtype) z = uniform(k2, shape, dtype) - two = _lax_const(mean, 2) y = lax.integer_pow(v, 2) y_sq = lax.integer_pow(y, 2) mean_sq = lax.integer_pow(mean, 2) - mean_two = lax.mul(mean, two) - scale_two = lax.mul(scale, two) - sqrt_term = lax.sqrt(mean_two * scale_two * y + mean_sq * y_sq) - x = mean + mean_sq * y / scale_two - mean / scale_two * sqrt_term + sqrt_term = lax.sqrt(4 * mean * y + mean_sq * y_sq) + x = mean + mean_sq * y / 2 - mean / 2 * sqrt_term w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x) return w diff --git a/tests/random_test.py b/tests/random_test.py index 9245e832e47d..cdf46e6ade27 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1539,18 +1539,17 @@ def testRayleigh(self, scale, dtype): @jtu.sample_product( mean= [0.2, 1., 2., 10. ,100.], - scale= [0.2, 1., 2., 10. ,100.], dtype=jtu.dtypes.floating) - def testWald(self, mean, scale, dtype): + def testWald(self, mean, dtype): key = self.seed_prng(0) - rand = lambda key: random.wald(key, mean, scale, shape = (10000, ), dtype = dtype) + rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype) crand = jax.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: - self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean / scale, scale = scale).cdf) + self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.invgauss(mu=mean).cdf) class KeyArrayTest(jtu.JaxTestCase): # Key arrays involve: From 78488f00dec09f4facab7c186bd3ff3146395211 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 22 Mar 2023 12:00:29 -0700 Subject: [PATCH 22/65] Improve handling of dynamic shapes in jax native serialization PiperOrigin-RevId: 518634912 --- .../jax2tf/tests/shape_poly_test.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index ab9a5f86a6ce..388e2fc1a35b 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -36,6 +36,7 @@ from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow +from jax._src.lib import xla_client import numpy as np from jax.experimental.jax2tf.tests import tf_test_util @@ -2634,18 +2635,25 @@ def test_harness(self, harness: PolyHarness): # Set of harness.group_name that are unsupported in serialization require_stablehlo_feature_support = { - # Tan and TopK require additional support for dynamic shape lowering + # Tan (b/274462307) and TopK (openxla/stablehlo#1255) require support. "vmap_tan", "vmap_top_k", - # Filter CHLO decompositions that produce shape dialect ops - "vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh", - "vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf", - "vmap_erfc", "vmap_lgamma", "vmap_nextafter", - "vmap_nextafter_broadcasting", "vmap_sinh", + # Crash due to openxla/stablehlo#1328 + "vmap_random_randint", "vmap_random_uniform" } if harness.group_name in require_stablehlo_feature_support: raise unittest.SkipTest( "native lowering with shape polymorphism requires additional StableHLO feature support") - + # API version 47 supports CHLO ops that decompose into shape dialect ops + if xla_client.mlir_api_version < 47: + require_stablehlo_feature_support_shape_dialect = { + "vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh", + "vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf", + "vmap_erfc", "vmap_lgamma", "vmap_nextafter", + "vmap_nextafter_broadcasting", "vmap_sinh" + } + if harness.group_name in require_stablehlo_feature_support_shape_dialect: + raise unittest.SkipTest( + "native lowering with shape polymorphism requires additional StableHLO feature support") if (jtu.device_under_test() == "tpu" and harness.fullname in [ "jnp.cumsum_reduce_axis=poly", From c5ba4d3daf3aa2b45383733afbab70c39903ceab Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 17 Mar 2023 17:45:41 -0700 Subject: [PATCH 23/65] make mlir arg and result names work with pmap This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a jaxpr with debug info (i.e. parameter names and result paths). The difference with the machinery in #15080 is just to deal with pmap being final-style (i.e. build the jaxpr at the last second, well after pytrees have been flattened away and transformations have been applied), whereas the machinery for pjit in imagine, plumbing for the former is a bit more long-range and subtle. The main idea here is that we need to annotate and maintain debug info on the lu.WrappedFun instance, which we first form at the api.py level, then pass through transformations (which can either update or drop debug info), then finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an annotation, parallel with the in_type annotation used for dynamic shapes, because the debug info has to be updated as transformations are applied, since they might e.g. add tangent inputs and outputs. In more detail: with an initial-style higher-orer primitive (like pjit), a jaxpr is formed immediately. Transformations, like autodiff, are jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return a new jaxpr either with updated debug info or no debug info at all. (The initial implementation in #15080 doesn't provide updated debug info in any of those jaxpr-to-jaxpr transformation functions, so the debug info is only applied to the jaxpr and then lowered to MLIR when the pjit as at the top level.) For final-style, like pmap here, instead of transformations being jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously, transformations, like ad.JVPTrace.process_map, would need to produce a WrappedFun with updated debug info or no debug info at all. (ALso analogously to #15080, this PR only implements enough for the debug info to be preserved for top-level pmaps.) This PR doens't yet delete the trace-time debug info in partial_eval.py. But that'll happen too! --- jax/_src/api.py | 6 ++++- jax/_src/api_util.py | 44 +++++++++++++++++++---------------- jax/_src/interpreters/pxla.py | 20 ++++++---------- jax/_src/linear_util.py | 39 ++++++++++++++++++++++++------- tests/pmap_test.py | 2 -- 5 files changed, 66 insertions(+), 45 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index c614d0914eb8..6ea645c28199 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -59,7 +59,7 @@ rebase_donate_argnums, _ensure_index, _ensure_index_tuple, shaped_abstractify, _ensure_str_tuple, argnames_partial_except, validate_argnames, validate_argnums, check_callable, resolve_argnums, - FLAGS) + debug_info, result_paths, debug_info_final, FLAGS) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -1634,6 +1634,8 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, if in_devices is not None and len(in_devices) == 0: raise ValueError("'devices' argument to pmap must be non-empty, or None.") + dbg = debug_info('pmap', fun, args, kwargs, static_broadcasted_tuple, ()) + f = lu.wrap_init(fun) if static_broadcasted_tuple: if max(static_broadcasted_tuple) >= len(args): @@ -1671,7 +1673,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, kws=True)) local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap") + f, res_paths = result_paths(f) flat_fun, out_tree = flatten_fun(f, in_tree) + flat_fun = debug_info_final(flat_fun, dbg, res_paths) if any(out_axis is None for out_axis in tree_flatten(out_axes)): raise NotImplementedError("None out_axes in pmap are not supported yet") diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 8cce1351ad75..e88df6d0687b 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -15,8 +15,8 @@ import inspect import operator from functools import partial -from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional, - Sequence, Set, Tuple, Union,) +from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence, + Set, Tuple, Union) import warnings import numpy as np @@ -30,7 +30,9 @@ treedef_children, generate_key_paths, keystr) from jax._src.tree_util import _replace_nones from jax._src import linear_util as lu -from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable +from jax._src.linear_util import TracingDebugInfo +from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, + Unhashable) from jax._src.config import flags, bool_env from jax._src import traceback_util traceback_util.register_exclusion(__file__) @@ -585,14 +587,6 @@ def api_hook(fun, tag: str): return fun -class TracingDebugInfo(NamedTuple): - # Packages up trace/staging-time debug info about a func and its parameters, - # formed just before staging to a jaxpr and read in trace-time error messages. - # TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls - traced_for: str # e.g. 'jit', 'scan', etc - func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}' - arg_names: Tuple[str, ...] # e.g. ('args[0]', ... ) - def debug_info(traced_for: str, fun: Callable, args: Tuple[Any], kwargs: Dict[str, Any], static_argnums: Tuple[int, ...], static_argnames: Tuple[str, ...]) -> Optional[TracingDebugInfo]: @@ -600,7 +594,7 @@ def debug_info(traced_for: str, fun: Callable, args: Tuple[Any], src = fun_sourceinfo(fun) arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames) if src is None or arg_names is None: return None - return TracingDebugInfo(traced_for, src, arg_names) + return TracingDebugInfo(traced_for, src, arg_names, None) # TODO(mattjj): make this function internal to this module def fun_sourceinfo(fun: Callable) -> Optional[str]: @@ -635,13 +629,23 @@ def result_paths(*args, **kwargs): yield ans, [keystr(path) for path, _ in generate_key_paths(ans)] def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo], - result_paths: Optional[Tuple[Optional[str], ...]] + result_paths: Optional[Tuple[Optional[str], ...]] = None, ) -> core.Jaxpr: """Add debug info to jaxpr, given trace-time debug info and result paths.""" - if trace_debug is not None and result_paths is not None: - debug_info = core.JaxprDebugInfo( - trace_debug.traced_for, trace_debug.func_src_info, - trace_debug.arg_names, result_paths) - else: - debug_info = None - return jaxpr.replace(debug_info=debug_info) if debug_info else jaxpr + if trace_debug is None: + return jaxpr + assert (result_paths is not None) ^ (trace_debug.result_paths is not None) + if result_paths is None: + result_paths = trace_debug.result_paths() # type: ignore + debug_info = core.JaxprDebugInfo( + trace_debug.traced_for, trace_debug.func_src_info, + trace_debug.arg_names, result_paths) + return jaxpr.replace(debug_info=debug_info) + +def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo], + res_paths: Callable[[], Tuple[str, ...]]) -> lu.WrappedFun: + "Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun" + if dbg is None: return f + assert dbg.result_paths is None + res_paths_ = HashableFunction(res_paths, closure=()) + return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_)) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 983245752ac7..02075f154c13 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1116,6 +1116,7 @@ def stage_parallel_callable( event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( fun, global_sharded_avals, pe.debug_info_final(fun, "pmap")) + jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) assert len(out_sharded_avals) == len(pci.out_axes), ( @@ -1264,7 +1265,6 @@ def lower_parallel_callable( raise ValueError("Ordered effects not supported in `pmap`.") unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) - arg_names, result_names = _debug_names(jaxpr.debug_info) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, @@ -1278,7 +1278,8 @@ def lower_parallel_callable( replicated_args=replicated_args, arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts), result_shardings=_shardings_to_mlir_shardings(parts.out_parts), - arg_names=arg_names, result_names=result_names) + arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, + result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths) module, keepalive, host_callbacks = ( lowering_result.module, lowering_result.keepalive, lowering_result.host_callbacks) @@ -2564,7 +2565,6 @@ def lower_sharding_computation( unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) - arg_names, result_names = _debug_names(jaxpr.debug_info) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, @@ -2579,7 +2579,8 @@ def lower_sharding_computation( replicated_args=replicated_args, arg_shardings=in_op_shardings, result_shardings=out_op_shardings, - arg_names=arg_names, result_names=result_names) + arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, + result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths) module, keepalive, host_callbacks = ( lowering_result.module, lowering_result.keepalive, @@ -2750,7 +2751,6 @@ def lower_mesh_computation( closed_jaxpr.effects)) ordered_effects = list(effects.ordered_effects.filter_in( closed_jaxpr.effects)) - arg_names, result_names = _debug_names(jaxpr.debug_info) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, @@ -2764,8 +2764,8 @@ def lower_mesh_computation( replicated_args=replicated_args, arg_shardings=in_partitions, result_shardings=out_partitions, - arg_names=arg_names, - result_names=result_names) + arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, + result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths) module, keepalive, host_callbacks = ( lowering_result.module, lowering_result.keepalive, lowering_result.host_callbacks) @@ -2791,12 +2791,6 @@ def lower_mesh_computation( device_assignment=list(mesh.devices.flat), committed=True) -def _debug_names( - dbg: Optional[core.JaxprDebugInfo] -) -> Union[Tuple[None, None], - Tuple[Sequence[Optional[str]], Sequence[Optional[str]]]]: - return (None, None) if dbg is None else (dbg.arg_names, dbg.result_paths) - class MeshComputation(stages.XlaLowering): _hlo: Optional[ir.Module] _executable: Optional[MeshExecutable] diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 6f755e724f1b..bda8e73e18b6 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -64,7 +64,7 @@ def trans1(static_arg, *dynamic_args, **kwargs): from __future__ import annotations from functools import partial -from typing import Any, Tuple, Callable +from typing import Any, Tuple, Callable, Optional, NamedTuple import weakref from jax.tree_util import tree_map @@ -124,14 +124,15 @@ class WrappedFun: params: extra parameters to pass as keyword arguments to `f`, along with the transformed keyword arguments. """ - __slots__ = ("f", "transforms", "stores", "params", "in_type") + __slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info") - def __init__(self, f, transforms, stores, params, in_type): + def __init__(self, f, transforms, stores, params, in_type, debug_info): self.f = f self.transforms = transforms self.stores = stores self.params = params self.in_type = in_type + self.debug_info = debug_info @property def __name__(self): @@ -140,7 +141,7 @@ def __name__(self): def wrap(self, gen, gen_static_args, out_store) -> WrappedFun: """Add another transform and its store.""" return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None) + (out_store,) + self.stores, self.params, None, None) def populate_stores(self, stores): """Copy the values from the `stores` into `self.stores`.""" @@ -199,11 +200,13 @@ def transform_to_str(x): return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n' def __hash__(self): - return hash((self.f, self.transforms, self.params, self.in_type)) + return hash((self.f, self.transforms, self.params, self.in_type, + self.debug_info)) def __eq__(self, other): return (self.f == other.f and self.transforms == other.transforms and - self.params == other.params and self.in_type == other.in_type) + self.params == other.params and self.in_type == other.in_type and + self.debug_info == other.debug_info) @curry def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: @@ -231,15 +234,15 @@ def fun_name(f): def wrap_init(f, params=None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" params = () if params is None else tuple(sorted(params.items())) - return WrappedFun(f, (), (), params, None) + return WrappedFun(f, (), (), params, None, None) -def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun: +def annotate(f: WrappedFun, in_type: Optional[core.InputType]) -> WrappedFun: assert f.in_type is None if in_type is None: return f _check_input_type(in_type) - return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type) + return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info) def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed @@ -271,6 +274,24 @@ def valid_size(d) -> bool: assert all(provided) +class TracingDebugInfo(NamedTuple): + # Packages up trace/staging-time debug info about a func and its parameters, + # formed just before staging to a jaxpr and read in trace-time error messages. + # TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls + traced_for: str # e.g. 'jit', 'scan', etc + func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}' + arg_names: Tuple[str, ...] # e.g. ('args[0]', ... ) + result_paths: Optional[Callable[[], Tuple[str, ...]]] + +def add_debug_info(f: WrappedFun, debug_info: Optional[TracingDebugInfo] + ) -> WrappedFun: + """Produce a new WrappedFun with debug_info attached.""" + assert f.debug_info is None + if debug_info is None: + return f + return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info) + + def cache(call: Callable): """Memoization decorator for functions taking a WrappedFun as first argument. diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 4d8c5a36f1f4..b0bf9deaef13 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2036,7 +2036,6 @@ def test_remat_of_pmap_policy(self, remat): self.assertEqual(jaxpr_text.count(' cos '), 2) def test_pmap_lower_arg_info(self): - raise SkipTest("arg info not plumbed to pmap yet") # TODO(mattjj) def f(x, y, *args, **kwargs): return y['hi'] + args[1] + sum(kwargs.values()) @@ -2052,7 +2051,6 @@ def f(x, y, *args, **kwargs): self.assertIn("kwargs['w']", mhlo_str) def test_pmap_lower_result_info(self): - raise SkipTest("arg info not plumbed to pmap yet") # TODO(mattjj) def f(x, y, z): return {'a': x, 'b': [y]} From 499372d4161c28211a85d5d443e053fa894eafd1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 22 Mar 2023 10:55:30 -0700 Subject: [PATCH 24/65] DOC: add formulae for distributions in jax.random --- jax/_src/random.py | 229 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 200 insertions(+), 29 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 156ab7cacc2c..e2799dda357c 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -552,7 +552,14 @@ def choice(key: KeyArray, def normal(key: KeyArray, shape: Union[Shape, NamedShape] = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample standard normal random values with given shape and float dtype. + r"""Sample standard normal random values with given shape and float dtype. + + The values are returned according to the probability density function: + + .. math:: + f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2} + + on the domain :math:`-\infty < x < \infty` Args: key: a PRNG key used as the random key. @@ -600,7 +607,15 @@ def multivariate_normal(key: KeyArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = None, method: str = 'cholesky') -> Array: - """Sample multivariate normal random values with given mean and covariance. + r"""Sample multivariate normal random values with given mean and covariance. + + The values are returned according to the probability density function: + + .. math:: + f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)} + + where :math:`k` is the dimension, :math:`\mu` is the mean (given by ``mean``) and + :math:`\Sigma` is the covariance matrix (given by ``cov``). Args: key: a PRNG key used as the random key. @@ -673,7 +688,14 @@ def truncated_normal(key: KeyArray, upper: RealArray, shape: Optional[Union[Shape, NamedShape]] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample truncated standard normal random values with given shape and dtype. + r"""Sample truncated standard normal random values with given shape and dtype. + + The values are returned according to the probability density function: + + .. math:: + f(x) \propto e^{-x^2/2} + + on the domain :math:`\rm{lower} < x < \rm{upper}`. Args: key: a PRNG key used as the random key. @@ -729,7 +751,14 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: def bernoulli(key: KeyArray, p: RealArray = np.float32(0.5), shape: Optional[Union[Shape, NamedShape]] = None) -> Array: - """Sample Bernoulli random values with given shape and mean. + r"""Sample Bernoulli random values with given shape and mean. + + The values are distributed according to the probability mass function: + + .. math:: + f(k; p) = p^k(1 - p)^{1 - k} + + where :math:`k \in \{0, 1\}` and :math:`0 \le p \le 1`. Args: key: a PRNG key used as the random key. @@ -769,7 +798,14 @@ def beta(key: KeyArray, b: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Beta random values with given shape and float dtype. + r"""Sample Beta random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(x;a,b) \propto x^{a - 1}(1 - x)^{b - 1} + + on the domain :math:`0 \le x \le 1`. Args: key: a PRNG key used as the random key. @@ -820,7 +856,14 @@ def _beta(key, a, b, shape, dtype) -> Array: def cauchy(key: KeyArray, shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Cauchy random values with given shape and float dtype. + r"""Sample Cauchy random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(x) \propto \frac{1}{x^2 + 1} + + on the domain :math:`-\infty < x < \infty` Args: key: a PRNG key used as the random key. @@ -852,7 +895,19 @@ def dirichlet(key: KeyArray, alpha: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Dirichlet random values with given shape and float dtype. + r"""Sample Dirichlet random values with given shape and float dtype. + + The values are distributed according the the probability density function: + + .. math:: + f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i} + + Where :math:`k` is the dimension, and :math:`\{x_i\}` satisfies + + .. math:: + \sum_{i=1}^k x_i = 1 + + and :math:`0 \le x_i \le 1` for all :math:`x_i`. Args: key: a PRNG key used as the random key. @@ -910,7 +965,14 @@ def _softmax(x, axis) -> Array: def exponential(key: KeyArray, shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Exponential random values with given shape and float dtype. + r"""Sample Exponential random values with given shape and float dtype. + + The values are distributed according the the probability density function: + + .. math:: + f(x) = e^{-x} + + on the domain :math:`0 \le x < \infty`. Args: key: a PRNG key used as the random key. @@ -1074,9 +1136,16 @@ def gamma(key: KeyArray, a: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Gamma random values with given shape and float dtype. + r"""Sample Gamma random values with given shape and float dtype. + + The values are distributed according the the probability density function: - This implements the standard gamma density, with a unit scale/rate parameter. + .. math:: + f(x;a) \propto x^{a - 1} e^{-x} + + on the domain :math:`0 \le x < \infty`, with :math:`a > 0`. + + This is the standard gamma density, with a unit scale/rate parameter. Dividing the sample output by the rate is equivalent to sampling from *gamma(a, rate)*, and multiplying the sample output by the scale is equivalent to sampling from *gamma(a, scale)*. @@ -1254,7 +1323,14 @@ def poisson(key: KeyArray, lam: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeInt = dtypes.int_) -> Array: - """Sample Poisson random values with given shape and integer dtype. + r"""Sample Poisson random values with given shape and integer dtype. + + The values are distributed according to the probability mass function: + + .. math:: + f(k; \lambda) = \frac{\lambda^k e^{-\lambda}}{k!} + + Where `k` is a non-negative integer and :math:`\lambda > 0`. Args: key: a PRNG key used as the random key. @@ -1291,6 +1367,11 @@ def gumbel(key: KeyArray, dtype: DTypeLikeFloat = dtypes.float_) -> Array: """Sample Gumbel random values with given shape and float dtype. + The values are distributed according to the probability density function: + + .. math:: + f(x) = e^{-(x + e^{-x})} + Args: key: a PRNG key used as the random key. shape: optional, a tuple of nonnegative integers representing the result @@ -1361,7 +1442,12 @@ def categorical(key: KeyArray, def laplace(key: KeyArray, shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Laplace random values with given shape and float dtype. + r"""Sample Laplace random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(x) = \frac{1}{2}e^{-|x|} Args: key: a PRNG key used as the random key. @@ -1392,7 +1478,12 @@ def _laplace(key, shape, dtype) -> Array: def logistic(key: KeyArray, shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample logistic random values with given shape and float dtype. + r"""Sample logistic random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(x) = \frac{e^{-x}}{(1 + e^{-x})^2} Args: key: a PRNG key used as the random key. @@ -1423,7 +1514,14 @@ def pareto(key: KeyArray, b: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Pareto random values with given shape and float dtype. + r"""Sample Pareto random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(x; b) = b / x^{b + 1} + + on the domain :math:`0 \le x < \infty` with :math:`b > 0` Args: key: a PRNG key used as the random key. @@ -1464,12 +1562,19 @@ def t(key: KeyArray, df: RealArray, shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Student's t random values with given shape and float dtype. + r"""Sample Student's t random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(t; \nu) \propto \left(1 + \frac{t^2}{\nu}\right)^{-(\nu + 1)/2} + + Where :math:`\nu > 0` is the degrees of freedom, given by the parameter ``df``. Args: key: a PRNG key used as the random key. df: a float or array of floats broadcast-compatible with ``shape`` - representing the parameter of the distribution. + representing the degrees of freedom parameter of the distribution. shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``df``. The default (None) produces a result shape equal to ``df.shape``. @@ -1508,7 +1613,15 @@ def chisquare(key: KeyArray, df: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Chisquare random values with given shape and float dtype. + r"""Sample Chisquare random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(x; \nu) \propto x^{k/2 - 1}e^{-x/2} + + on the domain :math:`0 < x < \infty`, where :math:`\nu > 0` represents the + degrees of freedom, given by the parameter ``df``. Args: key: a PRNG key used as the random key. @@ -1552,7 +1665,17 @@ def f(key: KeyArray, dfden: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample F-distribution random values with given shape and float dtype. + r"""Sample F-distribution random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(x; \nu) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{ + -(\nu_1 + \nu_2) / 2} + + on the domain :math:`0 < x < \infty`. Here :math:`\nu_1` is the degrees of + freedom of the numerator (``dfnum``), and :math:`\nu_2` is the degrees of + freedom of the denominator (``dfden``). Args: key: a PRNG key used as the random key. @@ -1603,7 +1726,14 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array: def rademacher(key: KeyArray, shape: Shape, dtype: DTypeLikeInt = dtypes.int_) -> Array: - """Sample from a Rademacher distribution. + r"""Sample from a Rademacher distribution. + + The values are distributed according to the probability mass function: + + .. math:: + f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1)) + + on the domain :math:`k \in \{-1, 1}`, where `\delta(x)` is the dirac delta function. Args: key: a PRNG key. @@ -1630,9 +1760,14 @@ def _rademacher(key, shape, dtype) -> Array: def maxwell(key: KeyArray, shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample from a one sided Maxwell distribution. + r"""Sample from a one sided Maxwell distribution. + + The values are distributed according to the probability density function: + + .. math:: + f(x) \propto x^2 e^{-x^2 / 2} - The scipy counterpart is `scipy.stats.maxwell`. + on the domain :math:`0 \le x < \infty`. Args: key: a PRNG key. @@ -1666,10 +1801,15 @@ def double_sided_maxwell(key: KeyArray, scale: RealArray, shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample from a double sided Maxwell distribution. + r"""Sample from a double sided Maxwell distribution. + + The values are distributed according to the probability density function: + + .. math:: + f(x;\mu,\sigma) \propto z^2 e^{-z^2 / 2} - Samples using: - loc + scale* sgn(U-0.5)* one_sided_maxwell U~Unif; + where :math:`z = (x - \mu) / \sigma`, with the center :math:`\mu` specified by + ``loc`` and the scale :math:`\sigma` specified by ``scale``. Args: key: a PRNG key. @@ -1712,9 +1852,15 @@ def weibull_min(key: KeyArray, concentration: RealArray, shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample from a Weibull distribution. + r"""Sample from a Weibull distribution. - The scipy counterpart is `scipy.stats.weibull_min`. + The values are distributed according to the probability density function: + + .. math:: + f(x;\sigma,c) \propto x^{c - 1} \exp(-(x / \sigma)^c) + + on the domain :math:`0 < x < \infty`, where :math:`c > 0` is the concentration + parameter, and :math:`\sigma > 0` is the scale parameter. Args: key: a PRNG key. @@ -1788,7 +1934,15 @@ def generalized_normal( shape: Shape = (), dtype: DTypeLikeFloat = dtypes.float_ ) -> Array: - """Sample from the generalized normal distribution. + r"""Sample from the generalized normal distribution. + + The values are returned according to the probability density function: + + .. math:: + f(x;p) \propto e^{-|x|^p} + + on the domain :math:`-\infty < x < \infty`, where :math:`p > 0` is the + shape parameter. Args: key: a PRNG key used as the random key. @@ -1842,7 +1996,15 @@ def rayleigh(key: KeyArray, scale: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Rayleigh random values with given shape and float dtype. + r"""Sample Rayleigh random values with given shape and float dtype. + + The values are returned according to the probability density function: + + .. math:: + f(x;\sigma) \propto xe^{-x^2/(2\sigma^2)} + + on the domain :math:`-\infty < x < \infty`, and where `\sigma > 0` is the scale + parameter of the distribution. Args: key: a PRNG key used as the random key. @@ -1885,7 +2047,16 @@ def wald(key: KeyArray, mean: RealArray, shape: Optional[Shape] = None, dtype: DTypeLikeFloat = dtypes.float_) -> Array: - """Sample Wald random values with given shape and float dtype. + r"""Sample Wald random values with given shape and float dtype. + + The values are returned according to the probability density function: + + .. math:: + f(x;\mu) = \frac{1}{\sqrt{2\pi x^3}} \exp\left(-\frac{(x - \mu)^2}{2\mu^2 x}\right) + + on the domain :math:`-\infty < x < \infty`, and where :math:`\mu > 0` is the location + parameter of the distribution. + Args: key: a PRNG key used as the random key. From aa46778d98da3b995e4b81d6ce8811ab02f1a8de Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Wed, 22 Mar 2023 12:30:48 -0700 Subject: [PATCH 25/65] WAR the dependency issue in the nightly CI container. --- .github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub b/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub index 47dc5a217135..444f9c691d97 100644 --- a/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub +++ b/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub @@ -40,12 +40,16 @@ EXPORTS="--export=ALL,PYTHONPATH=${T5X_DIR}" #------------------------------------------------------------------------------- # Setup command to be run before the actual pytest command +# We remove cudf from the container as a temporary WAR to a protobuf +# dependency issue between cudf and tensorflow-cpu that t5x try to +# install. read -r -d '' setup_cmd < ${E2E_TESTS_WORKSPACE_DIR}/hostname.txt From f9d73cb1c57af510d7c9cfd239908ae36bdf892d Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 22 Mar 2023 21:43:51 +0000 Subject: [PATCH 26/65] Add print statement to help debug spurious test failure --- tests/compilation_cache_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 00d9317c6dbb..12c277701ade 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -339,6 +339,8 @@ def test_cache_read_warning(self): warnings.catch_warnings(record=True) as w: mock_get.side_effect = RuntimeError("test error") self.assertEqual(f(2), 4) + if len(w) > 1: + print("Warnings:", [str(w_) for w_ in w], flush=True) self.assertLen(w, 1) self.assertIn( "Error reading persistent compilation cache entry " From 3039951aadb6999096288dd882b71e21f859cf73 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 22 Mar 2023 13:37:40 -0700 Subject: [PATCH 27/65] add experimental jax_log_checkpoint_residuals option The main idea here is to improve tooling for knowing what residuals are being saved and why. There's a lot more that can be done here (e.g. naming the arguments, explaining what JVP rule produced these residuals, explaining what consumed them, etc) but this is a start. Co-authored-by: Qiao Zhang --- jax/_src/ad_checkpoint.py | 32 ++++++++++++++++++++++--- jax/_src/config.py | 7 ++++++ tests/api_test.py | 50 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 3ca8b9c8b333..57ebd0e90e1b 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +import logging from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple, Union) import types @@ -47,6 +48,8 @@ map = safe_map zip = safe_zip +logger = logging.getLogger(__name__) + allowed_effects: effects.EffectTypeSet = effects.remat_allowed_effects ### Policies @@ -392,6 +395,12 @@ def f_(*args): jaxpr_, out_shape = out jaxpr = jaxpr_.jaxpr out_tree = lambda: tree_structure(out_shape) + assert len(jaxpr.invars) == len(in_leaves) + dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals") + arg_info = pe.arg_info_all(dbg) + return _saved_residuals(jaxpr, arg_info) + +def _saved_residuals(jaxpr, arg_info) -> List[Tuple[core.AbstractValue, str]]: res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} @@ -404,9 +413,6 @@ def f_(*args): if v in res_vars: results.append((v.aval, 'from a constant')) - assert len(jaxpr.invars) == len(in_leaves) - dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals") - arg_info = pe.arg_info_all(dbg) for i, v in enumerate(jaxpr.invars): if v in res_vars: if arg_info is not None: @@ -509,6 +515,26 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, jaxpr_unknown.effects, source_info_util.current()) + + # log info about saved residuals + try: + _, staged_unk = partition_list(in_used_staged, in_unknowns) + res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:]) + res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:] + body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None) + logger.log(logging.WARNING if jax.config.jax_log_checkpoint_residuals + else logging.DEBUG, + 'remat-decorated function ' + + 'saving inputs with shapes:\n' * bool(res_invars) + + ' %s\n' * len(res_invars) + + 'and ' * bool(res_invars) * bool(body_res) + + 'saving these intermediates:\n' * bool(body_res) + + ' %s from %s\n' * len(body_res), + *[v.aval.str_short() for v in res_invars], + *[elt for (a, s) in body_res for elt in [a.str_short(), s]]) + except: + pass # just don't log anything on failure + for t in out_jaxpr_tracers: t.recipe = recipe # zip together known and unknown outputs diff --git a/jax/_src/config.py b/jax/_src/config.py index a0446c5ead40..8d8f6b4c5f79 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -749,6 +749,13 @@ def update_thread_local_jit_state(**kw): 'option is set, the log level is WARNING; otherwise the level is ' 'DEBUG.')) +log_compiles = config.define_bool_state( + name='jax_log_checkpoint_residuals', + default=False, + help=('Log a message every time jax.checkpoint (aka jax.remat) is ' + 'partially evaluated (e.g. for autodiff), printing what residuals ' + 'are saved.')) + parallel_functions_output_gda = config.define_bool_state( name='jax_parallel_functions_output_gda', default=False, diff --git a/tests/api_test.py b/tests/api_test.py index 8a3baf8ee048..3bf6572bb570 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5715,6 +5715,56 @@ def f(x): self.assertIn(' sin ', str(jaxpr)) self.assertIn(' cos ', str(jaxpr)) + def test_remat_residual_logging(self): + def f(x): + x = jnp.sin(x) + x = jnp.cos(x.sum()) + return x + + x = jnp.arange(3.) + + f1 = jax.remat(f) + f2 = jax.remat(f, policy=lambda *_, **__: True) + f3 = jax.remat(f, policy=lambda p, *_, **__: str(p) == 'cos') + + prev_level = logging.get_verbosity() + try: + logging.set_verbosity('DEBUG') + with self.assertLogs(level=logging.DEBUG) as l: + jax.grad(f1)(x) + finally: + logging.set_verbosity(prev_level) + self.assertTrue(any('remat-decorated function saving inputs with shapes:' + in line for line in l.output)) + self.assertFalse(any('intermediates' in line for line in l.output)) + + prev_level = logging.get_verbosity() + try: + logging.set_verbosity('DEBUG') + with self.assertLogs(level=logging.DEBUG) as l: + jax.grad(f2)(x) + finally: + logging.set_verbosity(prev_level) + self.assertFalse(any('saving inputs' in line for line in l.output)) + self.assertTrue(any('remat-decorated function saving these intermediates:' + in line for line in l.output)) + self.assertTrue(any(' sin ' in line for line in l.output)) + self.assertTrue(any(' cos ' in line for line in l.output)) + + prev_level = logging.get_verbosity() + try: + logging.set_verbosity('DEBUG') + with self.assertLogs(level=logging.DEBUG) as l: + jax.grad(f3)(x) + finally: + logging.set_verbosity(prev_level) + self.assertTrue(any('remat-decorated function saving inputs with shapes:' + in line for line in l.output)) + self.assertTrue(any('and saving these intermediates:' + in line for line in l.output)) + self.assertFalse(any(' sin ' in line for line in l.output)) + self.assertTrue(any(' cos ' in line for line in l.output)) + class JaxprTest(jtu.JaxTestCase): From 4a27af37eeeff9fe79172e3f87ffdb25e11ba4d5 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 22 Mar 2023 17:22:39 -0700 Subject: [PATCH 28/65] Redefine `compile_and_serialize` as `serialize(lowered.compile())`. This has the downside of keeping around the UnloadedMeshComputation, but it makes the serialize() API easier to understand. PiperOrigin-RevId: 518715469 --- jax/_src/interpreters/pxla.py | 123 +++++++++++++---------- jax/experimental/serialize_executable.py | 37 +++---- tests/array_test.py | 6 +- 3 files changed, 86 insertions(+), 80 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 02075f154c13..1c8a37c581a8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1299,9 +1299,6 @@ def __init__(self, hlo: ir.Module, **compile_args): self._hlo = hlo self.compile_args = compile_args - def _compile_unloaded(self) -> Union[UnloadedPmapExecutable, PmapExecutable]: - return UnloadedPmapExecutable.from_hlo(self._hlo, **self.compile_args) - # -- stages.XlaLowering overrides def hlo(self) -> xc.XlaComputation: @@ -1319,10 +1316,8 @@ def stablehlo(self) -> ir.Module: @profiler.annotate_function def compile(self) -> PmapExecutable: if self._executable is None: - executable = self._compile_unloaded() - if isinstance(executable, UnloadedPmapExecutable): - executable = executable.load() - self._executable = executable + self._executable = UnloadedPmapExecutable.from_hlo( + self._hlo, **self.compile_args) return self._executable @@ -1471,9 +1466,9 @@ def from_hlo(xla_computation, ordered_effects=ordered_effects, keepalive=keepalive, host_callbacks=host_callbacks, - ) + ).load() - def load(self) -> PmapExecutable: + def build_execute_fun(self): input_indices = [] for aval, spec in safe_zip(self.local_input_avals, self.input_shardings): assert isinstance(spec, sharding_impls.PmapSharding), spec @@ -1489,10 +1484,13 @@ def load(self) -> PmapExecutable: self.ordered_effects, self.keepalive, bool(self.host_callbacks), set(range(len(input_indices)))) + return execute_fun + + def load(self) -> PmapExecutable: fingerprint = getattr(self.compiled, "fingerprint", None) - return PmapExecutable(self.compiled, execute_fun, fingerprint, - self.local_input_avals) + return PmapExecutable(self.compiled, self.build_execute_fun, fingerprint, + self.local_input_avals, self) def _compile_replicated_pmap_executable_from_hlo( @@ -1507,17 +1505,27 @@ def _compile_replicated_pmap_executable_from_hlo( in_indices=input_indices, in_shardings=in_shardings, kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs) # TODO(frostig): need `compile_replicated` to give us the XLA executable - return PmapExecutable(None, execute_fun, None, pci.avals) + return PmapExecutable(None, lambda: execute_fun, None, pci.avals, None) class PmapExecutable(stages.XlaExecutable): - __slots__ = ["xla_executable", "unsafe_call", "fingerprint", "in_avals"] + __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", + "fingerprint", "in_avals", "_unloaded_executable"] - def __init__(self, xla_executable, unsafe_call, fingerprint, in_avals): + def __init__(self, xla_executable, build_unsafe_call, fingerprint, + in_avals, unloaded_executable): self.xla_executable = xla_executable - self.unsafe_call = unsafe_call + self._unsafe_call = None + self.build_unsafe_call = build_unsafe_call self.fingerprint = fingerprint self.in_avals = in_avals + self._unloaded_executable = unloaded_executable + + @property + def unsafe_call(self) -> Callable[..., Any]: + if self._unsafe_call is None: + self._unsafe_call = self.build_unsafe_call() + return self._unsafe_call # -- stages.XlaExecutable overrides @@ -1529,7 +1537,7 @@ def call(self, *args): # TODO(frostig): do we need to check sharding and sharded avals? arg_avals = map(xla.abstractify, args) check_arg_avals_for_call(self.in_avals, arg_avals) - return self.unsafe_call(*args) + return self.unsafe_call(*args) # pylint: disable=not-callable def _get_pmap_sharding(devices, specs): @@ -2804,21 +2812,6 @@ def __init__(self, name: str, hlo: Optional[ir.Module], self.compile_args = compile_args self._executable = None - def _compile_unloaded( - self, - _allow_propagation_to_outputs: Optional[Sequence[bool]] = None, - _allow_compile_replicated: bool = True - ) -> Union[UnloadedMeshExecutable, MeshExecutable]: - if self.is_trivial: - return MeshExecutable.from_trivial_jaxpr(**self.compile_args) - else: - return UnloadedMeshExecutable.from_hlo( - self._name, - self._hlo, - **self.compile_args, - _allow_propagation_to_outputs=_allow_propagation_to_outputs, - _allow_compile_replicated=_allow_compile_replicated) # type: ignore - # -- stages.XlaLowering overrides def hlo(self) -> xc.XlaComputation: @@ -2841,11 +2834,16 @@ def compile(self, _allow_propagation_to_outputs: Optional[Sequence[bool]] = None, _allow_compile_replicated: bool = True) -> MeshExecutable: if self._executable is None: - executable = self._compile_unloaded( - _allow_propagation_to_outputs, _allow_compile_replicated) - if isinstance(executable, UnloadedMeshExecutable): - executable = executable.load() - self._executable = executable + if self.is_trivial: + self._executable = MeshExecutable.from_trivial_jaxpr( + **self.compile_args) + else: + self._executable = UnloadedMeshExecutable.from_hlo( + self._name, + self._hlo, + **self.compile_args, + _allow_propagation_to_outputs=_allow_propagation_to_outputs, + _allow_compile_replicated=_allow_compile_replicated) return self._executable def cost_analysis(self) -> Dict[str, float]: @@ -2952,7 +2950,7 @@ class UnloadedMeshExecutable: kept_var_idx: Set[int] auto_spmd_lowering: bool - def load(self) -> MeshExecutable: + def build_unsafe_call(self): input_indices = _get_input_indices(self.input_avals, self.input_shardings) handle_args = InputsHandler(self.xla_executable.local_devices(), self.input_shardings, input_indices) @@ -2964,11 +2962,14 @@ def load(self) -> MeshExecutable: self.xla_executable, self.name, self.backend, handle_args, handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive, bool(self.host_callbacks), self.kept_var_idx) + return unsafe_call - return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals, + def load(self) -> MeshExecutable: + return MeshExecutable(self.xla_executable, self.build_unsafe_call, + self.input_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, - self.device_assignment) + self.device_assignment, self) # May return a MeshExecutable in the compile_replicated case. @staticmethod @@ -2996,7 +2997,7 @@ def from_hlo(name: str, device_assignment: Sequence[xc.Device], committed: bool, pmap_nreps: int = 1 - ) -> Union[MeshExecutable, UnloadedMeshExecutable]: + ) -> MeshExecutable: dev: np.ndarray if auto_spmd_lowering: @@ -3119,7 +3120,7 @@ def from_hlo(name: str, keepalive=keepalive, host_callbacks=host_callbacks, kept_var_idx=kept_var_idx, - auto_spmd_lowering=auto_spmd_lowering) + auto_spmd_lowering=auto_spmd_lowering).load() class MeshExecutableFastpathData(NamedTuple): @@ -3134,24 +3135,35 @@ class MeshExecutableFastpathData(NamedTuple): class MeshExecutable(stages.XlaExecutable): __slots__ = [ - "xla_executable", "unsafe_call", "in_avals", "_in_shardings", - "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx", - "_device_assignment" + "xla_executable", "_unsafe_call", + "build_unsafe_call", "in_avals", + "_in_shardings", "_out_shardings", + "_auto_spmd_lowering", "_kept_var_idx", + "_device_assignment", + "_unloaded_executable", ] - def __init__(self, xla_executable, unsafe_call, in_avals, in_shardings, + def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - device_assignment): + device_assignment, unloaded_executable=None): self.xla_executable = xla_executable - self.unsafe_call = unsafe_call + self.build_unsafe_call = build_unsafe_call # in_avals is a list of global and local avals. Aval is global if input # is a GDA or jax.Array else local. self.in_avals = in_avals + self._unsafe_call = None self._in_shardings = in_shardings self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering self._kept_var_idx = kept_var_idx self._device_assignment = device_assignment + self._unloaded_executable = unloaded_executable + + @property + def unsafe_call(self) -> Callable[..., Any]: + if self._unsafe_call is None: + self._unsafe_call = self.build_unsafe_call() + return self._unsafe_call @staticmethod def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals, @@ -3174,8 +3186,9 @@ def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals, [False] * len(global_out_avals)) unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins, handle_outs, kept_var_idx) - return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings, - out_shardings, False, kept_var_idx, device_assignment) + return MeshExecutable(None, lambda: unsafe_call, global_in_avals, + in_shardings, out_shardings, False, kept_var_idx, + device_assignment, None) # -- stages.XlaExecutable overrides @@ -3189,7 +3202,7 @@ def call(self, *args): check_arg_avals_for_call(ref_avals, arg_avals) # Check the GDA sharding and the input sharding. check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings) - return self.unsafe_call(*args) + return self.unsafe_call(*args) # pylint: disable=not-callable def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: return self._in_shardings @@ -3297,9 +3310,9 @@ def _compile_replicated_mesh_executable_from_hlo( out_avals=global_out_avals, out_shardings=out_shardings, committed=committed, pmap_nreps=pmap_nreps) xla_executable = None - return MeshExecutable(xla_executable, unsafe_call, global_in_avals, + return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, auto_spmd_lowering, - kept_var_idx, device_assignment) + kept_var_idx, device_assignment, None) def _compile_replicated_mesh_executable_from_trivial_jaxpr( @@ -3320,9 +3333,9 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr( in_indices=input_indices, in_shardings=in_shardings, kept_var_idx=kept_var_idx, out_handler=handle_outs, out_shardings=out_shardings, pmap_nreps=pmap_nreps) - return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings, - out_shardings, False, kept_var_idx, - device_assignment) + return MeshExecutable(None, lambda: unsafe_call, global_in_avals, + in_shardings, out_shardings, False, kept_var_idx, + device_assignment, None) @lru_cache() diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index d7b3a6fbaef7..75b63a09d58b 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -21,46 +21,39 @@ from jax._src.lib import xla_client as xc -def compile_and_serialize(lowered: jax.stages.Lowered): - """Compiles a lowered executable, and then serializes the resulting binary. +def serialize(compiled: jax.stages.Compiled): + """Serializes a compiled binary. Because pytrees are not serializable, they are returned so that the user can handle them properly. """ - - from jax.interpreters import pxla - - if isinstance(lowered._lowering, pxla.MeshComputation): - kw = dict(_allow_propagation_to_outputs=[ - pxla._is_unspecified(o) - for o in lowered._lowering.compile_args['out_shardings']]) - else: - kw = {} - - unloaded_compilation = lowered._lowering._compile_unloaded(**kw) - args_info_flat, in_tree = jax.tree_util.tree_flatten(lowered.args_info) + unloaded_executable = getattr(compiled._executable, + '_unloaded_executable', None) + if unloaded_executable is None: + raise ValueError("Compilation does not support serialization") + args_info_flat, in_tree = jax.tree_util.tree_flatten(compiled.args_info) with io.BytesIO() as file: _JaxPjrtPickler(file).dump( - (unloaded_compilation, args_info_flat, lowered._no_kwargs)) - return file.getvalue(), in_tree, lowered.out_tree + (unloaded_executable, args_info_flat, compiled._no_kwargs)) + return file.getvalue(), in_tree, compiled.out_tree -def load_compiled(serialized, - in_tree, - out_tree, - backend: Optional[Union[str, xc.Client]] = None): +def deserialize_and_load(serialized, + in_tree, + out_tree, + backend: Optional[Union[str, xc.Client]] = None): """Constructs a jax.stages.Compiled from a serialized executable.""" if backend is None or isinstance(backend, str): backend = jax.devices(backend)[0].client - (unloaded_compilation, args_info_flat, + (unloaded_executable, args_info_flat, no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load() args_info = in_tree.unflatten(args_info_flat) - loaded_compiled_obj = unloaded_compilation.load() + loaded_compiled_obj = unloaded_executable.load() return jax.stages.Compiled( loaded_compiled_obj, args_info, out_tree, no_kwargs=no_kwargs) diff --git a/tests/array_test.py b/tests/array_test.py index a04e2dae546f..632e15f843f7 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -32,7 +32,7 @@ from jax.interpreters import pxla from jax.experimental.pjit import pjit from jax.experimental.serialize_executable import ( - compile_and_serialize, load_compiled) + serialize, deserialize_and_load) from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P from jax._src import array @@ -1033,8 +1033,8 @@ def fun(x): ).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32)) def verify_serialization(lowered): - serialized, in_tree, out_tree = compile_and_serialize(lowered) - compiled = load_compiled(serialized, in_tree, out_tree) + serialized, in_tree, out_tree = serialize(lowered.compile()) + compiled = deserialize_and_load(serialized, in_tree, out_tree) self.assertEqual(compiled.as_text(), lowered.compile().as_text()) verify_serialization(lowered) From 6fee63c1879e302ce62c0e0038ab58cfd09d6656 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 21 Mar 2023 21:43:20 -0700 Subject: [PATCH 29/65] enable pjit recursive typechecking Give pjit_p a custom typecheck rule, which basically just calls the core._check_call utility (which was made for xla_call_p and core.call_p). This revealed the need for a slight generalization of the custom_typecheck rule signature, for better "context-aware" printing of jaxpr type errors: the rules should have a `ctx_factory` first argument. **The reason this PR touches so many files is just that it makes the trivial tweaks to all existing typecheck rules to accomodate that new signature.** I didn't adapt any other higher-order primitives' rules to actually use the context, but presumably errors for HOPs like scan would be improved by using it. Follow-up work! It's key that core._check_call works with dynamic shapes; this PR is soon to be followed by some djax+pjit PRs! --- jax/_src/core.py | 5 +++-- jax/_src/custom_derivatives.py | 4 ++-- jax/_src/custom_transpose.py | 2 +- jax/_src/lax/control_flow/conditionals.py | 8 +++++--- jax/_src/lax/control_flow/loops.py | 8 +++++--- jax/_src/lax/lax.py | 6 +++--- jax/_src/lax/slicing.py | 2 +- jax/_src/maps.py | 2 +- jax/_src/pjit.py | 6 +++++- jax/experimental/shard_map.py | 2 +- tests/core_test.py | 9 +++++++++ 11 files changed, 36 insertions(+), 18 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 95d2f38e4d4a..68f131697382 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2627,7 +2627,7 @@ class JaxprTypeError(TypeError): pass custom_typechecks: Dict[Primitive, Callable] = {} -def _check_closed_call(*in_atoms, call_jaxpr): +def _check_closed_call(_, *in_atoms, call_jaxpr): in_avals = [x.aval for x in in_atoms] if list(in_avals) != list(call_jaxpr.in_avals): raise JaxprTypeError("Closed call in_avals mismatch") @@ -2726,7 +2726,8 @@ def write(v: Var, a: AbstractValue) -> None: # Compute the type of the primitive application. if prim in custom_typechecks: - out_type, eqn_effects = custom_typechecks[prim](*in_atoms, **eqn.params) + out_type, eqn_effects = custom_typechecks[prim]( + ctx_factory, *in_atoms, **eqn.params) elif prim.call_primitive: out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms, eqn.params) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index e93132a7806f..d0087b5a9d9c 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -403,8 +403,8 @@ def process_env_traces(primitive, level: int, jvp_was_run: bool, *args): custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') -def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts, - symbolic_zeros): +def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk, + num_consts, symbolic_zeros): # TODO(mattjj): could do more checking here... del in_avals, jvp_jaxpr_thunk, num_consts disallowed_effects = allowed_effects.filter_not_in(call_jaxpr.effects) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index c70034b551c5..8bb21f4f21bc 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -179,7 +179,7 @@ def get_bind_params(self, params): # TODO(frostig,mattjj): reinstate checks -def custom_transpose_typecheck(*in_atoms, out_types, **params): +def custom_transpose_typecheck(_, *in_atoms, out_types, **params): del in_atoms, params return out_types, core.no_effects diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 649079cd9e5a..ce748fc4788c 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -733,7 +733,9 @@ def _cond_axis_substitution(params, subst, traverse): branches = tuple(core.subst_axis_names_jaxpr(jaxpr, subst) for jaxpr in params['branches']) return dict(params, branches=branches) -def _cond_typecheck(*in_atoms, branches, linear): +def _cond_typecheck(bind_time, *in_atoms, branches, linear): + if not bind_time: + _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] tc = partial(_typecheck_param, 'cond') tc(branches, 'branches', 'tuple of ClosedJaxpr', @@ -794,7 +796,7 @@ def cond_bind(*args, branches, linear): if config.jax_enable_checks: avals = map(core.get_aval, args) in_atoms = [core.Var(0, '', a) for a in avals] # dummies - _cond_typecheck(*in_atoms, branches=branches, linear=linear) + _cond_typecheck(True, *in_atoms, branches=branches, linear=linear) for jaxpr in branches: core.check_jaxpr(jaxpr.jaxpr) return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear) @@ -810,7 +812,7 @@ def cond_bind(*args, branches, linear): batching.spmd_axis_primitive_batchers[cond_p] = _cond_batching_rule batching.axis_primitive_batchers[cond_p] = partial(_cond_batching_rule, None) xla.register_initial_style_primitive(cond_p) -core.custom_typechecks[cond_p] = _cond_typecheck +core.custom_typechecks[cond_p] = partial(_cond_typecheck, False) core.axis_substitution_rules[cond_p] = _cond_axis_substitution pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom pe.dce_rules[cond_p] = _cond_dce_rule diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 98eb20d25c4f..79c92873bab9 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -908,8 +908,10 @@ def known(*ins_known): new_vars = [*new_inst, *intensive_res, *extensive_res] return eqn_known, eqn_staged, unks_out, inst_out, new_vars -def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry, - jaxpr, linear, unroll): +def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, + num_carry, jaxpr, linear, unroll): + if not bind_time: + _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] tc = partial(_typecheck_param, 'scan') tc(reverse, 'reverse', 'bool', type(reverse) is bool) @@ -1546,7 +1548,7 @@ def fun(*args): ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens))) return z -def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, +def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): # TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b95c8b5f91d4..327f64edce53 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2812,7 +2812,7 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): return shape def _broadcast_in_dim_typecheck_rule( - operand, *dyn_shape, shape, broadcast_dimensions): + _, operand, *dyn_shape, shape, broadcast_dimensions): if not dyn_shape: out_aval, effects = broadcast_in_dim_p.abstract_eval( operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions) @@ -3271,7 +3271,7 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions): raise TypeError(msg.format(dimensions, np.shape(operand))) return tuple(new_sizes) -def _reshape_typecheck_rule(operand, *dyn_shape, new_sizes, dimensions): +def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions): if not dyn_shape: out_aval, effects = reshape_p.abstract_eval( operand.aval, new_sizes=new_sizes, dimensions=dimensions) @@ -4506,7 +4506,7 @@ def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension): return _dyn_shape_staging_rule(trace, iota_p, aval, *dyn_shape, **params) pe.custom_staging_rules[iota_p] = _iota_staging_rule -def _iota_typecheck_rule(*dyn_shape, dtype, shape, dimension): +def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension): if not dyn_shape: out_aval, effects = iota_p.abstract_eval( dtype=dtype, shape=shape, dimension=dimension) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 719322bf23a2..1e8e32e49687 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -916,7 +916,7 @@ def _dynamic_slice_staging_rule(trace, x, *starts_and_dyn_sizes, slice_sizes): *starts_and_dyn_sizes, slice_sizes=slice_sizes) -def _dynamic_slice_typecheck_rule(x, *starts_and_dyn_sizes, slice_sizes): +def _dynamic_slice_typecheck_rule(_, x, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.aval.ndim]) if not dyn: out_aval, effects = dynamic_slice_p.abstract_eval( diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 48ddb8c4272b..a16d881cd9aa 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -889,7 +889,7 @@ def unmap_zero(zero, axes): def _typecheck_xmap( - *in_atoms, call_jaxpr, name, in_axes, out_axes, donated_invars, + _, *in_atoms, call_jaxpr, name, in_axes, out_axes, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes): in_avals = [x.aval for x in in_atoms] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 20a5ab0c7249..490de581b80c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1343,9 +1343,13 @@ def pjit_staging_rule(trace, *args, **params): return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) else: return trace.default_process_primitive(pjit_p, args, params) - pe.custom_staging_rules[pjit_p] = pjit_staging_rule +def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): + return core._check_call(ctx_factory, pjit_p, in_atoms, + dict(params, call_jaxpr=jaxpr.jaxpr)) +core.custom_typechecks[pjit_p] = _pjit_typecheck + def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_): return jaxpr.out_avals, jaxpr.effects diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 36ce1b6f0cff..b8fc67fd27ed 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -402,7 +402,7 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue # Type-checking -def _shard_map_typecheck(*in_atoms, jaxpr, mesh, in_names, out_names, +def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, check_rep): for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)): diff --git a/tests/core_test.py b/tests/core_test.py index bd8b6ead219c..b6fa3492028f 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -428,6 +428,15 @@ def test_check_jaxpr_cond_correct(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr core.check_jaxpr(jaxpr) + def test_check_jaxpr_jit_invalid(self): + jaxpr = make_jaxpr(jax.jit(lambda x, y: x + 1))(1., 2.).jaxpr + pjit_eqn, = jaxpr.eqns + jaxpr._eqns[0] = pjit_eqn._replace(invars=()) + self.assertRaisesRegex( + core.JaxprTypeError, + '0 operands cannot call jaxpr with 2 inputs', + lambda: core.check_jaxpr(jaxpr)) + def test_check_jaxpr_cond_invalid(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond') From dd9f1783787a064f3bc676a50810ba021b43a5cd Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 22 Mar 2023 22:06:22 -0700 Subject: [PATCH 30/65] fix typo: "one of more" -> "one or more" PiperOrigin-RevId: 518762341 --- jax/experimental/jax2tf/examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/examples/README.md b/jax/experimental/jax2tf/examples/README.md index 83f18aff21d4..b049798e7e15 100644 --- a/jax/experimental/jax2tf/examples/README.md +++ b/jax/experimental/jax2tf/examples/README.md @@ -126,7 +126,7 @@ following sequence of steps: * train an MNIST model, and obtain a pair of an inference function and the parameters. - * convert the inference function with jax2tf, for one of more batch sizes. + * convert the inference function with jax2tf, for one or more batch sizes. * save a SavedModel and dump its contents. * reload the SavedModel and run it with TensorFlow to test that the inference function produces the same results as the JAX inference function. From b0e0a94ea212539c37092cc42627f1fef9d15cce Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Thu, 23 Mar 2023 02:31:27 -0700 Subject: [PATCH 31/65] Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses PiperOrigin-RevId: 518803946 --- jax/_src/prng.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 8185ca4046d4..730eb9d3c8a7 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -116,12 +116,12 @@ def _check_prng_key_data(impl, key_data: jax.Array): class PRNGKeyArrayMeta(abc.ABCMeta): """Metaclass for overriding PRNGKeyArray isinstance checks.""" - def __instancecheck__(self, instance): + def __instancecheck__(cls, instance): try: return (isinstance(instance.aval, core.ShapedArray) and type(instance.aval.dtype) is KeyTy) except AttributeError: - super().__instancecheck__(instance) + return super().__instancecheck__(instance) class PRNGKeyArray(metaclass=PRNGKeyArrayMeta): From a9b43102462eb369b71363a75b2469ef426fb948 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 23 Mar 2023 03:16:07 -0700 Subject: [PATCH 32/65] [jax2tf] Fix tests broken by upgrade of XlaCallModule PiperOrigin-RevId: 518811580 --- jax/experimental/jax2tf/tests/jax2tf_test.py | 50 +++++--------------- 1 file changed, 13 insertions(+), 37 deletions(-) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 79fa468b2694..2df6d4a61098 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1554,7 +1554,7 @@ def get_serialized_computation( abstracted_axes: Optional[Tuple[Dict[int, str]]] = None, use_pjit: bool = False, in_shardings = None, - out_shardings = None) -> str: + out_shardings = None) -> Tuple[str, int]: if use_pjit: assert not abstracted_axes lowered = pjit.pjit( @@ -1564,7 +1564,7 @@ def get_serialized_computation( lowered = jax.jit(f_jax, abstracted_axes=abstracted_axes).lower(*args) stablehlo_module_text = mlir.module_to_string(lowered._lowering.stablehlo()) logging.info(f'Serialized ir.Module = {stablehlo_module_text}') - return stablehlo_module_text + return stablehlo_module_text, 3 class XlaCallModuleTest(tf_test_util.JaxToTfTestCase): @@ -1577,9 +1577,10 @@ def f_jax(x): x = np.ones((2, 3), dtype=np.float32) jax_res = f_jax(x) + module, version = get_serialized_computation(f_jax, x) res = tfxla.call_module([x], - version=2, - module=get_serialized_computation(f_jax, x), + version=version, + module=module, Tout=[jax_res.dtype], Sout=[jax_res.shape]) self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res), @@ -1595,9 +1596,10 @@ def f_jax(count, x): x = np.ones((2, 3), dtype=np.float32) jax_res = f_jax(count, x) + module, version = get_serialized_computation(f_jax, count, x) res = tfxla.call_module([count, x], - version=2, - module=get_serialized_computation(f_jax, count, x), + version=version, + module=module, Tout=[jax_res.dtype], Sout=[jax_res.shape]) self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res), @@ -1612,11 +1614,11 @@ def f_jax(x1, x2): x2 = np.ones((3, 4), dtype=np.float32) jax_res = f_jax(x1, x2) - + module, version = get_serialized_computation(f_jax, x1, x2) def f_tf(x1_tf, x2_tf): return tfxla.call_module([x1_tf, x2_tf], - version=2, - module=get_serialized_computation(f_jax, x1, x2), + version=version, + module=module, Tout=[jax_res[0].dtype, jax_res[1].dtype], Sout=[jax_res[0].shape, jax_res[1].shape]) @@ -1624,32 +1626,6 @@ def f_tf(x1_tf, x2_tf): self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res), jax_res) - @unittest.skip("TODO(necula): 'dynamic_iota' op can't be translated to XLA HLO") - def test_shape_poly_arange(self): - if not config.jax_dynamic_shapes: - raise unittest.SkipTest("jax_dynamic_shapes must be enabled") - def f_jax(x): # x: f32[b] - return jnp.arange(x.shape[0]) + x - - x1 = np.ones((5,), dtype=np.float32) - jax_res = f_jax(x1) - - def f_tf(x1_tf): - return tfxla.call_module([x1_tf], - version=2, - module=get_serialized_computation( - f_jax, x1, - abstracted_axes=({ - 0: 'b' - },)), - Tout=[jax_res.dtype], - Sout=[jax_res.shape], - dim_args_spec=('0.0',)) - - res = tf.function(f_tf, jit_compile=True, autograph=False)(x1) - self.assertAllClose( - tf.nest.map_structure(lambda t: t.numpy(), res), jax_res) - @jtu.with_mesh([("x", 2)]) def test_pjit_basic1D(self): @@ -1665,7 +1641,7 @@ def func_jax(x, y): in_shardings=in_axis_resources, out_shardings=out_axis_resources, )(x, x) - module = get_serialized_computation( + module, version = get_serialized_computation( func_jax, x, x, @@ -1675,7 +1651,7 @@ def func_jax(x, y): def f_tf(x_tf, y_tf): return tfxla.call_module([x_tf, y_tf], - version=2, + version=version, module=module, Tout=[x.dtype], Sout=[x.shape]) From 8f6e3c462a1855b036544762a45eab2c5d6e3dd6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 23 Mar 2023 05:12:57 -0700 Subject: [PATCH 33/65] [XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes. PiperOrigin-RevId: 518830467 --- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 10 +++ jax/_src/dtypes.py | 123 +++++++++++++--------------------- jax/_src/interpreters/mlir.py | 2 +- jaxlib/setup.py | 2 +- setup.py | 1 + 6 files changed, 60 insertions(+), 80 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe37847a9752..4210995e6b13 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py # Use pyi instead - additional_dependencies: [types-requests==2.28.11, jaxlib==0.4.6] + additional_dependencies: [types-requests==2.28.11, jaxlib==0.4.6, ml_dtypes==0.0.3] - repo: https://github.com/mwouts/jupytext rev: v1.14.4 diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a4c88ed7cd9..16b1e6be4a77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,10 @@ Remember to align the itemized text with the first line of an item within a list See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). As part of this change the config flag `--jax2tf_default_experimental_native_lowering` has been renamed to `--jax2tf_native_serialization`. + * JAX now depends on `ml_dtypes`, which contains definitions of NumPy types + like bfloat16. These definitions were previously internal to JAX, but have + been split into a separate package to facilitate sharing them with other + projects. * Deprecations * The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead, @@ -33,6 +37,12 @@ Remember to align the itemized text with the first line of an item within a list ## jaxlib 0.4.7 +Changes: + * jaxlib now depends on `ml_dtypes`, which contains definitions of NumPy types + like bfloat16. These definitions were previously internal to JAX, but have + been split into a separate package to facilitate sharing them with other + projects. + ## jax 0.4.6 (Mar 9, 2023) * Changes diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 2637926b381c..993676035a0a 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -22,13 +22,14 @@ import builtins import functools -from typing import cast, overload, Any, Dict, List, Literal, Optional, Set, Tuple, Union +from typing import (cast, overload, Any, Dict, List, Literal, Optional, Set, + Tuple, Type, Union) import warnings +import ml_dtypes import numpy as np from jax._src.config import flags, config -from jax._src.lib import xla_client from jax._src.typing import DType, DTypeLike, OpaqueDType from jax._src import traceback_util @@ -37,15 +38,13 @@ FLAGS = flags.FLAGS # fp8 support -_fp8_enabled = xla_client._version >= 117 -if _fp8_enabled: - float8_e4m3fn: type = xla_client.float8_e4m3fn # pytype: disable=annotation-type-mismatch # typed-numpy - float8_e5m2: type = xla_client.float8_e5m2 # pytype: disable=annotation-type-mismatch # typed-numpy - _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) - _float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2) +float8_e4m3fn: Type[np.generic] = ml_dtypes.float8_e4m3fn +float8_e5m2: Type[np.generic] = ml_dtypes.float8_e5m2 +_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) +_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2) # bfloat16 support -bfloat16: type = xla_client.bfloat16 # pytype: disable=annotation-type-mismatch # typed-numpy +bfloat16: Type[np.generic] = ml_dtypes.bfloat16 _bfloat16_dtype: np.dtype = np.dtype(bfloat16) # Default types. @@ -139,9 +138,9 @@ def scalar_type_of(x: Any) -> type: typ = dtype(x) if typ == bfloat16: return float - elif _fp8_enabled and typ == float8_e4m3fn: + elif typ == float8_e4m3fn: return float - elif _fp8_enabled and typ == float8_e5m2: + elif typ == float8_e5m2: return float elif np.issubdtype(typ, np.bool_): return bool @@ -352,15 +351,14 @@ def __new__(cls, dtype): if _bfloat16_dtype not in cls._finfo_cache: cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo() return cls._finfo_cache[_bfloat16_dtype] - if _fp8_enabled: - if isinstance(dtype, str) and dtype == 'float8_e4m3fn' or dtype == _float8_e4m3fn_dtype: - if _float8_e4m3fn_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo() - return cls._finfo_cache[_float8_e4m3fn_dtype] - if isinstance(dtype, str) and dtype == 'float8_e5m2' or dtype == _float8_e5m2_dtype: - if _float8_e5m2_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo() - return cls._finfo_cache[_float8_e5m2_dtype] + if isinstance(dtype, str) and dtype == 'float8_e4m3fn' or dtype == _float8_e4m3fn_dtype: + if _float8_e4m3fn_dtype not in cls._finfo_cache: + cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo() + return cls._finfo_cache[_float8_e4m3fn_dtype] + if isinstance(dtype, str) and dtype == 'float8_e5m2' or dtype == _float8_e5m2_dtype: + if _float8_e5m2_dtype not in cls._finfo_cache: + cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo() + return cls._finfo_cache[_float8_e5m2_dtype] return super().__new__(cls, dtype) def _issubclass(a: Any, b: Any) -> bool: @@ -380,21 +378,20 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool: This is like :func:`numpy.issubdtype`, but can handle dtype extensions such as :obj:`jax.dtypes.bfloat16`. '""" - if _fp8_enabled: - if a == "float8_e4m3fn": - a = float8_e4m3fn - if a == float8_e4m3fn: - if isinstance(b, np.dtype): - return b == _float8_e4m3fn_dtype - else: - return b in [float8_e4m3fn, np.floating, np.inexact, np.number] - if a == "float8_e5m2": - a = float8_e5m2 - if a == float8_e5m2: - if isinstance(b, np.dtype): - return b == _float8_e5m2_dtype - else: - return b in [float8_e5m2, np.floating, np.inexact, np.number] + if a == "float8_e4m3fn": + a = float8_e4m3fn + if a == float8_e4m3fn: + if isinstance(b, np.dtype): + return b == _float8_e4m3fn_dtype + else: + return b in [float8_e4m3fn, np.floating, np.inexact, np.number] + if a == "float8_e5m2": + a = float8_e5m2 + if a == float8_e5m2: + if isinstance(b, np.dtype): + return b == _float8_e5m2_dtype + else: + return b in [float8_e5m2, np.floating, np.inexact, np.number] if a == "bfloat16": a = bfloat16 if a == bfloat16: @@ -432,22 +429,14 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool: np.dtype('int64'), ] _float_types: List[JAXType] -if _fp8_enabled: - _float_types = [ - np.dtype(float8_e4m3fn), - np.dtype(float8_e5m2), - np.dtype(bfloat16), - np.dtype('float16'), - np.dtype('float32'), - np.dtype('float64'), - ] -else: - _float_types = [ - np.dtype(bfloat16), - np.dtype('float16'), - np.dtype('float32'), - np.dtype('float64'), - ] +_float_types = [ + np.dtype(float8_e4m3fn), + np.dtype(float8_e5m2), + np.dtype(bfloat16), + np.dtype('float16'), + np.dtype('float32'), + np.dtype('float64'), +] _complex_types: List[JAXType] = [ np.dtype('complex64'), np.dtype('complex128'), @@ -460,9 +449,7 @@ def _jax_type(dtype: DType, weak_type: bool) -> JAXType: if weak_type: if dtype == bool: return dtype - if _fp8_enabled and dtype in [_float8_e4m3fn_dtype, _float8_e5m2_dtype]: - return float - if dtype == _bfloat16_dtype: + if dtype in [_float8_e4m3fn_dtype, _float8_e5m2_dtype, _bfloat16_dtype]: return float return type(dtype.type(0).item()) return dtype @@ -478,26 +465,15 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> Dict[JAXType, Lis """ b1, = _bool_types u1, u2, u4, u8, i1, i2, i4, i8 = _int_types - if _fp8_enabled: - f1_e4m3fn, f1_e5m2, bf, f2, f4, f8 = _float_types # pytype: disable=bad-unpacking - else: - bf, f2, f4, f8 = _float_types # pytype: disable=bad-unpacking + f1_e4m3fn, f1_e5m2, bf, f2, f4, f8 = _float_types c4, c8 = _complex_types i_, f_, c_ = _weak_types if jax_numpy_dtype_promotion == 'standard': - if _fp8_enabled: - return { - b1: [i_], - u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_], - f_: [f1_e4m3fn, f1_e5m2, bf, f2, c_], f1_e4m3fn: [], f1_e5m2: [], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], - c_: [c4], c4: [c8], c8: [], - } return { b1: [i_], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_], - f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], + f_: [f1_e4m3fn, f1_e5m2, bf, f2, c_], f1_e4m3fn: [], f1_e5m2: [], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], c_: [c4], c4: [c8], c8: [], } elif jax_numpy_dtype_promotion == 'strict': @@ -684,12 +660,8 @@ def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, raise ValueError("at least one array or dtype is required") dtype, weak_type = _lattice_result_type(*(float_ if arg is None else arg for arg in args)) if weak_type: - if _fp8_enabled: - dtype = canonicalize_dtype( - _default_types['f' if dtype in [_float8_e4m3fn_dtype, _float8_e5m2_dtype, _bfloat16_dtype] else dtype.kind]) - else: - dtype = canonicalize_dtype( - _default_types['f' if dtype == _bfloat16_dtype else dtype.kind]) + dtype = canonicalize_dtype( + _default_types['f' if dtype in [_float8_e4m3fn_dtype, _float8_e5m2_dtype, _bfloat16_dtype] else dtype.kind]) else: dtype = canonicalize_dtype(dtype) return (dtype, weak_type) if return_weak_type_flag else dtype @@ -699,10 +671,7 @@ def check_user_dtype_supported(dtype, fun_name=None): if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}: return np_dtype = np.dtype(dtype) - if _fp8_enabled: - is_custom_dtype = np_dtype.type in [float8_e4m3fn, float8_e5m2, bfloat16] - else: - is_custom_dtype = np_dtype.type in [bfloat16] + is_custom_dtype = np_dtype.type in [float8_e4m3fn, float8_e5m2, bfloat16] if np_dtype.kind not in "biufc" and not is_custom_dtype: msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" msg += f" in {fun_name}" if fun_name else "" diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 4b6187ee5b6e..bcbab7767024 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -299,7 +299,7 @@ def _ndarray_constant_handler(val: np.ndarray, canonicalize_types np.float16, np.float32, np.float64, np.complex64, np.complex128, np.bool_, np.longlong, dtypes.bfloat16]: - register_constant_handler(_scalar_type, _ndarray_constant_handler) + register_constant_handler(_scalar_type, _ndarray_constant_handler) # type: ignore def _python_scalar_handler(dtype, val, canonicalize_dtypes): return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 7d09dce476c8..1c607dfd3535 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -46,7 +46,7 @@ def has_ext_modules(self): author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension'], python_requires='>=3.8', - install_requires=['scipy>=1.5', 'numpy>=1.20'], + install_requires=['scipy>=1.5', 'numpy>=1.20', 'ml_dtypes>=0.0.3'], url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ diff --git a/setup.py b/setup.py index a025aa23fae4..2c8b186e5f05 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ def generate_proto(source): package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, python_requires='>=3.8', install_requires=[ + 'ml_dtypes>=0.0.3', 'numpy>=1.20', 'opt_einsum', 'scipy>=1.5', From c68c3d312ef68ee5e070353f089794bb245b55e0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 21 Mar 2023 09:36:29 -0700 Subject: [PATCH 34/65] jnp.mean: fix incorrect return value for large arrays --- jax/_src/numpy/reductions.py | 10 +++++----- tests/lax_numpy_test.py | 7 +++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 37d6d3fcd9c6..4363fbc3253f 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -327,7 +327,11 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: check_arraylike("mean", a) - dtypes.check_user_dtype_supported(dtype, "mean") + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.dtype(a)) + else: + dtypes.check_user_dtype_supported(dtype, "mean") + dtype = dtypes.canonicalize_dtype(dtype) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") @@ -339,10 +343,6 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) - if dtype is None: - dtype = dtypes.to_inexact_dtype(dtypes.dtype(a)) - dtype = dtypes.canonicalize_dtype(dtype) - return lax.div( sum(a, axis, dtype=dtype, keepdims=keepdims, where=where), lax.convert_element_type(normalizer, dtype)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 782ee78ebcd3..42edd02c51ce 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5075,6 +5075,13 @@ def testFromString(self): actual = jnp.fromstring(s, sep=',', dtype=int) self.assertArraysEqual(expected, actual) + def testMeanLargeArray(self): + # https://github.com/google/jax/issues/15068 + raise unittest.SkipTest("test is slow, but it passes!") + x = jnp.ones((16, 32, 1280, 4096), dtype='int8') + self.assertEqual(1.0, jnp.mean(x)) + self.assertEqual(1.0, jnp.mean(x, where=True)) + # Most grad tests are at the lax level (see lax_test.py), but we add some here # as needed for e.g. particular compound ops of interest. From 75fcc3aa33d955c7e04471dd9c4a202f652143c6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 23 Mar 2023 07:25:01 -0700 Subject: [PATCH 35/65] [XLA:Python] Allow passing ExecutableBuildOptions to outfeed receiver. Outfeed receiver compiles computations (during shutdown), and if the correct options aren't provided, then it may not be able to do things like find ptxas for CUDA builds. Plumb the executable build options through from Python. PiperOrigin-RevId: 518852909 --- jax/experimental/host_callback.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index b3773f010d98..47791ff7c7f4 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1980,9 +1980,16 @@ def _initialize_outfeed_receiver( device_repr = ", ".join([str(d) for d in devices_with_outfeed]) logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s", device_repr, max_callback_queue_size_bytes) - _callback_handler_data.receiver = outfeed_receiver_module.start( - _callback_input_received, tuple(clients_with_outfeed), - max_callback_queue_size_bytes) + if jaxlib.xla_extension_version > 143: + # TODO(phawkins): remove type:ignore after minimum jaxlib version bump + _callback_handler_data.receiver = outfeed_receiver_module.start( + _callback_input_received, tuple(clients_with_outfeed), + max_callback_queue_size_bytes, + xb.get_compile_options(1, 1).executable_build_options) # type:ignore + else: + _callback_handler_data.receiver = outfeed_receiver_module.start( + _callback_input_received, tuple(clients_with_outfeed), + max_callback_queue_size_bytes) def exit_handler(): # Prevent logging usage during compilation, gives errors under pytest From 0e2cf94ade137ccd577ad65d1a4d1318e3616e4f Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Thu, 23 Mar 2023 07:44:40 -0700 Subject: [PATCH 36/65] Add missing file --- .github/workflows/cat_slurm_logs.py | 45 +++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 .github/workflows/cat_slurm_logs.py diff --git a/.github/workflows/cat_slurm_logs.py b/.github/workflows/cat_slurm_logs.py new file mode 100644 index 000000000000..0479a4723302 --- /dev/null +++ b/.github/workflows/cat_slurm_logs.py @@ -0,0 +1,45 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Script used in the nightly-ci-multiprocess-gpu workflow to process logs.""" + +import argparse +import os +from typing import List + +ISSUE_FORMAT = """\ +
Failure summary {name} + +``` +{content} +``` + +
+""" + +def main(logfiles: List[str], outfile: str): + print(f"extracting content of {logfiles}") + print(f"and writing to {outfile}") + with open(outfile, 'w') as f: + for logfile in logfiles: + content = open(logfile).read() + f.write(ISSUE_FORMAT.format(name=os.path.basename(logfile), content=content)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("logfiles", nargs="+", help="The path to the input logfiles") + parser.add_argument("--outfile", help="The path to the parsed output file to be created.", + default="parsed_logs.txt") + args = parser.parse_args() + main(logfiles=args.logfiles, outfile=args.outfile) From 8f4b8a0e9c32981cb7f0b29ad8801d3573181b63 Mon Sep 17 00:00:00 2001 From: Misha <48orusef@gmail.com> Date: Sun, 12 Mar 2023 06:53:09 +0100 Subject: [PATCH 37/65] Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. --- docs/jax.scipy.rst | 17 +- jax/_src/scipy/stats/beta.py | 30 +++- jax/_src/scipy/stats/cauchy.py | 44 ++++- jax/_src/scipy/stats/chi2.py | 49 ++++-- jax/_src/scipy/stats/gamma.py | 26 ++- jax/_src/scipy/stats/logistic.py | 35 ++-- jax/_src/scipy/stats/norm.py | 12 ++ jax/_src/scipy/stats/t.py | 1 + jax/scipy/stats/beta.py | 3 + jax/scipy/stats/cauchy.py | 5 + jax/scipy/stats/chi2.py | 3 + jax/scipy/stats/gamma.py | 3 + jax/scipy/stats/norm.py | 2 + tests/scipy_stats_test.py | 269 ++++++++++++++++++++++++++++--- 14 files changed, 453 insertions(+), 46 deletions(-) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 5d798f76cceb..cad15d09fc66 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -172,6 +172,9 @@ jax.scipy.stats.beta logpdf pdf + cdf + logcdf + sf jax.scipy.stats.betabinom ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -192,6 +195,11 @@ jax.scipy.stats.cauchy logpdf pdf + cdf + logcdf + sf + isf + ppf jax.scipy.stats.chi2 ~~~~~~~~~~~~~~~~~~~~ @@ -202,7 +210,9 @@ jax.scipy.stats.chi2 logpdf pdf - + cdf + logcdf + sf jax.scipy.stats.dirichlet @@ -232,6 +242,9 @@ jax.scipy.stats.gamma logpdf pdf + cdf + logcdf + sf jax.scipy.stats.gennorm ~~~~~~~~~~~~~~~~~~~~~~~ @@ -296,6 +309,8 @@ jax.scipy.stats.norm logpdf pdf ppf + sf + isf jax.scipy.stats.pareto ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index 73ba2725da68..cf8ae2e3b8e9 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.typing import Array, ArrayLike -from jax.scipy.special import betaln, xlogy, xlog1py +from jax.scipy.special import betaln, betainc, xlogy, xlog1py @_wraps(osp_stats.beta.logpdf, update_doc=False) @@ -40,3 +40,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.beta.cdf, update_doc=False) +def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale) + return betainc( + a, + b, + lax.clamp( + _lax_const(x, 0), + lax.div(lax.sub(x, loc), scale), + _lax_const(x, 1), + ) + ) + + +@_wraps(osp_stats.beta.logcdf, update_doc=False) +def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.beta.sf, update_doc=False) +def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + cdf_result = cdf(x, a, b, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 169cc6f85aee..426b1eec0a07 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -18,8 +18,8 @@ from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.util import promote_args_inexact +from jax._src.numpy.util import _wraps, promote_args_inexact +from jax.numpy import arctan from jax._src.typing import Array, ArrayLike @@ -31,6 +31,46 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) + @_wraps(osp_stats.cauchy.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) + + + +@_wraps(osp_stats.cauchy.cdf, update_doc=False) +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale) + pi = _lax_const(x, np.pi) + scaled_x = lax.div(lax.sub(x, loc), scale) + return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x))) + + +@_wraps(osp_stats.cauchy.logcdf, update_doc=False) +def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, loc, scale)) + + +@_wraps(osp_stats.cauchy.sf, update_doc=False) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, = promote_args_inexact("cauchy.sf", x) + cdf_result = cdf(x, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) + + +@_wraps(osp_stats.cauchy.isf, update_doc=False) +def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale) + pi = _lax_const(q, np.pi) + half_pi = _lax_const(q, np.pi / 2) + unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q))) + return lax.add(lax.mul(unscaled, scale), loc) + + +@_wraps(osp_stats.cauchy.ppf, update_doc=False) +def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale) + pi = _lax_const(q, np.pi) + half_pi = _lax_const(q, np.pi / 2) + unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi)) + return lax.add(lax.mul(unscaled, scale), loc) diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index f82d8de02c23..912f225befe8 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -20,23 +20,52 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.typing import Array, ArrayLike +from jax.scipy.special import gammainc @_wraps(osp_stats.chi2.logpdf, update_doc=False) def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale) - one = _lax_const(x, 1) - two = _lax_const(x, 2) - y = lax.div(lax.sub(x, loc), scale) - df_on_two = lax.div(df, two) + x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale) + one = _lax_const(x, 1) + two = _lax_const(x, 2) + y = lax.div(lax.sub(x, loc), scale) + df_on_two = lax.div(df, two) - kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) + kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) - nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) + nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) - log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) - return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) + log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) + return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) @_wraps(osp_stats.chi2.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.exp(logpdf(x, df, loc, scale)) + return lax.exp(logpdf(x, df, loc, scale)) + + +@_wraps(osp_stats.chi2.cdf, update_doc=False) +def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale) + two = _lax_const(scale, 2) + return gammainc( + lax.div(df, two), + lax.clamp( + _lax_const(x, 0), + lax.div( + lax.sub(x, loc), + lax.mul(scale, two), + ), + _lax_const(x, jnp.inf), + ), + ) + + +@_wraps(osp_stats.chi2.logcdf, update_doc=False) +def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, df, loc, scale)) + + +@_wraps(osp_stats.chi2.sf, update_doc=False) +def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + cdf_result = cdf(x, df, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index ab9f475964fc..dcfb9439a6e1 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.typing import Array, ArrayLike -from jax.scipy.special import gammaln, xlogy +from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc @_wraps(osp_stats.gamma.logpdf, update_doc=False) @@ -35,3 +35,27 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) @_wraps(osp_stats.gamma.pdf, update_doc=False) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, loc, scale)) + + +@_wraps(osp_stats.gamma.cdf, update_doc=False) +def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale) + return gammainc( + a, + lax.clamp( + _lax_const(x, 0), + lax.div(lax.sub(x, loc), scale), + _lax_const(x, jnp.inf), + ) + ) + + +@_wraps(osp_stats.gamma.logcdf, update_doc=False) +def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, a, loc, scale)) + + +@_wraps(osp_stats.gamma.sf, update_doc=False) +def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale) + return gammaincc(a, lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index effdf21b70b6..67901e83fb7d 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -23,29 +23,38 @@ @_wraps(osp_stats.logistic.logpdf, update_doc=False) -def logpdf(x: ArrayLike) -> Array: - x, = promote_args_inexact("logistic.logpdf", x) +def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale) + x = lax.div(lax.sub(x, loc), scale) two = _lax_const(x, 2) half_x = lax.div(x, two) - return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))) + return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale)) @_wraps(osp_stats.logistic.pdf, update_doc=False) -def pdf(x: ArrayLike) -> Array: - return lax.exp(logpdf(x)) +def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.exp(logpdf(x, loc, scale)) + @_wraps(osp_stats.logistic.ppf, update_doc=False) -def ppf(x): - return logit(x) +def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale) + return lax.add(lax.mul(logit(x), scale), loc) + @_wraps(osp_stats.logistic.sf, update_doc=False) -def sf(x): - return expit(lax.neg(x)) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale) + return expit(lax.neg(lax.div(lax.sub(x, loc), scale))) + @_wraps(osp_stats.logistic.isf, update_doc=False) -def isf(x): - return -logit(x) +def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale) + return lax.add(lax.mul(lax.neg(logit(x)), scale), loc) + @_wraps(osp_stats.logistic.cdf, update_doc=False) -def cdf(x): - return expit(x) +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale) + return expit(lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index 011b566de20f..4c72bcac5654 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -24,6 +24,7 @@ from jax._src.typing import Array, ArrayLike from jax.scipy import special + @_wraps(osp_stats.norm.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale) @@ -54,3 +55,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @_wraps(osp_stats.norm.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return jnp.asarray(special.ndtri(q) * scale + loc, float) + + +@_wraps(osp_stats.norm.sf, update_doc=False) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + cdf_result = cdf(x, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) + + +@_wraps(osp_stats.norm.isf, update_doc=False) +def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return ppf(lax.sub(_lax_const(q, 1), q), loc, scale) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index efc615554bc2..5a54f2bf5578 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -36,6 +36,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic)))) + @_wraps(osp_stats.t.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, df, loc, scale)) diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index 30653f0c3c92..963181fa0226 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.beta import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index af6c1ba490a0..b3b0d994c865 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -18,4 +18,9 @@ from jax._src.scipy.stats.cauchy import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, + isf as isf, + ppf as ppf, ) diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py index 349c0c7cc96e..9cb28c8a616b 100644 --- a/jax/scipy/stats/chi2.py +++ b/jax/scipy/stats/chi2.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.chi2 import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py index 3c518f43f400..268fc4fa03de 100644 --- a/jax/scipy/stats/gamma.py +++ b/jax/scipy/stats/gamma.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.gamma import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py index fd9506d7c18d..c6b85f25dfd4 100644 --- a/jax/scipy/stats/norm.py +++ b/jax/scipy/stats/norm.py @@ -21,4 +21,6 @@ logpdf as logpdf, pdf as pdf, ppf as ppf, + sf as sf, + isf as isf, ) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 0b5984d58b36..fb7a5af9a998 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -190,7 +190,7 @@ def args_maker(): with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): @@ -226,6 +226,38 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4}) + @genNamedParametersNArgs(5) + def testBetaLogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.beta.logcdf + lax_fun = lsp_stats.beta.logcdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float32: 2e-3, np.float64: 1e-4}) + + @genNamedParametersNArgs(5) + def testBetaSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.beta.sf + lax_fun = lsp_stats.beta.sf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float32: 2e-3, np.float64: 1e-4}) + def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. @@ -250,6 +282,80 @@ def args_maker(): tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testCauchyLogCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.logcdf + lax_fun = lsp_stats.cauchy.logcdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testCauchySf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.sf + lax_fun = lsp_stats.cauchy.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testCauchyIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.isf + lax_fun = lsp_stats.cauchy.isf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that q is in desired range + # since lax.tan and numpy.tan work different near divergence points + q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=2e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + + @genNamedParametersNArgs(3) + def testCauchyPpf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.ppf + lax_fun = lsp_stats.cauchy.ppf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that q is in desired + # since lax.tan and numpy.tan work different near divergence points + q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=2e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @jtu.sample_product( shapes=[ [x_shape, alpha_shape] @@ -326,6 +432,37 @@ def testGammaLogPdfZero(self): self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) + @genNamedParametersNArgs(4) + def testGammaLogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.gamma.logcdf + lax_fun = lsp_stats.gamma.logcdf + + def args_maker(): + x, a, loc, scale = map(rng, shapes, dtypes) + x = np.clip(x, 0, None).astype(x.dtype) + return [x, a, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testGammaLogSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.gamma.sf + lax_fun = lsp_stats.gamma.sf + + def args_maker(): + x, a, loc, scale = map(rng, shapes, dtypes) + return [x, a, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(2) def testGenNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -411,32 +548,39 @@ def args_maker(): tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=3e-5) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) def testLogisticLogpdfOverflow(self): # Regression test for https://github.com/google/jax/issues/10219 @@ -445,31 +589,56 @@ def testLogisticLogpdfOverflow(self): lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), check_dtypes=False) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=2e-5) - self._CompileAndCheck(lax_fun, args_maker) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=2e-5) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testLogisticIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.logistic.isf + lax_fun = lsp_stats.logistic.isf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): @@ -524,6 +693,22 @@ def args_maker(): tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testNormSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.norm.sf + lax_fun = lsp_stats.norm.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-6) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): @@ -543,6 +728,24 @@ def args_maker(): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @genNamedParametersNArgs(3) + def testNormIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.norm.isf + lax_fun = lsp_stats.norm.isf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # ensure probability is between 0 and 1: + q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @genNamedParametersNArgs(5) def testTruncnormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -716,6 +919,36 @@ def args_maker(): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) + def testChi2LogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.chi2.logcdf + lax_fun = lsp_stats.chi2.logcdf + + def args_maker(): + x, df, loc, scale = map(rng, shapes, dtypes) + return [x, df, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testChi2Sf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.chi2.sf + lax_fun = lsp_stats.chi2.sf + + def args_maker(): + x, df, loc, scale = map(rng, shapes, dtypes) + return [x, df, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) From 5b30f8e9cd5ae91490017bae00fc46975e1b299a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 23 Mar 2023 10:29:11 -0700 Subject: [PATCH 38/65] Move jax._src.typing into a separate Bazel target. PiperOrigin-RevId: 518899136 --- jax/BUILD | 10 +++++++++- jax/_src/core.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 0cd20ce9d72a..85b296056c5b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -117,7 +117,6 @@ py_library_providing_imports_info( "_src/random.py", "_src/sharding_impls.py", "_src/stages.py", - "_src/typing.py", ] + glob( [ "*.py", @@ -185,6 +184,7 @@ py_library_providing_imports_info( ":source_info_util", ":traceback_util", ":tree_util", + ":typing", ":util", ":version", ":xla_bridge", @@ -334,6 +334,14 @@ pytype_library( ], ) +pytype_library( + name = "typing", + srcs = [ + "_src/typing.py", + ], + deps = [":basearray"] + py_deps("numpy"), +) + pytype_library( name = "util", srcs = ["_src/util.py"], diff --git a/jax/_src/core.py b/jax/_src/core.py index 68f131697382..61d5fcdbdcc5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2022,7 +2022,7 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize: else: raise type_error -def canonicalize_shape(shape: Shape, context: str="") -> Shape: +def canonicalize_shape(shape: Shape, context: str="") -> Tuple[Any, ...]: """Canonicalizes and checks for errors in a user-provided shape value. Args: From 7a326b3886e17b869b3049b8b35d835056d57cdc Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 23 Mar 2023 11:43:49 -0700 Subject: [PATCH 39/65] Deprecated xla_call_p since it has been replaced with pjit.pjit_p PiperOrigin-RevId: 518921538 --- jax/_src/interpreters/mlir.py | 19 ++-------- jax/_src/interpreters/pxla.py | 14 ++----- jax/_src/interpreters/xla.py | 55 +++------------------------- jax/_src/maps.py | 2 +- jax/experimental/host_callback.py | 28 +++----------- jax/experimental/jax2tf/jax2tf.py | 23 +----------- jax/experimental/jet.py | 7 ---- jax/experimental/shard_map.py | 23 +----------- jax/experimental/sparse/transform.py | 21 +---------- jax/interpreters/mlir.py | 2 +- jax/interpreters/xla.py | 11 +++++- tests/jaxpr_effects_test.py | 2 +- tests/name_stack_test.py | 2 +- 13 files changed, 36 insertions(+), 173 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bcbab7767024..76a9f6fa5a32 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1276,20 +1276,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens))) return out_nodes, tokens_out -def _xla_call_lower(ctx, *args, - backend=None, name, call_jaxpr, donated_invars, inline=None, - device=None, keep_unused=None): - del device, donated_invars, inline, keep_unused # Ignored. - out_nodes, tokens = _call_lowering( - name, util.wrap_name(name, "jit"), call_jaxpr, backend, - ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, - *args, dim_var_values=ctx.dim_var_values) - ctx.set_tokens_out(tokens) - return out_nodes - -register_lowering(xla.xla_call_p, _xla_call_lower) - -def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): +def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): out_nodes, tokens = _call_lowering( name, name, call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args, @@ -1297,9 +1284,9 @@ def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): ctx.set_tokens_out(tokens) return out_nodes -register_lowering(core.call_p, partial(_core_call_lowering, name="core_call")) +register_lowering(core.call_p, partial(core_call_lowering, name="core_call")) register_lowering(core.closed_call_p, - partial(_core_call_lowering, name="core_closed_call")) + partial(core_call_lowering, name="core_closed_call")) def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *, broadcast_dimensions) -> ir.Value: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 1c8a37c581a8..2ad6ef83c0e2 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -902,12 +902,7 @@ def process_primitive(self, primitive, tracers, params): return MapTracer(self, outvals, out_shard_axes) def process_call(self, call_primitive, fun, tracers, params): - if call_primitive is not xla.xla_call_p: raise NotImplementedError - bind = HashableFunction( - lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs), - (call_primitive, fun)) - fake_primitive = FakePrimitive(multiple_results=True, bind=bind) - return self.process_primitive(fake_primitive, tracers, params) + raise NotImplementedError def process_map(self, call_primitive, fun, tracers, params): if params['devices'] is not None: @@ -1998,15 +1993,14 @@ def _pmap_dce_rule(used_outputs, eqn): # Set param update handlers to update `donated_invars` just like xla_call_p -pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p] +pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \ partial(pe.call_partial_eval_custom_rule, 'call_jaxpr', _pmap_partial_eval_custom_params_updater, res_aval=_pmap_partial_eval_custom_res_maker) pe.dce_rules[xla_pmap_p] = _pmap_dce_rule -ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p] -ad.call_transpose_param_updaters[xla_pmap_p] = \ - ad.call_transpose_param_updaters[xla.xla_call_p] +ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params +ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 826b94cd2e9b..6623069b2486 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -28,17 +28,14 @@ import numpy as np from jax.config import config -from jax.interpreters import partial_eval as pe from jax._src import core from jax._src import device_array from jax._src import dtypes -from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ConcreteArray, ShapedArray -from jax._src.interpreters import ad -from jax._src.util import (safe_zip, safe_map, partition_list) +from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -157,7 +154,6 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) - # IR constants # TODO(mattjj): try to remove this canonicalize_dtype stuff @@ -348,11 +344,11 @@ def jaxpr_collectives(jaxpr): ### xla_call underlying jit - +# TODO(yashkatariya): Remove after 1 month from March 23, 2023. xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call') -xla_call = xla_call_p.bind -def _xla_call_partial_eval_update_params( + +def xla_call_partial_eval_update_params( params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int ) -> core.ParamDict: donated_invars = params['donated_invars'] @@ -366,57 +362,18 @@ def _xla_call_partial_eval_update_params( # Any new inputs are prepended to the left, so mark those as not donated. donated_invars = [False] * num_new_inputs + donated_invars return dict(params, donated_invars=tuple(donated_invars)) -pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params -def _xla_call_jvp_update_params(params, nz_tangents): +def xla_call_jvp_update_params(params, nz_tangents): donated_invars = params['donated_invars'] donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz] new_donated_invars = (*donated_invars, *donated_tangents) return dict(params, donated_invars=new_donated_invars) -ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params -def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): +def xla_call_transpose_update_params(params, undef_primals, nonzero_cts): donated_invars = params['donated_invars'] donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u] donated_cotangents = [False for nz in nonzero_cts if nz] return dict(params, donated_invars=(*donated_primals, *donated_cotangents)) -ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params - - -ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p) - - -def _xla_call_partial_eval_custom_params_updater( - unks_in: Sequence[bool], inst_in: Sequence[bool], - kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool], - num_res: int, params_known: dict, params_staged: dict - ) -> Tuple[dict, dict]: - # pruned inputs to jaxpr_known according to unks_in, so prune donated_invars - donated_known, _ = partition_list(unks_in, params_known['donated_invars']) - new_params_known = dict(params_known, donated_invars=tuple(donated_known)) - # added num_res new inputs to jaxpr_staged, so extend donated_invars - _, donated_staged_ = partition_list(inst_in, params_staged['donated_invars']) - donated_staged = [False] * num_res + donated_staged_ - new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged)) - return new_params_known, new_params_staged -pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \ - partial(pe.call_partial_eval_custom_rule, 'call_jaxpr', - _xla_call_partial_eval_custom_params_updater) -pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule - -pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p) - - -def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext, - settings: core.JaxprPpSettings, - ) -> pp.Doc: - printed_params = {k:v for k, v in eqn.params.items() if - k == 'call_jaxpr' or k == 'name' or - k == 'backend' and v is not None or - k == 'device' and v is not None or - k == 'donated_invars' and any(v)} - return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -core.pp_eqn_rules[xla_call_p] = _pp_xla_call ### translation tables diff --git a/jax/_src/maps.py b/jax/_src/maps.py index a16d881cd9aa..7e11f950e574 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -846,7 +846,7 @@ def shadowed_subst(name): # NOTE: We don't have to handle spmd_{in|out}_axes here, because # SPMD batching always gets involved as the last transform before XLA translation ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore -ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p] +ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 47791ff7c7f4..6e3aaa9be23d 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1655,16 +1655,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], jaxpr=new_jaxpr, num_carry=num_carry + 2, linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:]))) - elif eqn.primitive is xla.xla_call_p: - call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), - donated_invars=eqn.params["donated_invars"] + (False, False)))) elif eqn.primitive is pxla.xla_pmap_p: # We broadcast the input token into an array of tokens call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) @@ -1762,12 +1752,10 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], eqns.append( core.new_jaxpr_eqn( eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var], - pred1_and_token1, xla.xla_call_p, + pred1_and_token1, core.call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_before", - donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals), - inline=False), + name="cond_before"), transformed_cond_jaxpr.jaxpr.effects, eqn.source_info)) # Make a new cond "lambda pred, carry, token, itoken: pred" @@ -1808,22 +1796,18 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], new_body_invars_body_constvars + new_body_invars_carry + [new_body_invars_token, new_body_invars_itoken], new_body_carry2 + [new_body_token2, new_body_itoken2], - xla.xla_call_p, + core.call_p, dict( call_jaxpr=transformed_body_jaxpr.jaxpr, - name="body", - donated_invars=(False,) * len(transformed_body_jaxpr.in_avals), - inline=False), + name="body"), transformed_body_jaxpr.effects, eqn.source_info), core.new_jaxpr_eqn( new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2], - [new_body_pred2, new_body_token3, new_body_itoken3], xla.xla_call_p, + [new_body_pred2, new_body_token3, new_body_itoken3], core.call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_body", - donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals), - inline=False), + name="cond_body"), transformed_cond_jaxpr.effects, eqn.source_info) ] diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 37b65bf85cdc..37a33b4000ad 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1479,28 +1479,9 @@ def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun, avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) interpreted_fun = _interpret_subtrace(fun, self.main, avals) extra_name_stack = None - if call_primitive == xla.xla_call_p: - extra_name_stack = util.wrap_name(params["name"], "jit") with _extended_name_stack(extra_name_stack): with core.new_sublevel(): - if call_primitive == xla.xla_call_p: - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - # Make a nested tf.function(jit_compile=True) - store_tf_res_avals: Sequence[core.ShapedArray] = [] - def f_tf(*tf_args): - nonlocal store_tf_res_avals - tf_res_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \ - _call_wrapped_with_new_constant_cache(interpreted_fun, tf_args, - fresh_constant_cache=False) - tf_res_vals, tf_res_avals = util.unzip2(tf_res_out) - store_tf_res_avals = tf_res_avals - return tf_res_vals - tf_vals_out = tf.function(f_tf, autograph=False, jit_compile=True)(*vals) - vals_out = zip(tf_vals_out, store_tf_res_avals) - else: - vals_out = interpreted_fun.call_wrapped(*vals) - else: - vals_out = interpreted_fun.call_wrapped(*vals) + vals_out = interpreted_fun.call_wrapped(*vals) return [TensorFlowTracer(self, v, a) for v, a in vals_out] def post_process_call(self, call_primitive: core.Primitive, @@ -1572,7 +1553,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): # Call primitives are inlined -for unexpected in [core.call_p, xla.xla_call_p, maps.xmap_p]: +for unexpected in [core.call_p, maps.xmap_p]: tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) # Primitives that are not yet implemented must be explicitly declared here. diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 54323fe2dc41..f20682a63da0 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -267,13 +267,6 @@ class ZeroSeries: pass call_param_updaters = {} -def _xla_call_param_updater(params, num_inputs): - donated_invars = params['donated_invars'] - if any(donated_invars): - raise NotImplementedError("donated_invars not supported with jet") - return dict(params, donated_invars=(False,) * num_inputs) -call_param_updaters[xla.xla_call_p] = _xla_call_param_updater - ### rule definitions diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index b8fc67fd27ed..e2eb8085403a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -581,25 +581,8 @@ def process_primitive(self, prim, tracers, params): return ShardMapTracer(self, out_rep, out_vals) def process_call(self, call_primitive, fun, tracers, params): - if call_primitive is not xla.xla_call_p: raise NotImplementedError - fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit - bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr()) - fake_primitive = pxla.FakePrimitive(multiple_results=True, bind=bind) - _rep_rules[fake_primitive] = lambda *_, **__: set() # pytype: disable=container-type-mismatch - out_tracers_ = self.process_primitive(fake_primitive, tracers, params) - out_vals = [t.val for t in out_tracers_] - if self.check: - out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers]) - else: - out_rep = [set()] * len(out_vals) - return map(partial(ShardMapTracer, self), out_rep, out_vals) + raise NotImplementedError -@lu.transformation_with_aux -def _grab_jaxpr_shadily(*args): - out = yield args, {} - main = core.thread_local_state.trace_state.trace_stack.dynamic # forgive me - jaxpr, _ = main.jaxpr_stack[-1].to_jaxpr(out) - yield out, jaxpr class ShardMapTracer(core.Tracer): rep: Set[AxisName] @@ -711,10 +694,6 @@ def _axis_index_rule(mesh, *, axis_name): def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs): return _output_rep(mesh, jaxpr.jaxpr, in_rep) -@register_rule(xla.xla_call_p) -def _jit_rule(mesh, *in_rep, jaxpr, **kwargs): - return _output_rep(mesh, jaxpr, in_rep) - @register_rule(debugging.debug_callback_p) def _debug_callback_rule(mesh, *in_rep, **_): return [] diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 67b3fdcc5a09..d79455e9a288 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -421,14 +421,7 @@ def write(var: core.Var, a: SparsifyValue) -> None: _raise_unimplemented_primitive(prim) out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params) else: - if prim is xla.xla_call_p: - # TODO(vanderplas,frostig): workaround for binding call primitives - # within a jaxpr interpreter - params = eqn.params.copy() - fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) - out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params) - else: - out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params) + out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf, outvar in safe_zip(out_bufs, eqn.outvars): @@ -759,18 +752,6 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n sparse_rules_bcoo[lax.while_p] = _while_sparse -def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params): - if any(donated_invars): - raise NotImplementedError("sparse xla_call with donated_invars") - sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues) - fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr)) - args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues)) - donated_invars = tuple(False for arg in args_flat) - out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params) - return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat)) - -sparse_rules_bcoo[xla.xla_call_p] = _xla_call_sparse - def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, keep_unused, inline): diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 1499d07d8fcd..2eca53a83e5d 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -34,9 +34,9 @@ _call_lowering as _call_lowering, _lowerings as _lowerings, _platform_specific_lowerings as _platform_specific_lowerings, - _xla_call_lower as _xla_call_lower, aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, + core_call_lowering as core_call_lowering, dense_bool_elements as dense_bool_elements, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index e0da29e7ec2d..b008f09bef19 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -29,8 +29,7 @@ register_translation as register_translation, sharding_to_proto as sharding_to_proto, translations as translations, - xla_call as xla_call, - xla_call_p as xla_call_p, + xla_call_p as _deprecated_xla_call_p, xla_destructure as xla_destructure, xla_shape_handlers as xla_shape_handlers, device_put as _deprecated_device_put, @@ -83,6 +82,13 @@ ), _deprecated_device_put, ), + "xla_call_p": ( + ( + "jax.interpreters.xla.xla_call_p is deprecated. Please use" + " jax.experimental.pjit.pjit_p instead." + ), + _deprecated_xla_call_p, + ), } from jax._src.deprecations import deprecation_getattr as _deprecation_getattr @@ -98,4 +104,5 @@ from jax._src.interpreters.xla import ( device_put as device_put, ) + from jax._src.interpreters.xla import xla_call_p as xla_call_p del typing diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index f3845320eeff..72626a6aec2b 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -205,7 +205,7 @@ def f_(x): self.assertIn(foo_effect, jaxpr.jaxpr.effects) self.assertIn(bar_effect, jaxpr.jaxpr.effects) - def test_xla_call_primitive_inherits_effects(self): + def test_jit_primitive_inherits_effects(self): @jax.jit def f(x): diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 182cd8b3a137..b3e4341803ba 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -100,7 +100,7 @@ def _f(x): hlo_text = _get_hlo(f)(2) self.assertIn('foo/jit(core_call)/bar', hlo_text) - def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self): + def test_jit_jaxpr_should_not_store_outer_name_stack(self): @jax.named_scope('foo') def f(x): @jax.jit From 6ee65983163f153ff285ea91a215934200e16a7f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 23 Mar 2023 13:56:11 -0700 Subject: [PATCH 40/65] Fix mypy issue in jax/experimental/jet.py --- jax/experimental/jet.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index f20682a63da0..c35f4b09f383 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -52,18 +52,16 @@ `outstanding primitive rules `__. """ -from typing import Callable, Any, Tuple +from typing import Any, Callable, Dict, Tuple from functools import partial import numpy as np -import jax from jax import lax -from jax.interpreters import xla import jax.numpy as jnp from jax.experimental import pjit -from jax.interpreters import partial_eval as pe, pxla +from jax.interpreters import partial_eval as pe from jax.tree_util import (register_pytree_node, tree_structure, treedef_is_leaf, tree_flatten, tree_unflatten,) @@ -265,7 +263,7 @@ class ZeroSeries: pass register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series) -call_param_updaters = {} +call_param_updaters: Dict[core.Primitive, Callable[..., Any]] = {} ### rule definitions From 180f12e05f86342cca060f9780dffcfef792809a Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 23 Mar 2023 21:17:54 +0000 Subject: [PATCH 41/65] add trailing-whitespace pre-commit hook --- .pre-commit-config.yaml | 7 +++++++ .../jax2tf/tests/back_compat_testdata/cpu_lapack_syev.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4210995e6b13..422566a8a008 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,6 +8,13 @@ # 'pre-commit run --all' repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: trailing-whitespace + # only include python files + files: \.py$ + - repo: https://github.com/pycqa/flake8 rev: '6.0.0' hooks: diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/cpu_lapack_syev.py b/jax/experimental/jax2tf/tests/back_compat_testdata/cpu_lapack_syev.py index 06c329bedeec..3d3b23e083e2 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/cpu_lapack_syev.py +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/cpu_lapack_syev.py @@ -381,5 +381,5 @@ """, mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0bOO//\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02\x96\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\x0b)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00J\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8fW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_zheevd\x00", xla_call_module_version=4, - ), # End paste + ), # End paste ) \ No newline at end of file From fdea9e6d1ce3884c24916db28d5c90f86c9773ae Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 23 Mar 2023 14:51:20 -0700 Subject: [PATCH 42/65] [jax2tf] Minor addition to the documentation PiperOrigin-RevId: 518969936 --- jax/experimental/jax2tf/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 0fe9e0960f44..4e046ff5f4cf 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -35,7 +35,8 @@ This has several advantages: primitives at all data types. * uses standard native code paths in each framework, and thus it is easier to trust that the semantics and performance stays faithful to the native - semantics, across platforms. + semantics, across platforms. Has optional checking that the code runs on + the platform for which it was serialized. * the metadata associated with the operations, e.g., source location, is identical to what native execution uses. From 1e356cf19c2612736fc909f01f9df4602bdd7116 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 23 Mar 2023 10:57:53 -0700 Subject: [PATCH 43/65] internal: refactor array methods into separate private submodule --- jax/_src/numpy/array_methods.py | 817 ++++++++++++++++++++++++++++++++ jax/_src/numpy/lax_numpy.py | 771 +----------------------------- jax/_src/prng.py | 2 +- jax/numpy/__init__.py | 5 + 4 files changed, 826 insertions(+), 769 deletions(-) create mode 100644 jax/_src/numpy/array_methods.py diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py new file mode 100644 index 000000000000..364d2f72875b --- /dev/null +++ b/jax/_src/numpy/array_methods.py @@ -0,0 +1,817 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: skip-file +# mypy: disable-error-code=has-type +"""Define methods which are dynamically added to JAX's Arrays and Tracers. + +This is done dynamically in order to avoid circular imports. +""" + +__all__ = ['register_jax_array_methods'] + +from functools import partial, wraps +import inspect +from typing import Any, List, Optional, Tuple, Union +import warnings + +import numpy as np +import jax +from jax import lax +from jax._src import core +from jax._src import dtypes +from jax._src import device_array +from jax._src.api_util import _ensure_index_tuple +from jax._src.array import ArrayImpl +from jax._src.lax import lax as lax_internal +from jax._src.numpy import lax_numpy +from jax._src.numpy import reductions +from jax._src.numpy import ufuncs +from jax._src.numpy import util +from jax._src.ops import scatter +from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape +from jax._src.util import safe_zip + + +### add method and operator overloads to arraylike classes + +# We add operator overloads to DeviceArray and ShapedArray. These method and +# operator overloads mainly just forward calls to the corresponding lax_numpy +# functions, which can themselves handle instances from any of these classes. + + +def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: + """Copy the array and cast to a specified dtype. + + This is implemeted via :func:`jax.lax.convert_element_type`, which may + have slightly different behavior than :meth:`numpy.ndarray.astype` in + some cases. In particular, the details of float-to-int and int-to-float + casts are implementation dependent. + """ + if dtype is None: + dtype = dtypes.canonicalize_dtype(lax_numpy.float_) + dtypes.check_user_dtype_supported(dtype, "astype") + return lax.convert_element_type(arr, dtype) + + +def _nbytes(arr: ArrayLike) -> int: + """Total bytes consumed by the elements of the array.""" + return np.size(arr) * dtypes.dtype(arr, canonicalize=True).itemsize + + +def _item(a: Array) -> Any: + """Copy an element of an array to a standard Python scalar and return it.""" + if dtypes.issubdtype(a.dtype, np.complexfloating): + return complex(a) + elif dtypes.issubdtype(a.dtype, np.floating): + return float(a) + elif dtypes.issubdtype(a.dtype, np.integer): + return int(a) + elif dtypes.issubdtype(a.dtype, np.bool_): + return bool(a) + else: + raise TypeError(a.dtype) + + +def _itemsize(arr: ArrayLike) -> int: + """Length of one array element in bytes.""" + return dtypes.dtype(arr, canonicalize=True).itemsize + + +def _clip(number: ArrayLike, + min: Optional[ArrayLike] = None, max: Optional[ArrayLike] = None, + out: None = None) -> Array: + """Return an array whose values are limited to a specified range. + + Refer to :func:`jax.numpy.clip` for full documentation.""" + return lax_numpy.clip(number, a_min=min, a_max=max, out=out) + + +def _transpose(a: Array, *args: Any) -> Array: + """Returns a view of the array with axes transposed. + + Refer to :func:`jax.numpy.transpose` for full documentation. + """ + if not args: + axis = None + elif len(args) == 1: + axis = args[0] if args[0] is None else _ensure_index_tuple(args[0]) + else: + axis = _ensure_index_tuple(args) + return lax_numpy.transpose(a, axis) + + +def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape: + """Fixes a -1 value in newshape, if present.""" + # other errors, like having more than one -1, are caught downstream, in + # reshape_shape_rule. + try: + iter(newshape) # type: ignore[arg-type] + except: + iterable = False + else: + iterable = True + newshape = core.canonicalize_shape(newshape if iterable else [newshape]) # type: ignore[arg-type] + return tuple(- core.divide_shape_sizes(np.shape(a), newshape) + if core.symbolic_equal_dim(d, -1) else d + for d in newshape) + + +def _reshape(a: Array, *args: Any, order: str = "C") -> Array: + """Returns an array containing the same data with a new shape. + + Refer to :func:`jax.numpy.reshape` for full documentation. + """ + newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) + if order == "C": + return lax.reshape(a, newshape, None) + elif order == "F": + dims = list(range(a.ndim)[::-1]) + return lax.reshape(a, newshape[::-1], dims).T + elif order == "A": + raise NotImplementedError("np.reshape order=A is not implemented.") + else: + raise ValueError(f"Unexpected value for 'order' argument: {order}.") + + +def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array: + """Return a bitwise copy of the array, viewed as a new dtype. + + This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`. + + If the source and target dtype have the same bitwidth, the result has the same + shape as the input array. If the bitwidth of the target dtype is different + from the source, the size of the last axis of the result is adjusted + accordingly. + + >>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape + (1, 2, 6) + >>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape + (1, 2, 2) + + Conversions involving booleans are not well-defined in all situations. With + regards to the shape of result as explained above, booleans are treated as + having a bitwidth of 8. However, when converting to a boolean array, the input + should only contain 0 or 1 bytes. Otherwise, results may be unpredictable or + may change depending on how the result is used. + + This conversion is guaranteed and safe: + >>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_) + Array([ True, False, True], dtype=bool) + + However, there are no guarantees about the results of any expression involving + a view such as this: `jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)`. + In particular, the results may change between JAX releases and depending on + the platform. To safely convert such an array to a boolean array, compare it + with `0`: + + >>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 + Array([ True, True, False], dtype=bool) + """ + if type is not None: + raise NotImplementedError("`type` argument of array.view() is not supported.") + + util.check_arraylike("view", arr) + arr = lax_numpy.asarray(arr) + + dtypes.check_user_dtype_supported(dtype, "view") + dtype = dtypes.canonicalize_dtype(dtype) + + if arr.ndim == 0: + if arr.dtype.itemsize != dtype.itemsize: + raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.") + return _view(lax.expand_dims(arr, (0,)), dtype).squeeze() + + if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0: + raise ValueError("When changing to a larger dtype, its size must be a divisor " + "of the total size in bytes of the last axis of the array.") + + if arr.dtype == dtype: + return arr + + # lax.bitcast_convert_type does not support bool or complex; in these cases we + # cast to a compatible type and recursively call _view for simplicity. + if arr.dtype == bool: + return _view(arr.astype('uint8'), dtype) + + if lax_numpy.issubdtype(arr.dtype, np.complexfloating): + new_shape = (*arr.shape[:-1], arr.shape[-1] * 2) + new_dtype = lax_numpy.finfo(arr.dtype).dtype + arr = (lax_numpy.zeros(new_shape, new_dtype) + .at[..., 0::2].set(arr.real) + .at[..., 1::2].set(arr.imag)) + return _view(arr, dtype) + + if dtype == bool: + return _view(arr, np.uint8).astype(bool) + + if lax_numpy.issubdtype(dtype, np.complexfloating): + out = _view(arr, lax_numpy.finfo(dtype).dtype).astype(dtype) + return out[..., 0::2] + 1j * out[..., 1::2] + + # lax.bitcast_convert_type adds or subtracts dimensions depending on the + # relative bitwidths of the dtypes; we account for that with reshapes. + if arr.dtype.itemsize < dtype.itemsize: + factor = dtype.itemsize // arr.dtype.itemsize + arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor) + return lax.bitcast_convert_type(arr, dtype) + + if arr.dtype.itemsize > dtype.itemsize: + out = lax.bitcast_convert_type(arr, dtype) + return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) + + return lax.bitcast_convert_type(arr, dtype) + + +def _notimplemented_flat(self): + raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: " + "consider arr.flatten() instead.") + +_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array) +_rejected_binop_types = (list, tuple, set, dict) + +def _defer_to_unrecognized_arg(opchar, binary_op, swap=False): + # Ensure that other array types have the chance to override arithmetic. + def deferring_binary_op(self, other): + if hasattr(other, '__jax_array__'): + other = other.__jax_array__() + args = (other, self) if swap else (self, other) + if isinstance(other, _accepted_binop_types): + return binary_op(*args) + if isinstance(other, _rejected_binop_types): + raise TypeError(f"unsupported operand type(s) for {opchar}: " + f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}") + return NotImplemented + return deferring_binary_op + +def _unimplemented_setitem(self, i, x): + msg = ("'{}' object does not support item assignment. JAX arrays are " + "immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` " + "or another .at[] method: " + "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html") + raise TypeError(msg.format(type(self))) + +def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array: + out = lax_numpy.round(number, decimals=ndigits or 0) + # If `ndigits` is None, for a builtin float round(7.5) returns an integer. + return out.astype(int) if ndigits is None else out + +def _copy(self: Array) -> Array: + return self.copy() + +def _deepcopy(self: Array, memo: Any) -> Array: + del memo # unused + return self.copy() + + +# Experimental support for NumPy's module dispatch with NEP-37. +# Currently requires https://github.com/seberg/numpy-dispatch +_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl) +_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,) + +def __array_module__(self, types): + if all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): + return jax.numpy + else: + return NotImplemented + + +def _compress_method(a: ArrayLike, condition: ArrayLike, + axis: Optional[int] = None, out: None = None) -> Array: + """Return selected slices of this array along given axis. + + Refer to :func:`jax.numpy.compress` for full documentation.""" + return lax_numpy.jaxcompress(condition, a, axis, out) + + +@util._wraps(lax.broadcast, lax_description=""" +Deprecated. Use :func:`jax.lax.broadcast` instead. +""") +def _deprecated_broadcast(*args, **kwargs): + warnings.warn( + "The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.", + category=FutureWarning) + return lax.broadcast(*args, **kwargs) + + +@util._wraps(lax.broadcast, lax_description=""" +Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead. +""") +def _deprecated_broadcast_in_dim(*args, **kwargs): + warnings.warn( + "The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.", + category=FutureWarning) + return lax.broadcast_in_dim(*args, **kwargs) + + +@util._wraps(lax.broadcast, lax_description=""" +Deprecated. Use :func:`jax.numpy.split` instead. +""") +def _deprecated_split(*args, **kwargs): + warnings.warn( + "The arr.split() method is deprecated. Use jax.numpy.split instead.", + category=FutureWarning) + return lax_numpy.split(*args, **kwargs) + + +@core.stash_axis_env() +@partial(jax.jit, static_argnums=(1,2,3)) +def _multi_slice(arr: ArrayLike, + start_indices: Tuple[Tuple[int, ...]], + limit_indices: Tuple[Tuple[int, ...]], + removed_dims: Tuple[Tuple[int, ...]]) -> List[Array]: + """Extracts multiple slices from `arr`. + + This is used to shard DeviceArray arguments to pmap. It's implemented as a + DeviceArray method here to avoid circular imports. + """ + results: List[Array] = [] + for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims): + sliced = lax.slice(arr, starts, limits) + if removed: + sliced = lax.squeeze(sliced, removed) + results.append(sliced) + return results + +# The next two functions are related to iter(device_array), implemented here to +# avoid circular imports. +@jax.jit +def _unstack(x: Array) -> List[Array]: + return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] + +def _chunk_iter(x, size): + if size > x.shape[0]: + yield x + else: + num_chunks, tail = ufuncs.divmod(x.shape[0], size) + for i in range(num_chunks): + yield lax.dynamic_slice_in_dim(x, i * size, size) + if tail: + yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail) + +# Syntactic sugar for scatter operations. +class _IndexUpdateHelper: + # Note: this docstring will appear as the docstring for the `at` property. + """Helper property for index update functionality. + + The ``at`` property provides a functionally pure equivalent of in-place + array modificatons. + + In particular: + + ============================== ================================ + Alternate syntax Equivalent In-place expression + ============================== ================================ + ``x = x.at[idx].set(y)`` ``x[idx] = y`` + ``x = x.at[idx].add(y)`` ``x[idx] += y`` + ``x = x.at[idx].multiply(y)`` ``x[idx] *= y`` + ``x = x.at[idx].divide(y)`` ``x[idx] /= y`` + ``x = x.at[idx].power(y)`` ``x[idx] **= y`` + ``x = x.at[idx].min(y)`` ``x[idx] = minimum(x[idx], y)`` + ``x = x.at[idx].max(y)`` ``x[idx] = maximum(x[idx], y)`` + ``x = x.at[idx].apply(ufunc)`` ``ufunc.at(x, idx)`` + ``x = x.at[idx].get()`` ``x = x[idx]`` + ============================== ================================ + + None of the ``x.at`` expressions modify the original ``x``; instead they return + a modified copy of ``x``. However, inside a :py:func:`~jax.jit` compiled function, + expressions like :code:`x = x.at[idx].set(y)` are guaranteed to be applied in-place. + + Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple + indices refer to the same location, all updates will be applied (NumPy would + only apply the last update, rather than applying all updates.) The order + in which conflicting updates are applied is implementation-defined and may be + nondeterministic (e.g., due to concurrency on some hardware platforms). + + By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound + index semantics can be specified via the ``mode`` parameter (see below). + + Arguments + --------- + mode : str + Specify out-of-bound indexing mode. Options are: + + - ``"promise_in_bounds"``: (default) The user promises that indices are in bounds. + No additional checking will be performed. In practice, this means that + out-of-bounds indices in ``get()`` will be clipped, and out-of-bounds indices + in ``set()``, ``add()``, etc. will be dropped. + - ``"clip"``: clamp out of bounds indices into valid range. + - ``"drop"``: ignore out-of-bound indices. + - ``"fill"``: alias for ``"drop"``. For `get()`, the optional ``fill_value`` + argument specifies the value that will be returned. + + See :class:`jax.lax.GatherScatterMode` for more details. + + indices_are_sorted : bool + If True, the implementation will assume that the indices passed to ``at[]`` + are sorted in ascending order, which can lead to more efficient execution + on some backends. + unique_indices : bool + If True, the implementation will assume that the indices passed to ``at[]`` + are unique, which can result in more efficient execution on some backends. + fill_value : Any + Only applies to the ``get()`` method: the fill value to return for out-of-bounds + slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for + inexact types, the largest negative value for signed types, the largest positive + value for unsigned types, and ``True`` for booleans. + + Examples + -------- + >>> x = jnp.arange(5.0) + >>> x + Array([0., 1., 2., 3., 4.], dtype=float32) + >>> x.at[2].add(10) + Array([ 0., 1., 12., 3., 4.], dtype=float32) + >>> x.at[10].add(10) # out-of-bounds indices are ignored + Array([0., 1., 2., 3., 4.], dtype=float32) + >>> x.at[20].add(10, mode='clip') + Array([ 0., 1., 2., 3., 14.], dtype=float32) + >>> x.at[2].get() + Array(2., dtype=float32) + >>> x.at[20].get() # out-of-bounds indices clipped + Array(4., dtype=float32) + >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN + Array(nan, dtype=float32) + >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value + Array(-1., dtype=float32) + """ + __slots__ = ("array",) + + def __init__(self, array): + self.array = array + + def __getitem__(self, index): + return _IndexUpdateRef(self.array, index) + + def __repr__(self): + return f"_IndexUpdateHelper({repr(self.array)})" + + +# TODO(jakevdp): remove these deprecation warnings after June 2023 +def allow_pass_by_position_with_warning(f): + @wraps(f) + def wrapped(*args, **kwargs): + sig = inspect.signature(f) + try: + sig.bind(*args, **kwargs) + except TypeError: + argspec = inspect.getfullargspec(f) + n_positional = len(argspec.args) + keywords = argspec.kwonlyargs[:len(args) - n_positional] + warnings.warn( + f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. " + f"Pass by keyword instead", category=FutureWarning, stacklevel=2) + converted_kwargs = dict(zip(keywords, args[n_positional:])) + return f(*args[:n_positional], **converted_kwargs, **kwargs) + else: + return f(*args, **kwargs) + return wrapped + + +class _IndexUpdateRef: + """Helper object to call indexed update functions for an (advanced) index. + + This object references a source array and a specific indexer into that array. + Methods on this object return copies of the source array that have been + modified at the positions specified by the indexer. + """ + __slots__ = ("array", "index") + + def __init__(self, array, index): + self.array = array + self.index = index + + def __repr__(self): + return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})" + + @allow_pass_by_position_with_warning + def get(self, *, indices_are_sorted=False, unique_indices=False, + mode=None, fill_value=None): + """Equivalent to ``x[idx]``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexing ` ``x[idx]``. This function differs from + the usual array indexing syntax in that it allows additional keyword + arguments ``indices_are_sorted`` and ``unique_indices`` to be passed. + + See :mod:`jax.ops` for details. + """ + return lax_numpy._rewriting_take(self.array, self.index, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode, + fill_value=fill_value) + + @allow_pass_by_position_with_warning + def set(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] = y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:`indexed assignment ` ``x[idx] = y``. + + See :mod:`jax.ops` for details. + """ + return scatter._scatter_update(self.array, self.index, values, lax.scatter, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + + @allow_pass_by_position_with_warning + def apply(self, func, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. + + Returns the value of ``x`` that would result from applying the unary + function ``func`` to ``x`` at the given indices. This is similar to + ``x.at[idx].set(func(x[idx]))``, but differs in the case of repeated indices: + in ``x.at[idx].apply(func)``, repeated indices result in the function being + applied multiple times. + + Note that in the current implementation, ``scatter_apply`` is not compatible + with automatic differentiation. + + See :mod:`jax.ops` for details. + """ + def _scatter_apply(x, indices, _, dims, **kwargs): + return lax.scatter_apply(x, indices, func, dims, **kwargs) + return scatter._scatter_update(self.array, self.index, + lax_internal._zero(self.array.dtype), + _scatter_apply, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + + @allow_pass_by_position_with_warning + def add(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] += y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] += y``. + + See :mod:`jax.ops` for details. + """ + return scatter._scatter_update(self.array, self.index, values, + lax.scatter_add, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + + @allow_pass_by_position_with_warning + def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] *= y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] *= y``. + + See :mod:`jax.ops` for details. + """ + return scatter._scatter_update(self.array, self.index, values, + lax.scatter_mul, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode) + mul = multiply + + @allow_pass_by_position_with_warning + def divide(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] /= y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] /= y``. + + See :mod:`jax.ops` for details. + """ + return ufuncs.divide( + self.array, + scatter._scatter_update(lax_numpy.ones_like(self.array), self.index, values, + lax.scatter_mul, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode)) + + @allow_pass_by_position_with_warning + def power(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] **= y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] **= y``. + + See :mod:`jax.ops` for details. + """ + return ufuncs.power( + self.array, + scatter._scatter_update(lax_numpy.ones_like(self.array), self.index, values, + lax.scatter_mul, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode)) + + @allow_pass_by_position_with_warning + def min(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` + ``x[idx] = minimum(x[idx], y)``. + + See :mod:`jax.ops` for details. + """ + return scatter._scatter_update(self.array, self.index, values, + lax.scatter_min, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + + @allow_pass_by_position_with_warning + def max(self, values, *, indices_are_sorted=False, unique_indices=False, + mode=None): + """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` + ``x[idx] = maximum(x[idx], y)``. + + See :mod:`jax.ops` for details. + """ + return scatter._scatter_update(self.array, self.index, values, + lax.scatter_max, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + +_array_operators = { + "getitem": lax_numpy._rewriting_take, + "setitem": _unimplemented_setitem, + "copy": _copy, + "deepcopy": _deepcopy, + "neg": ufuncs.negative, + "pos": ufuncs.positive, + "eq": _defer_to_unrecognized_arg("==", ufuncs.equal), + "ne": _defer_to_unrecognized_arg("!=", ufuncs.not_equal), + "lt": _defer_to_unrecognized_arg("<", ufuncs.less), + "le": _defer_to_unrecognized_arg("<=", ufuncs.less_equal), + "gt": _defer_to_unrecognized_arg(">", ufuncs.greater), + "ge": _defer_to_unrecognized_arg(">=", ufuncs.greater_equal), + "abs": ufuncs.abs, + "add": _defer_to_unrecognized_arg("+", ufuncs.add), + "radd": _defer_to_unrecognized_arg("+", ufuncs.add, swap=True), + "sub": _defer_to_unrecognized_arg("-", ufuncs.subtract), + "rsub": _defer_to_unrecognized_arg("-", ufuncs.subtract, swap=True), + "mul": _defer_to_unrecognized_arg("*", ufuncs.multiply), + "rmul": _defer_to_unrecognized_arg("*", ufuncs.multiply, swap=True), + "div": _defer_to_unrecognized_arg("/", ufuncs.divide), + "rdiv": _defer_to_unrecognized_arg("/", ufuncs.divide, swap=True), + "truediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide), + "rtruediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide, swap=True), + "floordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide), + "rfloordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide, swap=True), + "divmod": _defer_to_unrecognized_arg("divmod", ufuncs.divmod), + "rdivmod": _defer_to_unrecognized_arg("divmod", ufuncs.divmod, swap=True), + "mod": _defer_to_unrecognized_arg("%", ufuncs.mod), + "rmod": _defer_to_unrecognized_arg("%", ufuncs.mod, swap=True), + "pow": _defer_to_unrecognized_arg("**", ufuncs.power), + "rpow": _defer_to_unrecognized_arg("**", ufuncs.power, swap=True), + "matmul": _defer_to_unrecognized_arg("@", lax_numpy.matmul), + "rmatmul": _defer_to_unrecognized_arg("@", lax_numpy.matmul, swap=True), + "and": _defer_to_unrecognized_arg("&", ufuncs.bitwise_and), + "rand": _defer_to_unrecognized_arg("&", ufuncs.bitwise_and, swap=True), + "or": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or), + "ror": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or, swap=True), + "xor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor), + "rxor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor, swap=True), + "invert": ufuncs.bitwise_not, + "lshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift), + "rshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift), + "rlshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift, swap=True), + "rrshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift, swap=True), + "round": _operator_round, +} + +_array_methods = { + "all": reductions.all, + "any": reductions.any, + "argmax": lax_numpy.argmax, + "argmin": lax_numpy.argmin, + "argpartition": lax_numpy.argpartition, + "argsort": lax_numpy.argsort, + "astype": _astype, + "choose": lax_numpy.choose, + "clip": _clip, + "conj": ufuncs.conj, + "conjugate": ufuncs.conjugate, + "compress": _compress_method, + "copy": lax_numpy.copy, + "cumprod": reductions.cumprod, + "cumsum": reductions.cumsum, + "diagonal": lax_numpy.diagonal, + "dot": lax_numpy.dot, + "flatten": lax_numpy.ravel, + "item": _item, + "max": reductions.max, + "mean": reductions.mean, + "min": reductions.min, + "nonzero": lax_numpy.nonzero, + "prod": reductions.prod, + "ptp": reductions.ptp, + "ravel": lax_numpy.ravel, + "repeat": lax_numpy.repeat, + "reshape": _reshape, + "round": round, + "searchsorted": lax_numpy.searchsorted, + "sort": lax_numpy.sort, + "squeeze": lax_numpy.squeeze, + "std": reductions.std, + "sum": reductions.sum, + "swapaxes": lax_numpy.swapaxes, + "take": lax_numpy.take, + "trace": lax_numpy.trace, + "transpose": _transpose, + "var": reductions.var, + "view": _view, + + # Methods exposed in order to avoid circular imports + "_split": lax_numpy.split, # used in jacfwd/jacrev + "_multi_slice": _multi_slice, # used in pxla for sharding + + # Deprecated methods. + # TODO(jakevdp): remove these after June 2023 + "broadcast": _deprecated_broadcast, + "broadcast_in_dim": _deprecated_broadcast_in_dim, + "split": _deprecated_split, +} + +_impl_only_array_methods = { + "_chunk_iter": _chunk_iter, + "_unstack": _unstack, +} + +_array_properties = { + "flat": _notimplemented_flat, + "T": lax_numpy.transpose, + "real": ufuncs.real, + "imag": ufuncs.imag, + "nbytes": _nbytes, + "itemsize": _itemsize, + "at": _IndexUpdateHelper, +} + +def _set_shaped_array_attributes(shaped_array): + # Set up operator, method, and property forwarding on Tracer instances + # containing + # ShapedArray avals by following the forwarding conventions for Tracer. + # Forward operators using a single-underscore-prefix naming convention: + for operator_name, function in _array_operators.items(): + setattr(shaped_array, f"_{operator_name}", staticmethod(function)) + # Forward methods and properties using core.{aval_method, aval_property}: + for method_name, method in _array_methods.items(): + setattr(shaped_array, method_name, core.aval_method(method)) + for prop_name, prop in _array_properties.items(): + setattr(shaped_array, prop_name, core.aval_property(prop)) + setattr(shaped_array, "_array_module", staticmethod(__array_module__)) + + +def _set_device_array_base_attributes(device_array, include=None, exclude=None): + # Forward operators, methods, and properties on DeviceArray to lax_numpy + # functions (with no Tracers involved; this forwarding is direct) + def maybe_setattr(attr_name, target): + if exclude is not None and attr_name in exclude: + return + if not include or attr_name in include: + setattr(device_array, attr_name, target) + + for operator_name, function in _array_operators.items(): + maybe_setattr(f"__{operator_name}__", function) + for method_name, method in _array_methods.items(): + maybe_setattr(method_name, method) + for prop_name, prop in _array_properties.items(): + maybe_setattr(prop_name, property(prop)) + + for name, func in _impl_only_array_methods.items(): + setattr(device_array, name, func) + +def _set_device_array_attributes(device_array): + setattr(device_array, "__array_module__", __array_module__) + + +def register_jax_array_methods(): + """Call this function once to register methods of JAX arrays""" + _set_shaped_array_attributes(core.ShapedArray) + _set_shaped_array_attributes(core.DShapedArray) + + _set_device_array_base_attributes(device_array.DeviceArray) + _set_device_array_base_attributes(ArrayImpl, exclude={'__getitem__'}) + + for t in device_array.device_array_types: + _set_device_array_attributes(t) + _set_device_array_attributes(ArrayImpl) + + Array.at.__doc__ = _IndexUpdateHelper.__doc__ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 71f10ba076d9..118aab5095f5 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -26,8 +26,7 @@ import builtins import collections -from functools import partial, wraps -import inspect +from functools import partial import math import operator import types @@ -44,7 +43,6 @@ from jax import jit from jax import errors from jax import lax -from jax._src.interpreters import pxla from jax.tree_util import tree_leaves, tree_flatten, tree_map from jax._src import api_util @@ -52,17 +50,15 @@ from jax._src import device_array from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple -from jax._src.core import ShapedArray, DShapedArray, ConcreteArray +from jax._src.core import ShapedArray, ConcreteArray from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, _sort_le_comparator, PrecisionLike) from jax._src.lax import lax as lax_internal -from jax._src.lib import pmap_lib from jax._src.lib import xla_client from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize -from jax._src.ops import scatter from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape from jax._src.util import (unzip2, subvals, safe_zip, ceil_of_ratio, partition_list, @@ -756,64 +752,8 @@ def reshape(a: ArrayLike, newshape: Union[DimSize, Shape], order: str = "C") -> # forward to method for ndarrays return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr] except AttributeError: - return _reshape(asarray(a), newshape, order=order) + return asarray(a).reshape(newshape, order=order) -def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape: - """Fixes a -1 value in newshape, if present.""" - # other errors, like having more than one -1, are caught downstream, in - # reshape_shape_rule. - try: - iter(newshape) # type: ignore[arg-type] - except: - iterable = False - else: - iterable = True - newshape = core.canonicalize_shape(newshape if iterable else [newshape]) # type: ignore[arg-type] - return tuple(- core.divide_shape_sizes(np.shape(a), newshape) - if core.symbolic_equal_dim(d, -1) else d - for d in newshape) - -def _item(a: Array) -> Any: - """Copy an element of an array to a standard Python scalar and return it.""" - if dtypes.issubdtype(a.dtype, np.complexfloating): - return complex(a) - elif dtypes.issubdtype(a.dtype, np.floating): - return float(a) - elif dtypes.issubdtype(a.dtype, np.integer): - return int(a) - elif dtypes.issubdtype(a.dtype, np.bool_): - return bool(a) - else: - raise TypeError(a.dtype) - -def _reshape(a: Array, *args: Any, order: str = "C") -> Array: - """Returns an array containing the same data with a new shape. - - Refer to :func:`jax.numpy.reshape` for full documentation. - """ - newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) - if order == "C": - return lax.reshape(a, newshape, None) - elif order == "F": - dims = list(range(ndim(a))[::-1]) - return lax.reshape(a, newshape[::-1], dims).T - elif order == "A": - raise NotImplementedError("np.reshape order=A is not implemented.") - else: - raise ValueError(f"Unexpected value for 'order' argument: {order}.") - -def _transpose(a: Array, *args: Any) -> Array: - """Returns a view of the array with axes transposed. - - Refer to :func:`jax.numpy.transpose` for full documentation. - """ - if not args: - axis = None - elif len(args) == 1: - axis = args[0] if args[0] is None else _ensure_index_tuple(args[0]) - else: - axis = _ensure_index_tuple(args) - return transpose(a, axis) @util._wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('order',), inline=True) @@ -4964,133 +4904,6 @@ def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, method='midpoint') -def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: - """Copy the array and cast to a specified dtype. - - This is implemeted via :func:`jax.lax.convert_element_type`, which may - have slightly different behavior than :meth:`numpy.ndarray.astype` in - some cases. In particular, the details of float-to-int and int-to-float - casts are implementation dependent. - """ - if dtype is None: - dtype = dtypes.canonicalize_dtype(float_) - dtypes.check_user_dtype_supported(dtype, "astype") - return lax.convert_element_type(arr, dtype) - - -def _nbytes(arr: ArrayLike) -> int: - """Total bytes consumed by the elements of the array.""" - return size(arr) * _dtype(arr).itemsize - - -def _itemsize(arr: ArrayLike) -> int: - """Length of one array element in bytes.""" - return _dtype(arr).itemsize - - -def _clip(number: ArrayLike, - min: Optional[ArrayLike] = None, max: Optional[ArrayLike] = None, # noqa: F811 - out: None = None) -> Array: - """Return an array whose values are limited to a specified range. - - Refer to :func:`jax.numpy.clip` for full documentation.""" - return clip(number, a_min=min, a_max=max, out=out) - - -def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array: - """Return a bitwise copy of the array, viewed as a new dtype. - - This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`. - - If the source and target dtype have the same bitwidth, the result has the same - shape as the input array. If the bitwidth of the target dtype is different - from the source, the size of the last axis of the result is adjusted - accordingly. - - >>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape - (1, 2, 6) - >>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape - (1, 2, 2) - - Conversions involving booleans are not well-defined in all situations. With - regards to the shape of result as explained above, booleans are treated as - having a bitwidth of 8. However, when converting to a boolean array, the input - should only contain 0 or 1 bytes. Otherwise, results may be unpredictable or - may change depending on how the result is used. - - This conversion is guaranteed and safe: - >>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_) - Array([ True, False, True], dtype=bool) - - However, there are no guarantees about the results of any expression involving - a view such as this: `jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)`. - In particular, the results may change between JAX releases and depending on - the platform. To safely convert such an array to a boolean array, compare it - with `0`: - - >>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 - Array([ True, True, False], dtype=bool) - """ - if type is not None: - raise NotImplementedError("`type` argument of array.view() is not supported.") - - util.check_arraylike("view", arr) - arr = asarray(arr) - - dtypes.check_user_dtype_supported(dtype, "view") - dtype = dtypes.canonicalize_dtype(dtype) - - if arr.ndim == 0: - if arr.dtype.itemsize != dtype.itemsize: - raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.") - return _view(lax.expand_dims(arr, (0,)), dtype).squeeze() - - if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0: - raise ValueError("When changing to a larger dtype, its size must be a divisor " - "of the total size in bytes of the last axis of the array.") - - if arr.dtype == dtype: - return arr - - # lax.bitcast_convert_type does not support bool or complex; in these cases we - # cast to a compatible type and recursively call _view for simplicity. - if arr.dtype == bool: - return _view(arr.astype('uint8'), dtype) - - if issubdtype(arr.dtype, complexfloating): - new_shape = (*arr.shape[:-1], arr.shape[-1] * 2) - new_dtype = finfo(arr.dtype).dtype - arr = (zeros(new_shape, new_dtype) - .at[..., 0::2].set(arr.real) - .at[..., 1::2].set(arr.imag)) - return _view(arr, dtype) - - if dtype == bool: - return _view(arr, uint8).astype(bool) - - if issubdtype(dtype, complexfloating): - out = _view(arr, finfo(dtype).dtype).astype(dtype) - return out[..., 0::2] + 1j * out[..., 1::2] - - # lax.bitcast_convert_type adds or subtracts dimensions depending on the - # relative bitwidths of the dtypes; we account for that with reshapes. - if arr.dtype.itemsize < dtype.itemsize: - factor = dtype.itemsize // arr.dtype.itemsize - arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor) - return lax.bitcast_convert_type(arr, dtype) - - if arr.dtype.itemsize > dtype.itemsize: - out = lax.bitcast_convert_type(arr, dtype) - return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) - - return lax.bitcast_convert_type(arr, dtype) - - -def _notimplemented_flat(self): - raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: " - "consider arr.flatten() instead.") - - @util._wraps(np.place, lax_description=""" Numpy function :func:`numpy.place` is not available in JAX and will raise a :class:`NotImplementedError`, because ``np.place`` modifies its arguments in-place, @@ -5115,581 +4928,3 @@ def put(*args, **kwargs): "jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. " "For functional approaches to updating array values, see jax.numpy.ndarray.at: " "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.") - - -### add method and operator overloads to arraylike classes - -# We add operator overloads to DeviceArray and ShapedArray. These method and -# operator overloads mainly just forward calls to the corresponding lax_numpy -# functions, which can themselves handle instances from any of these classes. - -_scalar_types = (int, float, complex, np.generic) -_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array) -_rejected_binop_types = (list, tuple, set, dict) - -def _defer_to_unrecognized_arg(opchar, binary_op, swap=False): - # Ensure that other array types have the chance to override arithmetic. - def deferring_binary_op(self, other): - if hasattr(other, '__jax_array__'): - other = other.__jax_array__() - args = (other, self) if swap else (self, other) - if isinstance(other, _accepted_binop_types): - return binary_op(*args) - if isinstance(other, _rejected_binop_types): - raise TypeError(f"unsupported operand type(s) for {opchar}: " - f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}") - return NotImplemented - return deferring_binary_op - -def _unimplemented_setitem(self, i, x): - msg = ("'{}' object does not support item assignment. JAX arrays are " - "immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` " - "or another .at[] method: " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html") - raise TypeError(msg.format(type(self))) - -def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array: - out = round(number, decimals=ndigits or 0) - # If `ndigits` is None, for a builtin float round(7.5) returns an integer. - return out.astype(int) if ndigits is None else out - -def _copy(self: Array) -> Array: - return self.copy() - -def _deepcopy(self: Array, memo: Any) -> Array: - del memo # unused - return self.copy() - - -# Experimental support for NumPy's module dispatch with NEP-37. -# Currently requires https://github.com/seberg/numpy-dispatch -_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl) -_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,) - -def __array_module__(self, types): - if _all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): - return jax.numpy - else: - return NotImplemented - - -def _compress_method(a: ArrayLike, condition: ArrayLike, - axis: Optional[int] = None, out: None = None) -> Array: - """Return selected slices of this array along given axis. - - Refer to :func:`jax.numpy.compress` for full documentation.""" - return compress(condition, a, axis, out) - - -@util._wraps(lax.broadcast, lax_description=""" -Deprecated. Use :func:`jax.lax.broadcast` instead. -""") -def _deprecated_broadcast(*args, **kwargs): - warnings.warn( - "The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.", - category=FutureWarning) - return lax.broadcast(*args, **kwargs) - - -@util._wraps(lax.broadcast, lax_description=""" -Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead. -""") -def _deprecated_broadcast_in_dim(*args, **kwargs): - warnings.warn( - "The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.", - category=FutureWarning) - return lax.broadcast_in_dim(*args, **kwargs) - - -@util._wraps(lax.broadcast, lax_description=""" -Deprecated. Use :func:`jax.numpy.split` instead. -""") -def _deprecated_split(*args, **kwargs): - warnings.warn( - "The arr.split() method is deprecated. Use jax.numpy.split instead.", - category=FutureWarning) - return split(*args, **kwargs) - - -@core.stash_axis_env() -@partial(jit, static_argnums=(1,2,3)) -def _multi_slice(arr: ArrayLike, - start_indices: Tuple[Tuple[int, ...]], - limit_indices: Tuple[Tuple[int, ...]], - removed_dims: Tuple[Tuple[int, ...]]) -> List[Array]: - """Extracts multiple slices from `arr`. - - This is used to shard DeviceArray arguments to pmap. It's implemented as a - DeviceArray method here to avoid circular imports. - """ - results: List[Array] = [] - for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims): - sliced = lax.slice(arr, starts, limits) - if removed: - sliced = lax.squeeze(sliced, removed) - results.append(sliced) - return results - -# The next two functions are related to iter(device_array), implemented here to -# avoid circular imports. -@jit -def _unstack(x: Array) -> List[Array]: - return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] -setattr(device_array.DeviceArray, "_unstack", _unstack) -setattr(ArrayImpl, '_unstack', _unstack) - -def _chunk_iter(x, size): - if size > x.shape[0]: - yield x - else: - num_chunks, tail = ufuncs.divmod(x.shape[0], size) - for i in range(num_chunks): - yield lax.dynamic_slice_in_dim(x, i * size, size) - if tail: - yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail) -setattr(device_array.DeviceArray, "_chunk_iter", _chunk_iter) -setattr(ArrayImpl, '_chunk_iter', _chunk_iter) - -# Syntactic sugar for scatter operations. -class _IndexUpdateHelper: - # Note: this docstring will appear as the docstring for the `at` property. - """Helper property for index update functionality. - - The ``at`` property provides a functionally pure equivalent of in-place - array modificatons. - - In particular: - - ============================== ================================ - Alternate syntax Equivalent In-place expression - ============================== ================================ - ``x = x.at[idx].set(y)`` ``x[idx] = y`` - ``x = x.at[idx].add(y)`` ``x[idx] += y`` - ``x = x.at[idx].multiply(y)`` ``x[idx] *= y`` - ``x = x.at[idx].divide(y)`` ``x[idx] /= y`` - ``x = x.at[idx].power(y)`` ``x[idx] **= y`` - ``x = x.at[idx].min(y)`` ``x[idx] = minimum(x[idx], y)`` - ``x = x.at[idx].max(y)`` ``x[idx] = maximum(x[idx], y)`` - ``x = x.at[idx].apply(ufunc)`` ``ufunc.at(x, idx)`` - ``x = x.at[idx].get()`` ``x = x[idx]`` - ============================== ================================ - - None of the ``x.at`` expressions modify the original ``x``; instead they return - a modified copy of ``x``. However, inside a :py:func:`~jax.jit` compiled function, - expressions like :code:`x = x.at[idx].set(y)` are guaranteed to be applied in-place. - - Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple - indices refer to the same location, all updates will be applied (NumPy would - only apply the last update, rather than applying all updates.) The order - in which conflicting updates are applied is implementation-defined and may be - nondeterministic (e.g., due to concurrency on some hardware platforms). - - By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound - index semantics can be specified via the ``mode`` parameter (see below). - - Arguments - --------- - mode : str - Specify out-of-bound indexing mode. Options are: - - - ``"promise_in_bounds"``: (default) The user promises that indices are in bounds. - No additional checking will be performed. In practice, this means that - out-of-bounds indices in ``get()`` will be clipped, and out-of-bounds indices - in ``set()``, ``add()``, etc. will be dropped. - - ``"clip"``: clamp out of bounds indices into valid range. - - ``"drop"``: ignore out-of-bound indices. - - ``"fill"``: alias for ``"drop"``. For `get()`, the optional ``fill_value`` - argument specifies the value that will be returned. - - See :class:`jax.lax.GatherScatterMode` for more details. - - indices_are_sorted : bool - If True, the implementation will assume that the indices passed to ``at[]`` - are sorted in ascending order, which can lead to more efficient execution - on some backends. - unique_indices : bool - If True, the implementation will assume that the indices passed to ``at[]`` - are unique, which can result in more efficient execution on some backends. - fill_value : Any - Only applies to the ``get()`` method: the fill value to return for out-of-bounds - slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for - inexact types, the largest negative value for signed types, the largest positive - value for unsigned types, and ``True`` for booleans. - - Examples - -------- - >>> x = jnp.arange(5.0) - >>> x - Array([0., 1., 2., 3., 4.], dtype=float32) - >>> x.at[2].add(10) - Array([ 0., 1., 12., 3., 4.], dtype=float32) - >>> x.at[10].add(10) # out-of-bounds indices are ignored - Array([0., 1., 2., 3., 4.], dtype=float32) - >>> x.at[20].add(10, mode='clip') - Array([ 0., 1., 2., 3., 14.], dtype=float32) - >>> x.at[2].get() - Array(2., dtype=float32) - >>> x.at[20].get() # out-of-bounds indices clipped - Array(4., dtype=float32) - >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN - Array(nan, dtype=float32) - >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value - Array(-1., dtype=float32) - """ - __slots__ = ("array",) - - def __init__(self, array): - self.array = array - - def __getitem__(self, index): - return _IndexUpdateRef(self.array, index) - - def __repr__(self): - return f"_IndexUpdateHelper({repr(self.array)})" -Array.at.__doc__ = _IndexUpdateHelper.__doc__ - - -# TODO(jakevdp): remove these deprecation warnings after June 2023 -def allow_pass_by_position_with_warning(f): - @wraps(f) - def wrapped(*args, **kwargs): - sig = inspect.signature(f) - try: - sig.bind(*args, **kwargs) - except TypeError: - argspec = inspect.getfullargspec(f) - n_positional = len(argspec.args) - keywords = argspec.kwonlyargs[:len(args) - n_positional] - warnings.warn( - f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. " - f"Pass by keyword instead", category=FutureWarning, stacklevel=2) - converted_kwargs = dict(zip(keywords, args[n_positional:])) - return f(*args[:n_positional], **converted_kwargs, **kwargs) - else: - return f(*args, **kwargs) - return wrapped - - -class _IndexUpdateRef: - """Helper object to call indexed update functions for an (advanced) index. - - This object references a source array and a specific indexer into that array. - Methods on this object return copies of the source array that have been - modified at the positions specified by the indexer. - """ - __slots__ = ("array", "index") - - def __init__(self, array, index): - self.array = array - self.index = index - - def __repr__(self): - return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})" - - @allow_pass_by_position_with_warning - def get(self, *, indices_are_sorted=False, unique_indices=False, - mode=None, fill_value=None): - """Equivalent to ``x[idx]``. - - Returns the value of ``x`` that would result from the NumPy-style - :mod:indexing ` ``x[idx]``. This function differs from - the usual array indexing syntax in that it allows additional keyword - arguments ``indices_are_sorted`` and ``unique_indices`` to be passed. - - See :mod:`jax.ops` for details. - """ - return _rewriting_take(self.array, self.index, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode, - fill_value=fill_value) - - @allow_pass_by_position_with_warning - def set(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): - """Pure equivalent of ``x[idx] = y``. - - Returns the value of ``x`` that would result from the NumPy-style - :mod:`indexed assignment ` ``x[idx] = y``. - - See :mod:`jax.ops` for details. - """ - return scatter._scatter_update(self.array, self.index, values, lax.scatter, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - - @allow_pass_by_position_with_warning - def apply(self, func, *, indices_are_sorted=False, unique_indices=False, - mode=None): - """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. - - Returns the value of ``x`` that would result from applying the unary - function ``func`` to ``x`` at the given indices. This is similar to - ``x.at[idx].set(func(x[idx]))``, but differs in the case of repeated indices: - in ``x.at[idx].apply(func)``, repeated indices result in the function being - applied multiple times. - - Note that in the current implementation, ``scatter_apply`` is not compatible - with automatic differentiation. - - See :mod:`jax.ops` for details. - """ - def _scatter_apply(x, indices, _, dims, **kwargs): - return lax.scatter_apply(x, indices, func, dims, **kwargs) - return scatter._scatter_update(self.array, self.index, - lax_internal._zero(self.array.dtype), - _scatter_apply, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - - @allow_pass_by_position_with_warning - def add(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): - """Pure equivalent of ``x[idx] += y``. - - Returns the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] += y``. - - See :mod:`jax.ops` for details. - """ - return scatter._scatter_update(self.array, self.index, values, - lax.scatter_add, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - - @allow_pass_by_position_with_warning - def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): - """Pure equivalent of ``x[idx] *= y``. - - Returns the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] *= y``. - - See :mod:`jax.ops` for details. - """ - return scatter._scatter_update(self.array, self.index, values, - lax.scatter_mul, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, - mode=mode) - mul = multiply - - @allow_pass_by_position_with_warning - def divide(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): - """Pure equivalent of ``x[idx] /= y``. - - Returns the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] /= y``. - - See :mod:`jax.ops` for details. - """ - return ufuncs.divide( - self.array, - scatter._scatter_update(ones_like(self.array), self.index, values, - lax.scatter_mul, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode)) - - @allow_pass_by_position_with_warning - def power(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): - """Pure equivalent of ``x[idx] **= y``. - - Returns the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] **= y``. - - See :mod:`jax.ops` for details. - """ - return ufuncs.power( - self.array, - scatter._scatter_update(ones_like(self.array), self.index, values, - lax.scatter_mul, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode)) - - @allow_pass_by_position_with_warning - def min(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): - """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. - - Returns the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` - ``x[idx] = minimum(x[idx], y)``. - - See :mod:`jax.ops` for details. - """ - return scatter._scatter_update(self.array, self.index, values, - lax.scatter_min, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - - @allow_pass_by_position_with_warning - def max(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): - """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. - - Returns the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` - ``x[idx] = maximum(x[idx], y)``. - - See :mod:`jax.ops` for details. - """ - return scatter._scatter_update(self.array, self.index, values, - lax.scatter_max, - indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - -_array_operators = { - "getitem": _rewriting_take, - "setitem": _unimplemented_setitem, - "copy": _copy, - "deepcopy": _deepcopy, - "neg": ufuncs.negative, - "pos": ufuncs.positive, - "eq": _defer_to_unrecognized_arg("==", ufuncs.equal), - "ne": _defer_to_unrecognized_arg("!=", ufuncs.not_equal), - "lt": _defer_to_unrecognized_arg("<", ufuncs.less), - "le": _defer_to_unrecognized_arg("<=", ufuncs.less_equal), - "gt": _defer_to_unrecognized_arg(">", ufuncs.greater), - "ge": _defer_to_unrecognized_arg(">=", ufuncs.greater_equal), - "abs": ufuncs.abs, - "add": _defer_to_unrecognized_arg("+", ufuncs.add), - "radd": _defer_to_unrecognized_arg("+", ufuncs.add, swap=True), - "sub": _defer_to_unrecognized_arg("-", ufuncs.subtract), - "rsub": _defer_to_unrecognized_arg("-", ufuncs.subtract, swap=True), - "mul": _defer_to_unrecognized_arg("*", ufuncs.multiply), - "rmul": _defer_to_unrecognized_arg("*", ufuncs.multiply, swap=True), - "div": _defer_to_unrecognized_arg("/", ufuncs.divide), - "rdiv": _defer_to_unrecognized_arg("/", ufuncs.divide, swap=True), - "truediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide), - "rtruediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide, swap=True), - "floordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide), - "rfloordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide, swap=True), - "divmod": _defer_to_unrecognized_arg("divmod", ufuncs.divmod), - "rdivmod": _defer_to_unrecognized_arg("divmod", ufuncs.divmod, swap=True), - "mod": _defer_to_unrecognized_arg("%", ufuncs.mod), - "rmod": _defer_to_unrecognized_arg("%", ufuncs.mod, swap=True), - "pow": _defer_to_unrecognized_arg("**", ufuncs.power), - "rpow": _defer_to_unrecognized_arg("**", ufuncs.power, swap=True), - "matmul": _defer_to_unrecognized_arg("@", matmul), - "rmatmul": _defer_to_unrecognized_arg("@", matmul, swap=True), - "and": _defer_to_unrecognized_arg("&", ufuncs.bitwise_and), - "rand": _defer_to_unrecognized_arg("&", ufuncs.bitwise_and, swap=True), - "or": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or), - "ror": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or, swap=True), - "xor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor), - "rxor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor, swap=True), - "invert": ufuncs.bitwise_not, - "lshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift), - "rshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift), - "rlshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift, swap=True), - "rrshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift, swap=True), - "round": _operator_round, -} - -_array_methods = { - "all": reductions.all, - "any": reductions.any, - "argmax": argmax, - "argmin": argmin, - "argpartition": argpartition, - "argsort": argsort, - "astype": _astype, - "choose": choose, - "clip": _clip, - "conj": ufuncs.conj, - "conjugate": ufuncs.conjugate, - "compress": _compress_method, - "copy": copy, - "cumprod": reductions.cumprod, - "cumsum": reductions.cumsum, - "diagonal": diagonal, - "dot": dot, - "flatten": ravel, - "item": _item, - "max": reductions.max, - "mean": reductions.mean, - "min": reductions.min, - "nonzero": nonzero, - "prod": reductions.prod, - "ptp": reductions.ptp, - "ravel": ravel, - "repeat": repeat, - "reshape": _reshape, - "round": round, - "searchsorted": searchsorted, - "sort": sort, - "squeeze": squeeze, - "std": reductions.std, - "sum": reductions.sum, - "swapaxes": swapaxes, - "take": take, - "trace": trace, - "transpose": _transpose, - "var": reductions.var, - "view": _view, - - # Methods exposed in order to avoid circular imports - "_split": split, # used in jacfwd/jacrev - "_multi_slice": _multi_slice, # used in pxla for sharding - - # Deprecated methods. - # TODO(jakevdp): remove these after June 2023 - "broadcast": _deprecated_broadcast, - "broadcast_in_dim": _deprecated_broadcast_in_dim, - "split": _deprecated_split, -} - -_array_properties = { - "flat": _notimplemented_flat, - "T": transpose, - "real": ufuncs.real, - "imag": ufuncs.imag, - "nbytes": _nbytes, - "itemsize": _itemsize, - "at": _IndexUpdateHelper, -} - -def _set_shaped_array_attributes(shaped_array): - # Set up operator, method, and property forwarding on Tracer instances - # containing - # ShapedArray avals by following the forwarding conventions for Tracer. - # Forward operators using a single-underscore-prefix naming convention: - for operator_name, function in _array_operators.items(): - setattr(shaped_array, f"_{operator_name}", staticmethod(function)) - # Forward methods and properties using core.{aval_method, aval_property}: - for method_name, method in _array_methods.items(): - setattr(shaped_array, method_name, core.aval_method(method)) - for prop_name, prop in _array_properties.items(): - setattr(shaped_array, prop_name, core.aval_property(prop)) - setattr(shaped_array, "_array_module", staticmethod(__array_module__)) - -_set_shaped_array_attributes(ShapedArray) -_set_shaped_array_attributes(DShapedArray) - - -def _set_device_array_base_attributes(device_array, include=None, exclude=None): - # Forward operators, methods, and properties on DeviceArray to lax_numpy - # functions (with no Tracers involved; this forwarding is direct) - def maybe_setattr(attr_name, target): - if exclude is not None and attr_name in exclude: - return - if not include or attr_name in include: - setattr(device_array, attr_name, target) - - for operator_name, function in _array_operators.items(): - maybe_setattr(f"__{operator_name}__", function) - for method_name, method in _array_methods.items(): - maybe_setattr(method_name, method) - for prop_name, prop in _array_properties.items(): - maybe_setattr(prop_name, property(prop)) - -_set_device_array_base_attributes(device_array.DeviceArray) -_set_device_array_base_attributes(ArrayImpl, exclude={'__getitem__'}) - -def _set_device_array_attributes(device_array): - setattr(device_array, "__array_module__", __array_module__) - -for t in device_array.device_array_types: - _set_device_array_attributes(t) -_set_device_array_attributes(ArrayImpl) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 730eb9d3c8a7..a966b81aed87 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -42,7 +42,7 @@ from jax._src.lax import utils as lax_utils from jax._src.lib import gpu_prng from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy.lax_numpy import _set_device_array_base_attributes +from jax._src.numpy.array_methods import _set_device_array_base_attributes from jax._src.numpy.util import _register_stackable from jax._src.sharding_impls import ( NamedSharding, PmapSharding, GSPMDSharding) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 229a99a953e2..63b39f07b408 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -420,6 +420,11 @@ from jax._src.numpy.vectorize import vectorize as vectorize +# Dynamically register numpy-style methods on JAX arrays. +from jax._src.numpy.array_methods import register_jax_array_methods +register_jax_array_methods() +del register_jax_array_methods + # Deprecations From 23b0743b62c61536d9b69b8c7dfb7660fc92fe52 Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Thu, 23 Mar 2023 15:20:37 -0700 Subject: [PATCH 44/65] Refactor special functions into their own module. We're going to want to decompose these using series and continued fraction representations, and for that we'll need control flow PiperOrigin-RevId: 518977008 --- jax/_src/lax/lax.py | 141 -------------------------------- jax/_src/lax/special.py | 176 ++++++++++++++++++++++++++++++++++++++++ jax/lax/__init__.py | 50 ++++++------ 3 files changed, 202 insertions(+), 165 deletions(-) create mode 100644 jax/_src/lax/special.py diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 327f64edce53..ebb2caf6680f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -60,7 +60,6 @@ standard_multi_result_abstract_eval, standard_named_shape_rule, standard_primitive, - standard_translate, ) from jax._src.lib import pytree from jax._src import xla_bridge @@ -325,59 +324,6 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array: :math:`\mathrm{atan}({x \over y})`.""" return atan2_p.bind(x, y) -def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: - r"""Elementwise regularized incomplete beta integral.""" - return regularized_incomplete_beta_p.bind(a, b, x) - -def lgamma(x: ArrayLike) -> Array: - r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`.""" - return lgamma_p.bind(x) - -def digamma(x: ArrayLike) -> Array: - r"""Elementwise digamma: :math:`\psi(x)`.""" - return digamma_p.bind(x) - -def igamma(a: ArrayLike, x: ArrayLike) -> Array: - r"""Elementwise regularized incomplete gamma function.""" - return igamma_p.bind(a, x) - -def igammac(a: ArrayLike, x: ArrayLike) -> Array: - r"""Elementwise complementary regularized incomplete gamma function.""" - return igammac_p.bind(a, x) - -def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: - r"""Elementwise derivative of the regularized incomplete gamma function.""" - return igamma_grad_a_p.bind(a, x) - -def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: - r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" - return random_gamma_grad_p.bind(a, x) - -def bessel_i0e(x: ArrayLike) -> Array: - r"""Exponentially scaled modified Bessel function of order 0: - :math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)` - """ - return bessel_i0e_p.bind(x) - -def bessel_i1e(x: ArrayLike) -> Array: - r"""Exponentially scaled modified Bessel function of order 1: - :math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)` - """ - return bessel_i1e_p.bind(x) - -def erf(x: ArrayLike) -> Array: - r"""Elementwise error function: :math:`\mathrm{erf}(x)`.""" - return erf_p.bind(x) - -def erfc(x: ArrayLike) -> Array: - r"""Elementwise complementary error function: - :math:`\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)`.""" - return erfc_p.bind(x) - -def erf_inv(x: ArrayLike) -> Array: - r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`.""" - return erf_inv_p.bind(x) - def real(x: ArrayLike) -> Array: r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`. @@ -1898,93 +1844,6 @@ def atan_impl(x): lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x)))) mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.AtanhOp)) -regularized_incomplete_beta_p = standard_naryop( - [_float, _float, _float], 'regularized_incomplete_beta') -xla.register_translation( - regularized_incomplete_beta_p, - partial(_broadcast_translate, xops.RegularizedIncompleteBeta)) - -def betainc_gradx(g, a, b, x): - lbeta = lgamma(a) + lgamma(b) - lgamma(a + b) - partial_x = exp((b - 1) * log1p(-x) + - (a - 1) * log(x) - lbeta) - return partial_x * g - -def betainc_grad_not_implemented(g, a, b, x): - raise ValueError("Betainc gradient with respect to a and b not supported.") - -ad.defjvp(regularized_incomplete_beta_p, - betainc_grad_not_implemented, - betainc_grad_not_implemented, - betainc_gradx) - -lgamma_p = standard_unop(_float, 'lgamma') -ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x))) -mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp)) - -digamma_p = standard_unop(_float, 'digamma') -mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp)) - -igamma_p = standard_naryop([_float, _float], 'igamma') -xla.register_translation(igamma_p, partial(_broadcast_translate, xops.Igamma)) -igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a') -xla.register_translation(igamma_grad_a_p, - partial(_broadcast_translate, xops.IgammaGradA)) - -def igamma_gradx(g, a, x): - return g * exp(-x + (a - _ones(a)) * log(x) - lgamma(a)) - -def igamma_grada(g, a, x): - return g * igamma_grad_a(a, x) - -ad.defjvp(igamma_p, igamma_grada, igamma_gradx) - -igammac_p = standard_naryop([_float, _float], 'igammac') -xla.register_translation(igammac_p, partial(_broadcast_translate, xops.Igammac)) - -def igammac_gradx(g, a, x): - return -igamma_gradx(g, a, x) - -def igammac_grada(g, a, x): - return -igamma_grada(g, a, x) - -ad.defjvp(igammac_p, igammac_grada, igammac_gradx) - -random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') -xla.register_translation(random_gamma_grad_p, - partial(_broadcast_translate, xops.RandomGammaGrad)) -bessel_i0e_p = standard_unop(_float, 'bessel_i0e') -xla.register_translation(bessel_i0e_p, standard_translate(bessel_i0e_p)) -ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y)) - -bessel_i1e_p = standard_unop(_float, 'bessel_i1e') -mlir.register_lowering(bessel_i1e_p, - partial(_nary_lower_hlo, chlo.BesselI1eOp)) - -def _bessel_i1e_jvp(g, y, x): - eps = dtypes.finfo(_dtype(x)).eps - x_is_not_tiny = abs(x) > eps - safe_x = select(x_is_not_tiny, x, full_like(x, eps)) - dy_dx = bessel_i0e(safe_x) - y * (sign(safe_x) + reciprocal(safe_x)) - dy_dx = select(x_is_not_tiny, dy_dx, full_like(x, 0.5)) - return g * dy_dx -ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp) - -erf_p = standard_unop(_float, 'erf') -ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)), - mul(g, exp(neg(square(x)))))) -mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.ErfOp)) - -erfc_p = standard_unop(_float, 'erfc') -ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)), - mul(g, exp(neg(square(x)))))) -mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp)) - -erf_inv_p = standard_unop(_float, 'erf_inv') -ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.), - mul(g, exp(square(ans))))) -mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp)) - real_p = unop(_complex_basetype, _complex, 'real') ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))]) mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.RealOp)) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py new file mode 100644 index 000000000000..57b812cf0f4a --- /dev/null +++ b/jax/_src/lax/special.py @@ -0,0 +1,176 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +""" Special functions + +LAX decompositions for special functions into their StableHLO counterparts. +""" + +from functools import partial + +from jax._src.lax.lax import (exp, full_like, log, log1p, mul, neg, np, reciprocal, + select, sign, square, standard_naryop, standard_unop, + xla, xops, + _broadcast_translate, _const, _dtype, _float, + _nary_lower_hlo, _ones) +from jax._src.lax.utils import (standard_translate) + +from jax._src import dtypes +from jax._src.interpreters import ad +from jax._src.interpreters import mlir +from jax._src.lib.mlir.dialects import chlo +from jax._src.typing import Array, ArrayLike + +def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: + r"""Elementwise regularized incomplete beta integral.""" + return regularized_incomplete_beta_p.bind(a, b, x) + +def lgamma(x: ArrayLike) -> Array: + r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`.""" + return lgamma_p.bind(x) + +def digamma(x: ArrayLike) -> Array: + r"""Elementwise digamma: :math:`\psi(x)`.""" + return digamma_p.bind(x) + +def igamma(a: ArrayLike, x: ArrayLike) -> Array: + r"""Elementwise regularized incomplete gamma function.""" + return igamma_p.bind(a, x) + +def igammac(a: ArrayLike, x: ArrayLike) -> Array: + r"""Elementwise complementary regularized incomplete gamma function.""" + return igammac_p.bind(a, x) + +def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: + r"""Elementwise derivative of the regularized incomplete gamma function.""" + return igamma_grad_a_p.bind(a, x) + +def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: + r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" + return random_gamma_grad_p.bind(a, x) + +def bessel_i0e(x: ArrayLike) -> Array: + r"""Exponentially scaled modified Bessel function of order 0: + :math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)` + """ + return bessel_i0e_p.bind(x) + +def bessel_i1e(x: ArrayLike) -> Array: + r"""Exponentially scaled modified Bessel function of order 1: + :math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)` + """ + return bessel_i1e_p.bind(x) + +def erf(x: ArrayLike) -> Array: + r"""Elementwise error function: :math:`\mathrm{erf}(x)`.""" + return erf_p.bind(x) + +def erfc(x: ArrayLike) -> Array: + r"""Elementwise complementary error function: + :math:`\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)`.""" + return erfc_p.bind(x) + +def erf_inv(x: ArrayLike) -> Array: + r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`.""" + return erf_inv_p.bind(x) + + +regularized_incomplete_beta_p = standard_naryop( + [_float, _float, _float], 'regularized_incomplete_beta') +xla.register_translation( + regularized_incomplete_beta_p, + partial(_broadcast_translate, xops.RegularizedIncompleteBeta)) + +def betainc_gradx(g, a, b, x): + lbeta = lgamma(a) + lgamma(b) - lgamma(a + b) + partial_x = exp((b - 1) * log1p(-x) + + (a - 1) * log(x) - lbeta) + return partial_x * g + +def betainc_grad_not_implemented(g, a, b, x): + raise ValueError("Betainc gradient with respect to a and b not supported.") + +ad.defjvp(regularized_incomplete_beta_p, + betainc_grad_not_implemented, + betainc_grad_not_implemented, + betainc_gradx) + +def igamma_gradx(g, a, x): + return g * exp(-x + (a - _ones(a)) * log(x) - lgamma(a)) + +def igamma_grada(g, a, x): + return g * igamma_grad_a(a, x) + +def igammac_gradx(g, a, x): + return -igamma_gradx(g, a, x) + +def igammac_grada(g, a, x): + return -igamma_grada(g, a, x) + +lgamma_p = standard_unop(_float, 'lgamma') +ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x))) +mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp)) + +digamma_p = standard_unop(_float, 'digamma') +mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp)) + +igamma_p = standard_naryop([_float, _float], 'igamma') +xla.register_translation(igamma_p, partial(_broadcast_translate, xops.Igamma)) + +igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a') +xla.register_translation(igamma_grad_a_p, + partial(_broadcast_translate, xops.IgammaGradA)) + +ad.defjvp(igamma_p, igamma_grada, igamma_gradx) + +igammac_p = standard_naryop([_float, _float], 'igammac') +xla.register_translation(igammac_p, partial(_broadcast_translate, xops.Igammac)) + +ad.defjvp(igammac_p, igammac_grada, igammac_gradx) + +random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') +xla.register_translation(random_gamma_grad_p, + partial(_broadcast_translate, xops.RandomGammaGrad)) + +bessel_i0e_p = standard_unop(_float, 'bessel_i0e') +xla.register_translation(bessel_i0e_p, standard_translate(bessel_i0e_p)) +ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y)) + +bessel_i1e_p = standard_unop(_float, 'bessel_i1e') +mlir.register_lowering(bessel_i1e_p, + partial(_nary_lower_hlo, chlo.BesselI1eOp)) + +def _bessel_i1e_jvp(g, y, x): + eps = dtypes.finfo(_dtype(x)).eps + x_is_not_tiny = abs(x) > eps + safe_x = select(x_is_not_tiny, x, full_like(x, eps)) + dy_dx = bessel_i0e(safe_x) - y * (sign(safe_x) + reciprocal(safe_x)) + dy_dx = select(x_is_not_tiny, dy_dx, full_like(x, 0.5)) + return g * dy_dx +ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp) + +erf_p = standard_unop(_float, 'erf') +ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)), + mul(g, exp(neg(square(x)))))) +mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.ErfOp)) + +erfc_p = standard_unop(_float, 'erfc') +ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)), + mul(g, exp(neg(square(x)))))) +mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp)) + +erf_inv_p = standard_unop(_float, 'erf_inv') +ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.), + mul(g, exp(square(ans))))) +mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp)) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 601e7d4029d9..82ac7f53c8c4 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -46,11 +46,6 @@ atanh as atanh, atanh_p as atanh_p, batch_matmul as batch_matmul, - bessel_i0e as bessel_i0e, - bessel_i0e_p as bessel_i0e_p, - bessel_i1e as bessel_i1e, - bessel_i1e_p as bessel_i1e_p, - betainc as betainc, bitcast_convert_type as bitcast_convert_type, bitcast_convert_type_p as bitcast_convert_type_p, bitwise_and as bitwise_and, @@ -87,8 +82,6 @@ cosh_p as cosh_p, create_token as create_token, create_token_p as create_token_p, - digamma as digamma, - digamma_p as digamma_p, div as div, div_p as div_p, dot as dot, @@ -98,12 +91,6 @@ dtypes as dtypes, eq as eq, eq_p as eq_p, - erf as erf, - erf_inv as erf_inv, - erf_inv_p as erf_inv_p, - erf_p as erf_p, - erfc as erfc, - erfc_p as erfc_p, exp as exp, exp_p as exp_p, expand_dims as expand_dims, @@ -117,12 +104,6 @@ ge_p as ge_p, gt as gt, gt_p as gt_p, - igamma as igamma, - igamma_grad_a as igamma_grad_a, - igamma_grad_a_p as igamma_grad_a_p, - igamma_p as igamma_p, - igammac as igammac, - igammac_p as igammac_p, imag as imag, imag_p as imag_p, infeed as infeed, @@ -136,8 +117,6 @@ itertools as itertools, le as le, le_p as le_p, - lgamma as lgamma, - lgamma_p as lgamma_p, log as log, log1p as log1p, log1p_p as log1p_p, @@ -171,8 +150,6 @@ population_count_p as population_count_p, pow as pow, pow_p as pow_p, - random_gamma_grad as random_gamma_grad, - random_gamma_grad_p as random_gamma_grad_p, real as real, real_p as real_p, reciprocal as reciprocal, @@ -187,7 +164,6 @@ reduce_prod_p as reduce_prod_p, reduce_sum_p as reduce_sum_p, reduce_xor_p as reduce_xor_p, - regularized_incomplete_beta_p as regularized_incomplete_beta_p, rem as rem, rem_p as rem_p, reshape as reshape, @@ -246,6 +222,32 @@ xor_p as xor_p, zeros_like_array as zeros_like_array, ) +from jax._src.lax.special import ( + bessel_i0e as bessel_i0e, + bessel_i0e_p as bessel_i0e_p, + bessel_i1e as bessel_i1e, + bessel_i1e_p as bessel_i1e_p, + betainc as betainc, + digamma as digamma, + digamma_p as digamma_p, + erf as erf, + erfc as erfc, + erfc_p as erfc_p, + erf_inv as erf_inv, + erf_inv_p as erf_inv_p, + erf_p as erf_p, + igamma as igamma, + igammac as igammac, + igammac_p as igammac_p, + igamma_grad_a as igamma_grad_a, + igamma_grad_a_p as igamma_grad_a_p, + igamma_p as igamma_p, + lgamma as lgamma, + lgamma_p as lgamma_p, + random_gamma_grad as random_gamma_grad, + random_gamma_grad_p as random_gamma_grad_p, + regularized_incomplete_beta_p as regularized_incomplete_beta_p, +) from jax._src.lax.slicing import ( GatherDimensionNumbers as GatherDimensionNumbers, GatherScatterMode as GatherScatterMode, From 071d9b946a4f7f5bebe9854e2e5b68419e2979e5 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 23 Mar 2023 14:39:40 -0700 Subject: [PATCH 45/65] improve scan error messages --- jax/_src/lax/control_flow/common.py | 1 - jax/_src/lax/control_flow/loops.py | 101 +++++++++++++++++++------- jax/_src/tree_util.py | 64 ++++++++++++++++- tests/lax_control_flow_test.py | 107 ++++++++++++++++++++++------ 4 files changed, 223 insertions(+), 50 deletions(-) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index f0828a798c3c..f5d0c4490866 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -107,7 +107,6 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2): tree_unflatten(tree2, avals2)) raise TypeError(f"{what} must have identical types, got\n{diff}.") - def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False): if has_aux: actual_tree_children = actual_tree.children() diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 79c92873bab9..8a9828c6ef90 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -13,9 +13,9 @@ # limitations under the License. """Module for the loop primitives.""" from functools import partial +import inspect import itertools import operator - from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar import jax @@ -30,7 +30,8 @@ from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, - tree_map) + tree_map, tree_flatten_with_path, keystr) +from jax._src.tree_util import equality_errors from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api @@ -45,26 +46,13 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.ufuncs import logaddexp from jax._src.traceback_util import api_boundary -from jax._src.util import ( - partition_list, - safe_map, - safe_zip, - split_list, - unzip2, - weakref_lru_cache, - ) +from jax._src.util import (partition_list, safe_map, safe_zip, split_list, + unzip2, weakref_lru_cache) import numpy as np from jax._src.lax.control_flow.common import ( - _abstractify, - _avals_short, - _check_tree_and_avals, - _initial_style_jaxpr, - _make_closed_jaxpr, - _prune_zeros, - _typecheck_param, - allowed_effects, - ) + _abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr, + _make_closed_jaxpr, _prune_zeros, _typecheck_param, allowed_effects) _map = safe_map zip = safe_zip @@ -260,14 +248,11 @@ def _create_jaxpr(init): init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out) if changed: - new_init = tree_unflatten(init_tree, new_init_flat) - init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init) + init = tree_unflatten(init_tree, new_init_flat) + init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) in_flat, jaxpr, consts, out_tree, out_tree_children = rest - _check_tree_and_avals("scan carry output and input", - # Extract the subtree and avals for the first element of the return tuple - out_tree_children[0], carry_avals_out, - init_tree, carry_avals) + _check_scan_carry_type(f, init, out_tree_children[0], carry_avals_out) disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( @@ -280,6 +265,71 @@ def _create_jaxpr(init): unroll=unroll) return tree_unflatten(out_tree, out) +def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals): + try: + sig = inspect.signature(body_fun) + except (ValueError, TypeError): + sig = None + carry_name = sig and list(sig.parameters)[0] + if carry_name: + component = lambda p: (f'the input carry component {carry_name}{keystr(p)}' + if p else f'the input carry {carry_name}') + else: + component = lambda p: (f'the input carry at path {keystr(p)}' + if p else 'the input carry') + leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry) + paths, in_carry_flat = unzip2(leaves_and_paths) + in_avals = _map(_abstractify, in_carry_flat) + if in_carry_tree != out_carry_tree: + try: + out_carry = tree_unflatten(out_carry_tree, out_avals) + except: + out_carry = None + + if out_carry is None: + differences = [f'the input tree structure is:\n{in_carry_tree}\n', + f'the output tree structure is:\n{out_carry_tree}\n'] + else: + differences = '\n'.join( + f' * {component(path)} is a {thing1} but the corresponding component ' + f'of the carry output is a {thing2}, so {explanation}\n' + for path, thing1, thing2, explanation + in equality_errors(in_carry, out_carry)) + raise TypeError( + "Scanned function carry input and carry output must have the same " + "pytree structure, but they differ:\n" + f"{differences}\n" + "Revise the scanned function so that its output is a pair where the " + "first element has the same pytree structure as the first argument." + ) + if not all(_map(core.typematch, in_avals, out_avals)): + differences = '\n'.join( + f' * {component(path)} has type {in_aval.str_short()}' + ' but the corresponding output carry component has type ' + f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}\n' + for path, in_aval, out_aval in zip(paths, in_avals, out_avals) + if not core.typematch(in_aval, out_aval)) + raise TypeError( + "Scanned function carry input and carry output must have equal types " + "(e.g. shapes and dtypes of arrays), " + "but they differ:\n" + f"{differences}\n" + "Revise the scanned function so that all output types (e.g. shapes " + "and dtypes) match the corresponding input types." + ) + +def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: + assert not core.typematch(a1, a2) + if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray): + dtype_mismatch = a1.dtype != a2.dtype + shape_mismatch = a1.shape != a2.shape + return (', so ' * (dtype_mismatch or shape_mismatch) + + 'the dtypes do not match' * dtype_mismatch + + ' and also ' * (dtype_mismatch and shape_mismatch) + + 'the shapes do not match' * shape_mismatch) + return '' + + def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear, f_impl, x_avals, y_avals): consts, init, xs = split_list(args, [num_consts, num_carry]) @@ -1111,6 +1161,7 @@ def _create_jaxpr(init_val): body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) return tree_unflatten(body_tree, outs) + def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts ) -> effects.Effects: joined_effects = set() diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 212f338491f6..1cc830036d1a 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import collections from dataclasses import dataclass @@ -427,6 +428,67 @@ def prefix_errors(prefix_tree: Any, full_tree: Any, ) -> List[Callable[[str], ValueError]]: return list(_prefix_error((), prefix_tree, full_tree, is_leaf)) +def equality_errors( + tree1: Any, tree2: Any, is_leaf: Optional[Callable[[Any], bool]] = None, +) -> Iterable[Tuple[KeyPath, str, str, str]]: + yield from _equality_errors((), tree1, tree2, is_leaf) + +# TODO(mattjj): maybe share some logic with _prefix_error? +def _equality_errors(path, t1, t2, is_leaf): + # If both are leaves, this isn't a structure equality error. + if (treedef_is_strict_leaf(tree_structure(t1, is_leaf=is_leaf)) and + treedef_is_strict_leaf(tree_structure(t2, is_leaf=is_leaf))): return + + # The trees may disagree because they are different types: + if type(t1) != type(t2): + yield path, str(type(t1)), str(type(t2)), 'their Python types differ' + return # no more errors to find + + # Or they may disagree because their roots have different numbers or keys of + # children (with special-case handling of list/tuple): + if isinstance(t1, (list, tuple)): + assert type(t1) == type(t2) + if len(t1) != len(t2): + yield (path, + f'{type(t1).__name__} of length {len(t1)}', + f'{type(t2).__name__} of length {len(t2)}', + 'the lengths do not match') + return # no more errors to find + t1_children, t1_meta = flatten_one_level(t1) + t2_children, t2_meta = flatten_one_level(t2) + t1_keys, t2_keys = _child_keys(t1), _child_keys(t2) + try: + diff = ' '.join(repr(k.key) for k in + set(t1_keys).symmetric_difference(set(t2_keys))) + except: + diff = '' + if len(t1_children) != len(t2_children): + yield (path, + f'{type(t1)} with {len(t1_children)} child' + f'{"ren" if len(t1_children) > 1 else ""}', + f'{type(t2)} with {len(t2_children)} child' + f'{"ren" if len(t2_children) > 1 else ""}', + 'the numbers of children do not match' + + (diff and f', with the symmetric difference of key sets: {{{diff}}}') + ) + return # no more errors to find + + # Or they may disagree if their roots have different pytree metadata: + if t1_meta != t2_meta: + yield (path, + f'{type(t1)} with pytree metadata {t1_meta}', + f'{type(t2)} with pytree metadata {t2_meta}', + 'the pytree node metadata does not match') + return # no more errors to find + + # If the root types and numbers of children agree, there must be a mismatch in + # a subtree, so recurse: + assert t1_keys == t2_keys, \ + f"equal pytree nodes gave different tree keys: {t1_keys} and {t2_keys}" + for k, c1, c2 in zip(t1_keys, t1_children, t2_children): + yield from _equality_errors((*path, k), c1, c2, is_leaf) + + # TODO(ivyzheng): Remove old APIs when all users migrated. class _DeprecatedKeyPathEntry(NamedTuple): @@ -800,7 +862,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any, ("equal pytree nodes gave differing prefix_tree_keys: " f"{prefix_tree_keys} and {full_tree_keys}") for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children): - yield from _prefix_error(tuple((*key_path, k)), t1, t2) + yield from _prefix_error((*key_path, k), t1, t2) # TODO(jakevdp) remove these deprecated wrappers & their imports in jax/__init__.py diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 21110613d4e7..0031476b3d5b 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -1762,31 +1762,92 @@ def plus_one(p, iter_idx): 'scan got value with no leading axis to scan over.*', lambda: lax.scan(plus_one, p0, list(range(5)))) - def testScanTypeErrors(self): - """Test typing error messages for scan.""" - a = jnp.arange(5) - # Body output not a tuple - with self.assertRaisesRegex(TypeError, + def testScanBodyOutputError(self): + with self.assertRaisesRegex( + TypeError, re.escape("scan body output must be a pair, got ShapedArray(float32[]).")): - lax.scan(lambda c, x: np.float32(0.), 0, a) - with self.assertRaisesRegex(TypeError, - re.escape("scan carry output and input must have same type structure, " - f"got {tree_util.tree_structure((0, 0, 0,))} " - f"and {tree_util.tree_structure((1, (2, 3)))}")): - lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), a) - with self.assertRaisesRegex(TypeError, - re.escape("scan carry output and input must have same type structure, " - f"got {tree_util.tree_structure(a)} and {tree_util.tree_structure(None)}.")): - lax.scan(lambda c, x: (0, x), None, a) - with self.assertRaisesWithLiteralMatch( + lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.)) + + def testScanBodyCarryPytreeMismatchErrors(self): + with self.assertRaisesRegex( TypeError, - "scan carry output and input must have identical types, got\n" - "DIFFERENT ShapedArray(int32[]) vs. ShapedArray(float32[])."): - lax.scan(lambda c, x: (np.int32(0), x), np.float32(1.0), a) - with self.assertRaisesRegex(TypeError, - re.escape("scan carry output and input must have same type structure, " - f"got {tree_util.tree_structure(a)} and {tree_util.tree_structure((1, 2))}.")): - lax.scan(lambda c, x: (0, x), (1, 2), a) + re.escape("Scanned function carry input and carry output must have " + "the same pytree structure, but they differ:\n" + " * the input carry c is a tuple of length 2")): + lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), jnp.arange(5.)) + + with self.assertRaisesRegex( + TypeError, + re.escape("Scanned function carry input and carry output must have the " + "same pytree structure, but they differ:\n" + " * the input carry x is a tuple of length 2")): + lax.scan(lambda x, _: ((x[0].astype('float32'),), None), + (jnp.array(0, 'int32'),) * 2, None, length=1) + + with self.assertRaisesRegex( + TypeError, + re.escape("Scanned function carry input and carry output must have the " + "same pytree structure, but they differ:\n" + " * the input carry x is a but the corres")): + jax.lax.scan(lambda x, _: ([x[0].astype('float32'),] * 2, None), + (jnp.array(0, 'int32'),) * 2, None, length=1) + + with self.assertRaisesRegex( + TypeError, + re.escape("Scanned function carry input and carry output must have the " + "same pytree structure, but they differ:\n" + " * the input carry x is a with 1 child but")): + jax.lax.scan(lambda x, _: ({'a': x['a'], 'b': x['a']}, None), + {'a': jnp.array(0, 'int32')}, None, length=1) + + with self.assertRaisesRegex( + TypeError, + re.escape("Scanned function carry input and carry output must have the " + "same pytree structure, but they differ:\n" + " * the input carry component x[0] is a with " + "1 child but the corresponding component of the carry " + "output is a with 2 children")): + jax.lax.scan(lambda x, _: (({'a': x[0]['a'], 'b': x[0]['a']},) * 2, None), + ({'a': jnp.array(0, 'int32')},) * 2, None, length=1) + + def testScanBodyCarryTypeMismatchErrors(self): + with self.assertRaisesRegex( + TypeError, + re.escape("Scanned function carry input and carry output must have equal " + "types (e.g. shapes and dtypes of arrays), but they differ:\n" + " * the input carry x has type int32[] but the corresponding " + "output carry component has type float32[], so the dtypes do " + "not match" + )): + jax.lax.scan(lambda x, _: (x.astype('float32'), None), + jnp.array(0, 'int32'), None, length=1) + + with self.assertRaisesRegex( + TypeError, + re.escape("Scanned function carry input and carry output must have equal " + "types (e.g. shapes and dtypes of arrays), but they differ:\n" + " * the input carry component x[1] has type int32[] but the " + "corresponding output carry component has type float32[], " + "so the dtypes do not match" + )): + jax.lax.scan(lambda x, _: ((x[0], x[1].astype('float32')), None), + (jnp.array(0, 'int32'),) * 2, None, length=1) + + with self.assertRaisesRegex( + TypeError, + re.escape("Scanned function carry input and carry output must have equal " + "types (e.g. shapes and dtypes of arrays), but they differ:\n" + " * the input carry component x[0] has type int32[] but the " + "corresponding output carry component has type float32[], " + "so the dtypes do not match\n\n" + " * the input carry component x[1] has type int32[] but the " + "corresponding output carry component has type float32[1,1], " + "so the dtypes do not match and also the shapes do not match" + )): + jax.lax.scan(lambda x, _: ((x[0].astype('float32'), + x[1].astype('float32').reshape(1, 1), + x[2]), None), + (jnp.array(0, 'int32'),) * 3, None, length=1) @parameterized.named_parameters( {"testcase_name": f"_{scan_name}", From 175cd37a929d47e5193b0b65633e0b44c30d97f0 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 23 Mar 2023 15:42:05 -0700 Subject: [PATCH 46/65] [jax2tf] Create a jax_export library with JAX-only pieces for native serialization This is a pure refactor, no functionality should change. PiperOrigin-RevId: 518982222 --- jax/experimental/jax2tf/BUILD | 16 +- jax/experimental/jax2tf/jax2tf.py | 245 +-------------- jax/experimental/jax2tf/jax_export.py | 278 ++++++++++++++++++ .../jax2tf/tests/back_compat_test.py | 6 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 3 +- 5 files changed, 303 insertions(+), 245 deletions(-) create mode 100644 jax/experimental/jax2tf/jax_export.py diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index 40bce8bc304d..499a446c04ca 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -30,18 +30,32 @@ py_library( deps = [":jax2tf_internal"], ) +py_library( + name = "jax_export", + srcs = [ + "jax_export.py", + "shape_poly.py", + ], + srcs_version = "PY3", + # TODO: b/255503696: enable pytype + tags = ["pytype_unchecked_annotations"], + deps = [ + "//jax", + ] + py_deps("numpy"), +) + py_library( name = "jax2tf_internal", srcs = [ "call_tf.py", "impl_no_xla.py", "jax2tf.py", - "shape_poly.py", ], srcs_version = "PY3", # TODO: b/255503696: enable pytype tags = ["pytype_unchecked_annotations"], deps = [ "//jax", + ":jax_export", ] + py_deps("numpy") + py_deps("tensorflow_core") + jax2tf_deps, ) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 37a33b4000ad..4d5092478ca4 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -20,7 +20,7 @@ import re import threading from typing import ( - Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, + Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast) import warnings @@ -38,6 +38,7 @@ from jax.experimental import maps from jax.experimental.jax2tf import shape_poly from jax.experimental.jax2tf import impl_no_xla +from jax.experimental.jax2tf import jax_export from jax.interpreters import xla from jax._src import ad_checkpoint @@ -53,7 +54,6 @@ from jax._src import random as random_internal from jax._src import source_info_util from jax._src import util -from jax._src import xla_bridge as xb from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.interpreters import pxla @@ -63,7 +63,6 @@ from jax._src.lax import slicing as lax_slicing from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client -from jax._src.lib.mlir.dialects import stablehlo from jax._src.numpy.ufuncs import logaddexp import tensorflow as tf # type: ignore[import] @@ -100,34 +99,6 @@ map = util.safe_map zip = util.safe_zip -# These are the JAX custom call target names that are guaranteed to be stable. -# They are tested by back_compat_test.py. -_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [ - "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", - "ducc_fft", "cu_threefry2x32", - # eigh on CPU - "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", - # eigh on GPU - "cusolver_syevj", "cusolver_syevd", - # eigh on TPU - "Eigh", - # qr on CPU - "lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf", - "lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr", - # qr on GPU - "cusolver_geqrf", "cublas_geqrf_batched", - "cusolver_geqrf", "cusolver_orgqr", - # qr and svd on TPU - "Qr", "ProductOfElementaryHouseholderReflectors", - # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU - # # lu on CPU - # "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf", - # # lu on GPU - # "cublas_getrf_batched", "cusolver_getrf", - # "hipblas_getrf_batched", "hipsolver_getrf", - # lu on TPU - "LuDecomposition", -] def _sanitize_scope_name(name): scope_name = _INVALID_SCOPE_CHAR.sub("_", name) @@ -430,7 +401,7 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal: lowering_platform = native_serialization_platforms[0] else: lowering_platform = None - exported: Optional[Exported] = serialize_native( + exported: Optional[jax_export.Exported] = jax_export.serialize_native( fun_flat_jax, args_avals_flat, lowering_platform=lowering_platform, strict_checks=native_serialization_strict_checks) @@ -609,7 +580,7 @@ def _make_custom_gradient_fn_tf(*, native_serialization: Union[str, bool], native_serialization_platforms: Sequence[str], native_serialization_strict_checks: bool, - exported_primal: Optional["Exported"]): + exported_primal: Optional[jax_export.Exported]): """Prepares the TF function to be used with tf.custom_gradient. """ @@ -747,155 +718,10 @@ def _interpret_fun_jax( return util.unzip2(out_vals) -@dataclasses.dataclass -class Exported: - """Represents a lowered and serialized module.""" - in_avals: Sequence[core.ShapedArray] - out_avals: Sequence[core.ShapedArray] - # The in_shardings reflect only the module_ket_var_idx - in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]] - out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]] - lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm" - - mlir_module: mlir.ir.Module - mlir_module_serialized: bytes # VHLO bytecode format - xla_call_module_version: int # Follows the versions of XlaCallModule - module_kept_var_idx: Sequence[int] # Specifies if an argument is kept in the - # lowering. As long as `out_avals`. - dim_args_spec: Sequence[str] - -def serialize_native(fun_jax: Callable, - args_avals: Sequence[core.ShapedArray], *, - lowering_platform: Optional[str], - strict_checks: bool) -> Exported: - arg_specs_jax = [ - jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape) - for aval in args_avals - ] - - if not hasattr(fun_jax, "lower"): - # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also - # convert(f_jax), in which case a "jit" is implied. In that case we raise - # an error if the lowered function contains non-replicated sharding annotations. - fun_jax_lower = jax.jit(fun_jax).lower - allow_non_replicated_sharding = False - else: - # If we have a pjit or pmap already we do not wrap with another, and we - # allow shardings. - fun_jax_lower = fun_jax.lower - allow_non_replicated_sharding = True - - lowered = fun_jax_lower( - *arg_specs_jax, - _experimental_lowering_platform=lowering_platform)._lowering # type: ignore - - mlir_module = lowered.stablehlo() - - if xla_client.mlir_api_version >= 46: - xla_call_module_version = 4 - mlir_str = mlir.module_to_bytecode(mlir_module) - target_version = stablehlo.get_earliest_forward_compatible_version() - mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact( - mlir_str, target_version) - else: - xla_call_module_version = 3 - mlir_module_serialized = mlir.module_to_bytecode(mlir_module) - - # Figure out the result types and shapes - if "global_out_avals" in lowered.compile_args: - # This is currently the case for pjit - out_avals = lowered.compile_args["global_out_avals"] - elif "shards" in lowered.compile_args: # for PmapComputation - out_avals = lowered.compile_args["shards"].out_sharded_avals - else: - out_avals = lowered.compile_args["out_avals"] - if lowered.compile_args["host_callbacks"]: - raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering") - - if "kept_var_idx" in lowered.compile_args: - module_kept_var_idx = tuple(sorted(lowered.compile_args["kept_var_idx"])) - else: - # For pmap - module_kept_var_idx = tuple(range(len(args_avals))) - - # We must compute the dim_args_spec: for each dimension variable, encode how - # to compute its value from the shape of the explicit arguments. E.g., "2.1" - # denotes args_tf[2].shape[1]. The order of the dimension variables must match - # the order of the first N arguments of the lowered function. - # If we use --jax_dynamic_shapes, the dimension variables are listed in the - # order in which they are encountered by scanning the arguments and their - # shapes in order. Otherwise, the dimension variables are passed in the - # alphabetical order of their names. - dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec - dim_vars_order: List[str] = [] - all_dim_vars: Set[str] = set() - current_kept_arg_idx = -1 # The index among the kept arguments - for arg_idx, aval in enumerate(args_avals): - is_kept = arg_idx in module_kept_var_idx - if is_kept: - current_kept_arg_idx += 1 - - for axis_idx, d in enumerate(aval.shape): - if not core.is_constant_dim(d): - # We collect dimension variables even from dropped args - all_dim_vars = all_dim_vars.union(d.get_vars()) - if not is_kept: continue - d_var = d.to_var() - # We can compute dim vars only from trivial polynomials - if d_var is None: continue - if not d_var in dim_args_spec_dict: - dim_vars_order.append(d_var) - dim_args_spec_dict[d_var] = f"{current_kept_arg_idx}.{axis_idx}" - - if all_dim_vars: - dim_args_spec_set = set(dim_vars_order) - if dim_args_spec_set != all_dim_vars: - missing = all_dim_vars.difference(dim_args_spec_set) - args_list = [f" Arg[{arg_idx}] - {'KEPT ' if arg_idx in module_kept_var_idx else 'DROPPED'}: {aval}" - for arg_idx, aval in enumerate(args_avals)] - raise ValueError( - "The following dimension variables cannot be computed from the static " - f"shapes of the kept lowered arguments: {missing}. These are the " - "argument shapes:\n" + - "\n".join(args_list) + - "\n" - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") - - if config.jax_dynamic_shapes: - # In the order we have seen them - dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_order] - else: - # In sorted order by name - dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_order)] - else: - dim_args_spec = [] - - # Log and then check the module. - if logging.vlog_is_on(3): - mlir_module_text = mlir.module_to_string(mlir_module) - logmsg = f"version={xla_call_module_version} lowering_platform={lowering_platform}, dim_args_spec=" + ", ".join(dim_args_spec) - logging.vlog(3, "Lowered JAX module: %s\n%s", logmsg, mlir_module_text) - - check_module(mlir_module, - allow_non_replicated_sharding=allow_non_replicated_sharding, - allow_all_custom_calls=not strict_checks) - - return Exported( - in_avals=args_avals, - out_avals=out_avals, - in_shardings=lowered.compile_args["in_shardings"], - out_shardings=lowered.compile_args["out_shardings"], - lowering_platform=lowering_platform or default_jax_backend(), - mlir_module=mlir_module, - mlir_module_serialized=mlir_module_serialized, - module_kept_var_idx=module_kept_var_idx, - xla_call_module_version=xla_call_module_version, - dim_args_spec=dim_args_spec - ) def run_exported_as_tf(args_avals: Sequence[core.ShapedArray], args_tf: Sequence[TfVal], - exported: Exported, + exported: jax_export.Exported, native_serialization_strict_checks: bool, ) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]: """Runs the `exported` as an XlaCallModule TF op.""" @@ -962,64 +788,6 @@ def _convert_res(res_val, res_jax_type): return res, tuple(exported.out_avals) -def check_module(mod: mlir.ir.Module, *, - allow_non_replicated_sharding: bool, - allow_all_custom_calls: bool): - """Run a number of checks on the module. - - Args: - allow_non_replicated_sharding: whether the module is allowed to contain - non_replicated sharding annotations. - allow_all_custom_calls: whether we should allow all custom calls, or - only those who we have explicitly marked as stable. - """ - sharding_attr = mlir.ir.StringAttr.get("Sharding", mod.context) - allowed_custom_call_targets_attrs = [ - mlir.ir.StringAttr.get(target, mod.context) - for target in _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE] - disallowed_custom_call_ops: List[str] = [] - def check_sharding(op_str: str, loc: mlir.ir.Location): - # Check the shardings in an operation or attribute (`op_str`) - if not allow_non_replicated_sharding: - m = re.search(r'mhlo.sharding\s*=\s*"([^"]+)"', op_str) - if m and m.group(1) not in ["{replicated}", ""]: - raise ValueError( - "Lowered function does not have a top-level pjit but it has " - f"non-replicated sharding annotations, e.g., {op_str} at {loc}.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion.") - - def check_op(op: mlir.ir.Operation): - op_name = op.operation.name - if op_name == "func.func": - for a in op.operation.attributes: - # TODO: figure out how to parse the attributes properly - check_sharding(str(a), op.location) - - elif op_name == "stablehlo.custom_call": - call_target_name_attr = op.operation.attributes["call_target_name"] - if (not allow_all_custom_calls and - call_target_name_attr not in allowed_custom_call_targets_attrs): - disallowed_custom_call_ops.append(str(op)) - if call_target_name_attr == sharding_attr: - check_sharding(str(op), op.location) - - def walk_operations(op): - check_op(op) - for region in op.operation.regions: - for block in region: - for op in block: - walk_operations(op) - - walk_operations(mod) - if disallowed_custom_call_ops: - disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) - msg = ("Cannot serialize code with custom calls whose targets have no " - "compatibility guarantees. Examples are:\n" - f"{disallowed_custom_call_ops_str}.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") - raise ValueError(msg) - - def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun, in_vals: Sequence[TfVal], fresh_constant_cache: bool = False @@ -1228,9 +996,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal, _thread_local_state.constant_cache[const_key] = (val, tf_val) return tf_val, jax_dtype -def default_jax_backend() -> str: - # Canonicalize to turn into CUDA or ROCM - return xb.canonicalize_platform(jax.default_backend()) def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]: # Returns a tuple of shape_poly.dim_as_value_dtype diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py new file mode 100644 index 000000000000..89afe806694b --- /dev/null +++ b/jax/experimental/jax2tf/jax_export.py @@ -0,0 +1,278 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX APIs for exporting code for interoperation. + +This module is used with jax2tf, but should have no TensorFlow dependencies. +""" +import dataclasses +import re +from typing import Callable, Dict, List, Optional, Sequence, Set, Union + +from absl import logging + +import jax +from jax import config +from jax import sharding + +from jax._src import core +from jax._src import util +from jax._src import xla_bridge as xb +from jax._src.interpreters import mlir +from jax._src.interpreters import pxla +from jax._src.lib import xla_client +from jax._src.lib.mlir.dialects import stablehlo + + +map = util.safe_map +zip = util.safe_zip + +# These are the JAX custom call target names that are guaranteed to be stable. +# They are tested by back_compat_test.py. +_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [ + "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", + "ducc_fft", "cu_threefry2x32", + # eigh on CPU + "lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd", + # eigh on GPU + "cusolver_syevj", "cusolver_syevd", + # eigh on TPU + "Eigh", + # qr on CPU + "lapack_sgeqrf", "lapack_dgeqrf", "lapack_cgeqrf", "lapack_zgeqrf", + "lapack_sorgqr", "lapack_dorgqr", "lapack_cungqr", "lapack_zungqr", + # qr on GPU + "cusolver_geqrf", "cublas_geqrf_batched", + "cusolver_geqrf", "cusolver_orgqr", + # qr and svd on TPU + "Qr", "ProductOfElementaryHouseholderReflectors", + # TODO(atondwal, necula): add back_compat tests for lu on CPU/GPU + # # lu on CPU + # "lapack_sgetrf" , "lapack_dgetrf" , "lapack_cgetrf" , "lapack_zgetrf", + # # lu on GPU + # "cublas_getrf_batched", "cusolver_getrf", + # "hipblas_getrf_batched", "hipsolver_getrf", + # lu on TPU + "LuDecomposition", +] + + +@dataclasses.dataclass +class Exported: + """Represents a lowered and serialized JAX module.""" + in_avals: Sequence[core.ShapedArray] + out_avals: Sequence[core.ShapedArray] + # The in_shardings reflect only the module_kept_var_idx + in_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]] + out_shardings: Sequence[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]] + lowering_platform: str # One of "tpu", "cpu", "cuda", "rocm" + + mlir_module: mlir.ir.Module + mlir_module_serialized: bytes # VHLO bytecode format + xla_call_module_version: int # Follows the versions of XlaCallModule + module_kept_var_idx: Sequence[int] # Specifies if an argument is kept in the + # lowering. As long as `out_avals`. + dim_args_spec: Sequence[str] + + +def default_jax_backend() -> str: + # Canonicalize to turn into CUDA or ROCM + return xb.canonicalize_platform(jax.default_backend()) + + +def serialize_native(fun_jax: Callable, + args_avals: Sequence[core.ShapedArray], *, + lowering_platform: Optional[str], + strict_checks: bool) -> Exported: + arg_specs_jax = [ + jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape) + for aval in args_avals + ] + + if not hasattr(fun_jax, "lower"): + # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also + # convert(f_jax), in which case a "jit" is implied. In that case we raise + # an error if the lowered function contains non-replicated sharding annotations. + fun_jax_lower = jax.jit(fun_jax).lower + allow_non_replicated_sharding = False + else: + # If we have a pjit or pmap already we do not wrap with another, and we + # allow shardings. + fun_jax_lower = fun_jax.lower + allow_non_replicated_sharding = True + + lowered = fun_jax_lower( + *arg_specs_jax, + _experimental_lowering_platform=lowering_platform)._lowering # type: ignore + + mlir_module = lowered.stablehlo() + + if xla_client.mlir_api_version >= 46: + xla_call_module_version = 4 + mlir_str = mlir.module_to_bytecode(mlir_module) + target_version = stablehlo.get_earliest_forward_compatible_version() + mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact( + mlir_str, target_version) + else: + xla_call_module_version = 3 + mlir_module_serialized = mlir.module_to_bytecode(mlir_module) + + # Figure out the result types and shapes + if "global_out_avals" in lowered.compile_args: + # This is currently the case for pjit + out_avals = lowered.compile_args["global_out_avals"] + elif "shards" in lowered.compile_args: # for PmapComputation + out_avals = lowered.compile_args["shards"].out_sharded_avals + else: + out_avals = lowered.compile_args["out_avals"] + if lowered.compile_args["host_callbacks"]: + raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering") + + if "kept_var_idx" in lowered.compile_args: + module_kept_var_idx = tuple(sorted(lowered.compile_args["kept_var_idx"])) + else: + # For pmap + module_kept_var_idx = tuple(range(len(args_avals))) + + # We must compute the dim_args_spec: for each dimension variable, encode how + # to compute its value from the shape of the explicit arguments. E.g., "2.1" + # denotes args_tf[2].shape[1]. The order of the dimension variables must match + # the order of the first N arguments of the lowered function. + # If we use --jax_dynamic_shapes, the dimension variables are listed in the + # order in which they are encountered by scanning the arguments and their + # shapes in order. Otherwise, the dimension variables are passed in the + # alphabetical order of their names. + dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec + dim_vars_order: List[str] = [] + all_dim_vars: Set[str] = set() + current_kept_arg_idx = -1 # The index among the kept arguments + for arg_idx, aval in enumerate(args_avals): + is_kept = arg_idx in module_kept_var_idx + if is_kept: + current_kept_arg_idx += 1 + + for axis_idx, d in enumerate(aval.shape): + if not core.is_constant_dim(d): + # We collect dimension variables even from dropped args + all_dim_vars = all_dim_vars.union(d.get_vars()) + if not is_kept: continue + d_var = d.to_var() + # We can compute dim vars only from trivial polynomials + if d_var is None: continue + if d_var not in dim_args_spec_dict: + dim_vars_order.append(d_var) + dim_args_spec_dict[d_var] = f"{current_kept_arg_idx}.{axis_idx}" + + if all_dim_vars: + dim_args_spec_set = set(dim_vars_order) + if dim_args_spec_set != all_dim_vars: + missing = all_dim_vars.difference(dim_args_spec_set) + args_list = [f" Arg[{arg_idx}] - {'KEPT ' if arg_idx in module_kept_var_idx else 'DROPPED'}: {aval}" + for arg_idx, aval in enumerate(args_avals)] + raise ValueError( + "The following dimension variables cannot be computed from the static " + f"shapes of the kept lowered arguments: {missing}. These are the " + "argument shapes:\n" + + "\n".join(args_list) + + "\n" + "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + + if config.jax_dynamic_shapes: + # In the order we have seen them + dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_order] + else: + # In sorted order by name + dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_order)] + else: + dim_args_spec = [] + + # Log and then check the module. + if logging.vlog_is_on(3): + mlir_module_text = mlir.module_to_string(mlir_module) + logmsg = f"version={xla_call_module_version} lowering_platform={lowering_platform}, dim_args_spec=" + ", ".join(dim_args_spec) + logging.vlog(3, "Lowered JAX module: %s\n%s", logmsg, mlir_module_text) + + check_module(mlir_module, + allow_non_replicated_sharding=allow_non_replicated_sharding, + allow_all_custom_calls=not strict_checks) + + return Exported( + in_avals=args_avals, + out_avals=out_avals, + in_shardings=lowered.compile_args["in_shardings"], + out_shardings=lowered.compile_args["out_shardings"], + lowering_platform=lowering_platform or default_jax_backend(), + mlir_module=mlir_module, + mlir_module_serialized=mlir_module_serialized, + module_kept_var_idx=module_kept_var_idx, + xla_call_module_version=xla_call_module_version, + dim_args_spec=dim_args_spec + ) + + +def check_module(mod: mlir.ir.Module, *, + allow_non_replicated_sharding: bool, + allow_all_custom_calls: bool): + """Run a number of checks on the module. + + Args: + allow_non_replicated_sharding: whether the module is allowed to contain + non_replicated sharding annotations. + allow_all_custom_calls: whether we should allow all custom calls, or + only those who we have explicitly marked as stable. + """ + sharding_attr = mlir.ir.StringAttr.get("Sharding", mod.context) + allowed_custom_call_targets_attrs = [ + mlir.ir.StringAttr.get(target, mod.context) + for target in _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE] + disallowed_custom_call_ops: List[str] = [] + def check_sharding(op_str: str, loc: mlir.ir.Location): + # Check the shardings in an operation or attribute (`op_str`) + if not allow_non_replicated_sharding: + m = re.search(r'mhlo.sharding\s*=\s*"([^"]+)"', op_str) + if m and m.group(1) not in ["{replicated}", ""]: + raise ValueError( + "Lowered function does not have a top-level pjit but it has " + f"non-replicated sharding annotations, e.g., {op_str} at {loc}.\n" + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion.") + + def check_op(op: mlir.ir.Operation): + op_name = op.operation.name + if op_name == "func.func": + for a in op.operation.attributes: + # TODO: figure out how to parse the attributes properly + check_sharding(str(a), op.location) + + elif op_name == "stablehlo.custom_call": + call_target_name_attr = op.operation.attributes["call_target_name"] + if (not allow_all_custom_calls and + call_target_name_attr not in allowed_custom_call_targets_attrs): + disallowed_custom_call_ops.append(str(op)) + if call_target_name_attr == sharding_attr: + check_sharding(str(op), op.location) + + def walk_operations(op): + check_op(op) + for region in op.operation.regions: + for block in region: + for op in block: + walk_operations(op) + + walk_operations(mod) + if disallowed_custom_call_ops: + disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) + msg = ("Cannot serialize code with custom calls whose targets have no " + "compatibility guarantees. Examples are:\n" + f"{disallowed_custom_call_ops_str}.\n" + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") + raise ValueError(msg) diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index 84f69ea376c0..34a95b6de355 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -21,7 +21,7 @@ There is one test for each version of a custom call target, e.g., `test_ducc_fft` tests the FFT custom calls on CPU. Only custom call targets tested here should be listed in -jax2tf._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom +jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom call targets will result in an error when encountered during serialization. Once we stop using a custom call target in JAX, you can remove it from the @@ -203,7 +203,7 @@ def run_one_test(self, func: Callable[..., jax.Array], res_from_jax = tuple(np.array(a) for a in res_from_jax) # Use the native exporter, to make sure we get the proper serialized module. - exported = jax2tf.jax2tf.serialize_native( + exported = jax2tf.jax_export.serialize_native( jax.jit(func), [core.ShapedArray(a.shape, a.dtype) for a in data.inputs], lowering_platform=default_jax_backend(), @@ -309,7 +309,7 @@ def test_detect_different_custom_calls(self): self.run_one_test(jnp.sin, platform_dummy_data) def test_custom_call_coverage(self): - targets_to_cover = set(jax2tf.jax2tf._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) + targets_to_cover = set(jax2tf.jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 2df6d4a61098..ac8dd1667d6b 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -36,6 +36,7 @@ import jax._src.xla_bridge from jax.config import config from jax.experimental import jax2tf +from jax.experimental.jax2tf import jax_export from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map @@ -1493,7 +1494,7 @@ def apply_transform(func, transform: str): stack.enter_context(mesh) # Run the JAX native version, to check it works, and to fill caches. _ = func_to_convert(*args) - exported = jax2tf.jax2tf.serialize_native( + exported = jax_export.serialize_native( func_to_convert, [core.ShapedArray(a.shape, a.dtype) for a in args], lowering_platform='tpu', From 4994472e143ba06105d0ed9258832f71ad23f0e7 Mon Sep 17 00:00:00 2001 From: John QiangZhang Date: Thu, 23 Mar 2023 15:49:44 -0700 Subject: [PATCH 47/65] Add padding option "SAME_LOWER" for ticket https://github.com/google/jax/pull/14990 PiperOrigin-RevId: 518984018 --- jax/_src/lax/convolution.py | 37 +++++++++++++++++++++---------------- jax/_src/lax/lax.py | 25 +++++++++++++++++++++---- jax/_src/lax_reference.py | 11 +++++++++-- tests/lax_test.py | 12 +++++++----- 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 5259135726fa..985edb8fa40e 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -74,25 +74,30 @@ def conv_general_dilated( rhs: a rank `n+2` dimensional array of kernel weights. window_strides: a sequence of `n` integers, representing the inter-window strides. - padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of - `n` `(low, high)` integer pairs that give the padding to apply before and - after each spatial dimension. - lhs_dilation: `None`, or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `lhs`. LHS dilation - is also known as transposed convolution. - rhs_dilation: `None`, or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `rhs`. RHS dilation - is also known as atrous convolution. - dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or - a 3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a - string of length `n+2`. + padding: either the strings `'SAME'`, `'SAME_LOWER'`, or `'VALID'`, or a + sequence of `n` `(low, high)` integer pairs that give the padding to apply + before and after each spatial dimension. `'SAME'` and `'SAME_LOWER'` add + padding to produce same output size as the input. The padding is split + between the two sides equally or almost equally. In case the padding is an + odd number, the extra padding is added at the end for `'SAME'` and at the + beginning for `'SAME_LOWER'`. + lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation + factor to apply in each spatial dimension of `lhs`. LHS dilation is also + known as transposed convolution. + rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation + factor to apply in each spatial dimension of `rhs`. RHS dilation is also + known as atrous convolution. + dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or a + 3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a string + of length `n+2`. feature_group_count: integer, default 1. See XLA HLO docs. batch_group_count: integer, default 1. See XLA HLO docs. precision: Optional. Either ``None``, which means the default precision for - the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, - ``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or - 'fastest', see the ``jax.default_matmul_precision`` context manager), or a - tuple of two :class:`~jax.lax.Precision` enums or strings indicating precision of + the backend, a :class:`~jax.lax.Precision` enum value + (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``), a + string (e.g. 'highest' or 'fastest', see the + ``jax.default_matmul_precision`` context manager), or a tuple of two + :class:`~jax.lax.Precision` enums or strings indicating precision of ``lhs`` and ``rhs``. preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ebb2caf6680f..73414e5b8479 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4442,23 +4442,40 @@ def _dilate_shape(shape, dilation): def _ceil_divide(x1, x2): return -np.floor_divide(np.negative(x1), x2) + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + SAME_LOWER = 3 + + def padtype_to_pads(in_shape, window_shape, window_strides, padding): """Convert padding string to list of pairs of pad values.""" - PaddingType = xla_client.PaddingType if isinstance(padding, str): - mapping = {'VALID': PaddingType.VALID, 'SAME': PaddingType.SAME} + mapping = { + 'VALID': PaddingType.VALID, + 'SAME': PaddingType.SAME, + 'SAME_LOWER': PaddingType.SAME_LOWER, + } try: padding = mapping[padding.upper()] except KeyError as err: msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}." raise RuntimeError(msg.format(padding)) from err - if padding == PaddingType.SAME: + if padding == PaddingType.SAME or padding == PaddingType.SAME_LOWER: out_shape = _ceil_divide(in_shape, window_strides) pad_sizes = np.maximum(0, (out_shape - 1) * window_strides + window_shape - in_shape) - return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] + if padding == PaddingType.SAME: + return [ + (pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes + ] + else: + return [ + (pad_size - pad_size // 2, pad_size // 2) for pad_size in pad_sizes + ] elif padding == PaddingType.VALID: return [(0, 0)] * len(in_shape) else: diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index 27488194d98a..96d3b3dc2564 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -339,12 +339,19 @@ def _conv(lhs, rhs, window_strides, pads): view, view_axes, rhs, rhs_axes, out_axes, use_blas=True) def padtype_to_pads(in_shape, filter_shape, window_strides, padding): - if padding.upper() == 'SAME': + if padding.upper() == 'SAME' or padding.upper() == 'SAME_LOWER': out_shape = np.ceil(np.true_divide(in_shape, window_strides)).astype(int) pad_sizes = [_max((out_size - 1) * stride + filter_size - in_size, 0) for out_size, stride, filter_size, in_size in zip(out_shape, window_strides, filter_shape, in_shape)] - return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] + if padding.upper() == 'SAME': + return [ + (pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes + ] + else: + return [ + (pad_size - pad_size // 2, pad_size // 2) for pad_size in pad_sizes + ] else: return [(0, 0)] * len(in_shape) diff --git a/tests/lax_test.py b/tests/lax_test.py index 299a82145c1a..e8abb8da7dfa 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -253,11 +253,13 @@ def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs): self._CheckAgainstNumpy(numpy_op, op, args_maker) @jtu.sample_product( - [dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) - for b, i, j in itertools.product([2, 3], repeat=3)], - dtype=lax_test_util.float_dtypes, - strides=[(1, 1), (1, 2), (2, 1)], - padding=["VALID", "SAME"], + [ + dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) + for b, i, j in itertools.product([2, 3], repeat=3) + ], + dtype=lax_test_util.float_dtypes, + strides=[(1, 1), (1, 2), (2, 1)], + padding=["VALID", "SAME", "SAME_LOWER"], ) def testConv(self, lhs_shape, rhs_shape, dtype, strides, padding): rng = jtu.rand_small(self.rng()) From 32b8c42e44055c8ea3cb777b198317184ddd1917 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 23 Mar 2023 16:04:23 -0700 Subject: [PATCH 48/65] [jax2tf] A simple failing test on TPU with native serialization PiperOrigin-RevId: 518987577 --- jax/experimental/jax2tf/tests/jax2tf_test.py | 24 +++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ac8dd1667d6b..0884d35bb56d 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -59,6 +59,28 @@ def test_empty(self): f_jax = lambda x, y: x self.ConvertAndCompare(f_jax, 0.7, 1) + def test_sin(self): + f_tf = jax2tf.convert(jnp.sin) + x = np.float32(.5) + sin_x = np.sin(x) + self.assertAllClose(sin_x, f_tf(x)) + self.assertAllClose(sin_x, tf.function(f_tf, autograph=False, + jit_compile=True)(x)) + # TODO: The following, with jit_compile=False, fails with + # native serialization because the tf.function() somehow executes the + # XlaCallModule op on CPU. This is despite the `with tf.device()` + # tf_preferred_device = ( + # tf.config.list_logical_devices("TPU") + + # tf.config.list_logical_devices("GPU") + + # tf.config.list_logical_devices())[0] + # logging.info("Running TF on %s", tf_preferred_device) + # with tf.device(tf_preferred_device): + # self.assertAllClose(sin_x, tf.function(f_tf, autograph=False, + # jit_compile=False)(x)) + + # self.assertAllClose(sin_x, tf.function(f_tf, autograph=False, + # jit_compile=False)(x)) + def test_basics(self): f_jax = lambda x: jnp.sin(jnp.cos(x)) self.ConvertAndCompare(f_jax, 0.7) @@ -1564,7 +1586,7 @@ def get_serialized_computation( else: lowered = jax.jit(f_jax, abstracted_axes=abstracted_axes).lower(*args) stablehlo_module_text = mlir.module_to_string(lowered._lowering.stablehlo()) - logging.info(f'Serialized ir.Module = {stablehlo_module_text}') + logging.info("Serialized ir.Module = %s", stablehlo_module_text) return stablehlo_module_text, 3 From 195f84714ae568f12536d8e106952135d808c396 Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Thu, 23 Mar 2023 16:26:26 -0700 Subject: [PATCH 49/65] Migrate igamma_p off xla_fallback We decompose it into a series or a call to igammac. PiperOrigin-RevId: 518993077 --- jax/_src/lax/special.py | 131 ++++++++++++++++++++++++++++-- tests/filecheck/math.filecheck.py | 5 -- 2 files changed, 126 insertions(+), 10 deletions(-) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index 57b812cf0f4a..804adaf7703e 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -17,13 +17,19 @@ LAX decompositions for special functions into their StableHLO counterparts. """ +from enum import Enum +import numpy as np from functools import partial -from jax._src.lax.lax import (exp, full_like, log, log1p, mul, neg, np, reciprocal, - select, sign, square, standard_naryop, standard_unop, - xla, xops, +from jax._src.lax.lax import (bitwise_and, bitwise_not, bitwise_or, + broadcast_in_dim, broadcast_shapes, + convert_element_type, eq, exp, full_like, + gt, le, log, log1p, lt, mul, neg, reciprocal, + reduce, select, sign, square, standard_naryop, + standard_unop, xla, xops, _broadcast_translate, _const, _dtype, _float, - _nary_lower_hlo, _ones) + _nary_lower_hlo, _ones, _isnan, _reduce) +from jax._src.lax.control_flow import while_loop from jax._src.lax.utils import (standard_translate) from jax._src import dtypes @@ -118,6 +124,120 @@ def igammac_gradx(g, a, x): def igammac_grada(g, a, x): return -igamma_grada(g, a, x) +# The below is directly ported from tensorflow/compiler/xla/client/lib/math.cc +# We try to follow the corresponding functions as closely as possible, so that +# we can quickly incorporate changes. +class IgammaMode(Enum): + VALUE = 1 + DERIVATIVE = 2 + SAMPLE_DERIVATIVE = 3 + +def _any(predicates: Array) -> Array: + f = _const(predicates, False) + predicates_shape = predicates.shape + all_dimensions = tuple(range(len(predicates_shape))) + return reduce(predicates, f, bitwise_or, all_dimensions) + +def _igamma_series(ax, x, a, enabled, dtype, mode): + def cond_fn(vals): + return _any(vals[0]) + + def body_fn(vals): + enabled, r, c, ans, x, dc_da, dans_da = vals + + r = r + _const(r, 1.) + dc_da = dc_da * (x / r) - (c * x) / (r * r) + dans_da = dans_da + dc_da + c = c * (x / r) + ans = ans + c + + if mode == IgammaMode.VALUE: + conditional = bitwise_and(enabled, c / ans > dtypes.finfo(dtype).eps) + else: + conditional = bitwise_and(enabled, + abs(dc_da / dans_da) > dtypes.finfo(dtype).eps) + + # TODO: Make this a vmap. Might be tricky with the imports. + return ( + conditional, + select(enabled, r, vals[1]), + select(enabled, c, vals[2]), + select(enabled, ans, vals[3]), + select(enabled, x, vals[4]), + select(enabled, dc_da, vals[5]), + select(enabled, dans_da, vals[6]), + ) + + init_vals = ( + enabled, a, full_like(a, 1), full_like(a, 1), x, full_like(a, 0), + full_like(a, 0), + ) + + vals = while_loop(cond_fn, body_fn, init_vals) + ans = vals[3] + dans_da = vals[6] + + if mode == IgammaMode.VALUE: + return (ans * ax) / a + + dlogax_da = log(x) - digamma(a + _const(a, 1)) + + if mode == IgammaMode.DERIVATIVE: + return ax * (ans * dlogax_da + dans_da) / a + elif mode == IgammaMode.SAMPLE_DERIVATIVE: + return -(dans_da + ans * dlogax_da) * x / a + else: + raise ValueError("Invalid IgammaMode") + +def igamma_impl(a, x): + broadcasted_shape = broadcast_shapes(a.shape, x.shape) + a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) + x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim))) + + def doit(a, x, dtype): + is_nan = bitwise_or(_isnan(a), _isnan(x)) + x_is_zero = eq(x, _const(x, 0)) + x_is_infinity = eq(x, _const(x, float('inf'))) + domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0))) + use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a)) + ax = a * log(x) - x - lgamma(a) + underflow = lt(ax, -log(dtypes.finfo(dtype).max)) + ax = exp(ax) + enabled = bitwise_not( + _reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan])) + + output = select( + use_igammac, + _const(a, 1) - + _igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac), + dtype, IgammaMode.VALUE), + _igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)), + dtype, IgammaMode.VALUE) + ) + output = select(x_is_zero, full_like(a, 0), output) + output = select(x_is_infinity, full_like(a, 1), output) + output = select(bitwise_or(domain_error, is_nan), + full_like(a, float('nan')), output) + return output + + needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16 + if needs_upcast: + a_dtype = a.dtype + a = convert_element_type(a, np.float32) + x = convert_element_type(x, np.float32) + a_x_type = np.float32 + else: + a_x_type = a.dtype + result = doit(a, x, a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result + +def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): + # TODO(atondwal): implement _igammac_continued_fraction in JAX. + # Right now we fallback to the XLA implementation of IgammacContinuedFraction. + return igammac(a, x) + lgamma_p = standard_unop(_float, 'lgamma') ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x))) mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp)) @@ -126,7 +246,8 @@ def igammac_grada(g, a, x): mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp)) igamma_p = standard_naryop([_float, _float], 'igamma') -xla.register_translation(igamma_p, partial(_broadcast_translate, xops.Igamma)) +mlir.register_lowering(igamma_p, + mlir.lower_fun(igamma_impl, multiple_results=False)) igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a') xla.register_translation(igamma_grad_a_p, diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index 2df2cf6843dd..bfc5d73fc219 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -262,11 +262,6 @@ def main(_): # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.gt) - # CHECK-LABEL: TEST: igamma float32[] float32[] - # CHECK: xla_fallback_igamma - # CHECK-SAME: tensor - print_ir(np.float32(0), np.float32(0))(lax.igamma) - # CHECK-LABEL: TEST: igammac float32[] float32[] # CHECK: xla_fallback_igammac # CHECK-SAME: tensor From f63a09c6a972e8061f2525beb826be8abf96efe0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 23 Mar 2023 16:39:20 -0700 Subject: [PATCH 50/65] lax_numpy: move quantile-based functions to reductions.py --- jax/_src/lax/eigh.py | 3 +- jax/_src/numpy/lax_numpy.py | 201 +------------------------------ jax/_src/numpy/reductions.py | 199 ++++++++++++++++++++++++++++++ jax/numpy/__init__.py | 19 ++- tests/lax_numpy_reducers_test.py | 99 +++++++++++++++ tests/lax_numpy_test.py | 98 --------------- 6 files changed, 309 insertions(+), 310 deletions(-) diff --git a/jax/_src/lax/eigh.py b/jax/_src/lax/eigh.py index ecde2997edf3..8fcb80567f0b 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/lax/eigh.py @@ -33,6 +33,7 @@ import jax import jax._src.numpy.lax_numpy as jnp import jax._src.numpy.linalg as jnp_linalg +from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax import lax from jax._src.lax import qdwh @@ -360,7 +361,7 @@ def nearly_diagonal_case(agenda, blocks, eigenvectors): def default_case(agenda, blocks, eigenvectors): V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # TODO: Improve this? - split_point = jnp.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan)) + split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan)) H_minus, V_minus, H_plus, V_plus, rank = split_spectrum( H, b, split_point, V0=V) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 118aab5095f5..181ff755f488 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1609,7 +1609,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], return array stat_funcs: Dict[str, PadStatFunc] = { - "maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": median} + "maximum": reductions.amax, "minimum": reductions.amin, "mean": reductions.mean, "median": reductions.median} pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width") pad_width_arr = np.array(pad_width) @@ -4582,161 +4582,6 @@ def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) - return c -@util._wraps(np.quantile, skip_params=['out', 'overwrite_input']) -@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', - 'keepdims', 'method')) -def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, - out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, interpolation: None = None) -> Array: - util.check_arraylike("quantile", a, q) - if overwrite_input or out is not None: - msg = ("jax.numpy.quantile does not support overwrite_input=True or " - "out != None") - raise ValueError(msg) - if interpolation is not None: - warnings.warn("The interpolation= argument to 'quantile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning) - return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False) - -@util._wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) -@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', - 'keepdims', 'method')) -def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, - out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, interpolation: None = None) -> Array: - util.check_arraylike("nanquantile", a, q) - if overwrite_input or out is not None: - msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " - "out != None") - raise ValueError(msg) - if interpolation is not None: - warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning) - return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, True) - -def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], - interpolation: str, keepdims: bool, squash_nans: bool) -> Array: - if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: - raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " - "'midpoint', or 'nearest'") - a, = util.promote_dtypes_inexact(a) - keepdim = [] - if issubdtype(a.dtype, np.complexfloating): - raise ValueError("quantile does not support complex input, as the operation is poorly defined.") - if axis is None: - a = ravel(a) - axis = 0 - elif isinstance(axis, tuple): - keepdim = list(shape(a)) - nd = ndim(a) - axis = tuple(_canonicalize_axis(ax, nd) for ax in axis) - if len(set(axis)) != len(axis): - raise ValueError('repeated axis') - for ax in axis: - keepdim[ax] = 1 - - keep = set(range(nd)) - set(axis) - # prepare permutation - dimensions = list(range(nd)) - for i, s in enumerate(sorted(keep)): - dimensions[i], dimensions[s] = dimensions[s], dimensions[i] - do_not_touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx not in axis) - touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx in axis) - a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions) - axis = _canonicalize_axis(-1, ndim(a)) - else: - axis = _canonicalize_axis(axis, ndim(a)) - - q_shape = shape(q) - q_ndim = ndim(q) - if q_ndim > 1: - raise ValueError(f"q must be have rank <= 1, got shape {shape(q)}") - - a_shape = shape(a) - - if squash_nans: - a = where(ufuncs.isnan(a), nan, a) # Ensure nans are positive so they sort to the end. - a = lax.sort(a, dimension=axis) - counts = reductions.sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, - keepdims=keepdims) - shape_after_reduction = counts.shape - q = lax.expand_dims( - q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) - counts = lax.expand_dims(counts, tuple(range(q_ndim))) - q = lax.mul(q, lax.sub(counts, _lax_const(q, 1))) - low = lax.floor(q) - high = lax.ceil(q) - high_weight = lax.sub(q, low) - low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) - - low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) - high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) - low = lax.convert_element_type(low, int64) - high = lax.convert_element_type(high, int64) - out_shape = q_shape + shape_after_reduction - index = [lax.broadcasted_iota(int64, out_shape, dim + q_ndim) - for dim in range(len(shape_after_reduction))] - if keepdims: - index[axis] = low - else: - index.insert(axis, low) - low_value = a[tuple(index)] - index[axis] = high - high_value = a[tuple(index)] - else: - a = where(reductions.any(ufuncs.isnan(a), axis=axis, keepdims=True), nan, a) - a = lax.sort(a, dimension=axis) - n = lax.convert_element_type(array(a_shape[axis]), lax_internal._dtype(q)) - q = lax.mul(q, n - 1) - low = lax.floor(q) - high = lax.ceil(q) - high_weight = lax.sub(q, low) - low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) - - low = lax.clamp(_lax_const(low, 0), low, n - 1) - high = lax.clamp(_lax_const(high, 0), high, n - 1) - low = lax.convert_element_type(low, int64) - high = lax.convert_element_type(high, int64) - - slice_sizes = list(a_shape) - slice_sizes[axis] = 1 - dnums = lax.GatherDimensionNumbers( - offset_dims=tuple(range( - q_ndim, - len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)), - collapsed_slice_dims=() if keepdims else (axis,), - start_index_map=(axis,)) - low_value = lax.gather(a, low[..., None], dimension_numbers=dnums, - slice_sizes=slice_sizes) - high_value = lax.gather(a, high[..., None], dimension_numbers=dnums, - slice_sizes=slice_sizes) - if q_ndim == 1: - low_weight = lax.broadcast_in_dim(low_weight, low_value.shape, - broadcast_dimensions=(0,)) - high_weight = lax.broadcast_in_dim(high_weight, high_value.shape, - broadcast_dimensions=(0,)) - - if interpolation == "linear": - result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight), - lax.mul(high_value.astype(q.dtype), high_weight)) - elif interpolation == "lower": - result = low_value - elif interpolation == "higher": - result = high_value - elif interpolation == "nearest": - pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) - result = lax.select(pred, low_value, high_value) - elif interpolation == "midpoint": - result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) - else: - raise ValueError(f"interpolation={interpolation!r} not recognized") - if keepdims and keepdim: - if q_ndim > 0: - keepdim = [shape(q)[0], *keepdim] - result = reshape(result, keepdim) - return lax.convert_element_type(result, a.dtype) - - @partial(vectorize, excluded={0, 2, 3}) def _searchsorted_via_scan(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: op = _sort_le_comparator if side == 'left' else _sort_lt_comparator @@ -4859,50 +4704,6 @@ def _const(v): return vectorize(lax.switch, excluded=(1,))(indices, funclist, x) -@util._wraps(np.percentile, skip_params=['out', 'overwrite_input']) -@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', - 'keepdims', 'method')) -def percentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, interpolation: None = None) -> Array: - util.check_arraylike("percentile", a, q) - q, = util.promote_dtypes_inexact(q) - return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, - interpolation=interpolation, method=method, keepdims=keepdims) - -@util._wraps(np.nanpercentile, skip_params=['out', 'overwrite_input']) -@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', - 'keepdims', 'method')) -def nanpercentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, interpolation: None = None) -> Array: - util.check_arraylike("nanpercentile", a, q) - q = ufuncs.true_divide(q, float32(100.0)) - return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, - interpolation=interpolation, method=method, - keepdims=keepdims) - -@util._wraps(np.median, skip_params=['out', 'overwrite_input']) -@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) -def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, - out: None = None, overwrite_input: bool = False, - keepdims: bool = False) -> Array: - util.check_arraylike("median", a) - return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, - keepdims=keepdims, method='midpoint') - -@util._wraps(np.nanmedian, skip_params=['out', 'overwrite_input']) -@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) -def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, - out: None = None, overwrite_input: bool = False, - keepdims: bool = False) -> Array: - util.check_arraylike("nanmedian", a) - return nanquantile(a, 0.5, axis=axis, out=out, - overwrite_input=overwrite_input, keepdims=keepdims, - method='midpoint') - @util._wraps(np.place, lax_description=""" Numpy function :func:`numpy.place` is not available in JAX and will raise a diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 4363fbc3253f..7e245487d3dd 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -25,6 +25,7 @@ from jax._src import api from jax._src import core from jax._src import dtypes +from jax._src.numpy import ufuncs from jax._src.numpy.util import ( _broadcast_to, check_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps) @@ -684,3 +685,201 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None, fill_nan=True, fill_value=0) nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, fill_nan=True, fill_value=1) + +# Quantiles +@_wraps(np.quantile, skip_params=['out', 'overwrite_input']) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', + 'keepdims', 'method')) +def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, method: str = "linear", + keepdims: bool = False, interpolation: None = None) -> Array: + check_arraylike("quantile", a, q) + if overwrite_input or out is not None: + msg = ("jax.numpy.quantile does not support overwrite_input=True or " + "out != None") + raise ValueError(msg) + if interpolation is not None: + warnings.warn("The interpolation= argument to 'quantile' is deprecated. " + "Use 'method=' instead.", DeprecationWarning) + return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, False) + +@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', + 'keepdims', 'method')) +def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, method: str = "linear", + keepdims: bool = False, interpolation: None = None) -> Array: + check_arraylike("nanquantile", a, q) + if overwrite_input or out is not None: + msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " + "out != None") + raise ValueError(msg) + if interpolation is not None: + warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " + "Use 'method=' instead.", DeprecationWarning) + return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, True) + +def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], + interpolation: str, keepdims: bool, squash_nans: bool) -> Array: + if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: + raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " + "'midpoint', or 'nearest'") + a, = promote_dtypes_inexact(a) + keepdim = [] + if dtypes.issubdtype(a.dtype, np.complexfloating): + raise ValueError("quantile does not support complex input, as the operation is poorly defined.") + if axis is None: + a = a.ravel() + axis = 0 + elif isinstance(axis, tuple): + keepdim = list(a.shape) + nd = a.ndim + axis = tuple(_canonicalize_axis(ax, nd) for ax in axis) + if len(set(axis)) != len(axis): + raise ValueError('repeated axis') + for ax in axis: + keepdim[ax] = 1 + + keep = set(range(nd)) - set(axis) + # prepare permutation + dimensions = list(range(nd)) + for i, s in enumerate(sorted(keep)): + dimensions[i], dimensions[s] = dimensions[s], dimensions[i] + do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis) + touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis) + a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions) + axis = _canonicalize_axis(-1, a.ndim) + else: + axis = _canonicalize_axis(axis, a.ndim) + + q_shape = q.shape + q_ndim = q.ndim + if q_ndim > 1: + raise ValueError(f"q must be have rank <= 1, got shape {q.shape}") + + a_shape = a.shape + + if squash_nans: + a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. + a = lax.sort(a, dimension=axis) + counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) + shape_after_reduction = counts.shape + q = lax.expand_dims( + q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) + counts = lax.expand_dims(counts, tuple(range(q_ndim))) + q = lax.mul(q, lax.sub(counts, _lax_const(q, 1))) + low = lax.floor(q) + high = lax.ceil(q) + high_weight = lax.sub(q, low) + low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) + + low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) + high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) + low = lax.convert_element_type(low, int) + high = lax.convert_element_type(high, int) + out_shape = q_shape + shape_after_reduction + index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim) + for dim in range(len(shape_after_reduction))] + if keepdims: + index[axis] = low + else: + index.insert(axis, low) + low_value = a[tuple(index)] + index[axis] = high + high_value = a[tuple(index)] + else: + a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a) + a = lax.sort(a, dimension=axis) + n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) + q = lax.mul(q, n - 1) + low = lax.floor(q) + high = lax.ceil(q) + high_weight = lax.sub(q, low) + low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) + + low = lax.clamp(_lax_const(low, 0), low, n - 1) + high = lax.clamp(_lax_const(high, 0), high, n - 1) + low = lax.convert_element_type(low, int) + high = lax.convert_element_type(high, int) + + slice_sizes = list(a_shape) + slice_sizes[axis] = 1 + dnums = lax.GatherDimensionNumbers( + offset_dims=tuple(range( + q_ndim, + len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)), + collapsed_slice_dims=() if keepdims else (axis,), + start_index_map=(axis,)) + low_value = lax.gather(a, low[..., None], dimension_numbers=dnums, + slice_sizes=slice_sizes) + high_value = lax.gather(a, high[..., None], dimension_numbers=dnums, + slice_sizes=slice_sizes) + if q_ndim == 1: + low_weight = lax.broadcast_in_dim(low_weight, low_value.shape, + broadcast_dimensions=(0,)) + high_weight = lax.broadcast_in_dim(high_weight, high_value.shape, + broadcast_dimensions=(0,)) + + if interpolation == "linear": + result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight), + lax.mul(high_value.astype(q.dtype), high_weight)) + elif interpolation == "lower": + result = low_value + elif interpolation == "higher": + result = high_value + elif interpolation == "nearest": + pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) + result = lax.select(pred, low_value, high_value) + elif interpolation == "midpoint": + result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) + else: + raise ValueError(f"interpolation={interpolation!r} not recognized") + if keepdims and keepdim: + if q_ndim > 0: + keepdim = [np.shape(q)[0], *keepdim] + result = result.reshape(keepdim) + return lax.convert_element_type(result, a.dtype) + +@_wraps(np.percentile, skip_params=['out', 'overwrite_input']) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', + 'keepdims', 'method')) +def percentile(a: ArrayLike, q: ArrayLike, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, method: str = "linear", + keepdims: bool = False, interpolation: None = None) -> Array: + check_arraylike("percentile", a, q) + q, = promote_dtypes_inexact(q) + return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, + interpolation=interpolation, method=method, keepdims=keepdims) + +@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input']) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', + 'keepdims', 'method')) +def nanpercentile(a: ArrayLike, q: ArrayLike, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, method: str = "linear", + keepdims: bool = False, interpolation: None = None) -> Array: + check_arraylike("nanpercentile", a, q) + q = ufuncs.true_divide(q, 100.0) + return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, + interpolation=interpolation, method=method, + keepdims=keepdims) + +@_wraps(np.median, skip_params=['out', 'overwrite_input']) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) +def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, + keepdims: bool = False) -> Array: + check_arraylike("median", a) + return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, + keepdims=keepdims, method='midpoint') + +@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input']) +@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) +def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, + keepdims: bool = False) -> Array: + check_arraylike("nanmedian", a) + return nanquantile(a, 0.5, axis=axis, out=out, + overwrite_input=overwrite_input, keepdims=keepdims, + method='midpoint') diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 63b39f07b408..23a8f8e849bc 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -107,6 +107,8 @@ float16 as float16, float32 as float32, float64 as float64, + float8_e4m3fn as float8_e4m3fn, + float8_e5m2 as float8_e5m2, float_ as float_, floating as floating, fmax as fmax, @@ -166,7 +168,6 @@ logspace as logspace, mask_indices as mask_indices, matmul as matmul, - median as median, meshgrid as meshgrid, moveaxis as moveaxis, msort as msort, @@ -175,9 +176,6 @@ nanargmax as nanargmax, nanargmin as nanargmin, argpartition as argpartition, - nanmedian as nanmedian, - nanpercentile as nanpercentile, - nanquantile as nanquantile, ndim as ndim, newaxis as newaxis, nonzero as nonzero, @@ -189,14 +187,12 @@ packbits as packbits, pad as pad, partition as partition, - percentile as percentile, pi as pi, piecewise as piecewise, place as place, printoptions as printoptions, promote_types as promote_types, put as put, - quantile as quantile, ravel as ravel, ravel_multi_index as ravel_multi_index, repeat as repeat, @@ -258,11 +254,6 @@ zeros_like as zeros_like, ) -from jax._src.numpy.lax_numpy import ( - float8_e4m3fn, - float8_e5m2, -) - from jax._src.numpy.index_tricks import ( c_ as c_, index_exp as index_exp, @@ -298,19 +289,25 @@ cumproduct as cumproduct, max as max, mean as mean, + median as median, min as min, nancumsum as nancumsum, nancumprod as nancumprod, nanmax as nanmax, nanmean as nanmean, + nanmedian as nanmedian, nanmin as nanmin, + nanpercentile as nanpercentile, nanprod as nanprod, + nanquantile as nanquantile, nanstd as nanstd, nansum as nansum, nanvar as nanvar, + percentile as percentile, prod as prod, product as product, ptp as ptp, + quantile as quantile, sometrue as sometrue, std as std, sum as sum, diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index f68137771905..ee38560d16ab 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -16,6 +16,7 @@ import collections from functools import partial import itertools +import unittest from absl.testing import absltest from absl.testing import parameterized @@ -655,6 +656,104 @@ def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) + @jtu.sample_product( + [dict(op=op, q_rng=q_rng) + for (op, q_rng) in ( + ("percentile", partial(jtu.rand_uniform, low=0., high=100.)), + ("quantile", partial(jtu.rand_uniform, low=0., high=1.)), + ("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)), + ("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)), + ) + ], + [dict(a_shape=a_shape, axis=axis) + for a_shape, axis in ( + ((7,), None), + ((47, 7), 0), + ((47, 7), ()), + ((4, 101), 1), + ((4, 47, 7), (1, 2)), + ((4, 47, 7), (0, 2)), + ((4, 47, 7), (1, 0, 2)), + ) + ], + a_dtype=default_dtypes, + q_dtype=[np.float32], + q_shape=scalar_shapes + [(1,), (4,)], + keepdims=[False, True], + method=['linear', 'lower', 'higher', 'nearest', 'midpoint'], + ) + def testQuantile(self, op, q_rng, a_shape, a_dtype, q_shape, q_dtype, + axis, keepdims, method): + a_rng = jtu.rand_some_nan(self.rng()) + q_rng = q_rng(self.rng()) + if "median" in op: + args_maker = lambda: [a_rng(a_shape, a_dtype)] + else: + args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] + + @jtu.ignore_warning(category=RuntimeWarning, + message="All-NaN slice encountered") + def np_fun(*args): + args = [x if jnp.result_type(x) != jnp.bfloat16 else + np.asarray(x, np.float32) for x in args] + if numpy_version <= (1, 22): + return getattr(np, op)(*args, axis=axis, keepdims=keepdims, + interpolation=method) + else: + return getattr(np, op)(*args, axis=axis, keepdims=keepdims, + method=method) + jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims, + method=method) + + # TODO(phawkins): we currently set dtype=False because we aren't as + # aggressive about promoting to float64. It's not clear we want to mimic + # Numpy here. + tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} + tol = max(jtu.tolerance(a_dtype, tol_spec), + jtu.tolerance(q_dtype, tol_spec)) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) + + @unittest.skipIf(not config.jax_enable_x64, "test requires X64") + @unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision") + def testPercentilePrecision(self): + # Regression test for https://github.com/google/jax/issues/8513 + x = jnp.float64([1, 2, 3, 4, 7, 10]) + self.assertEqual(jnp.percentile(x, 50), 3.5) + + @jtu.sample_product( + [dict(a_shape=a_shape, axis=axis) + for a_shape, axis in ( + ((7,), None), + ((47, 7), 0), + ((4, 101), 1), + ) + ], + a_dtype=default_dtypes, + keepdims=[False, True], + op=["median", "nanmedian"], + ) + def testMedian(self, op, a_shape, a_dtype, axis, keepdims): + if op == "median": + a_rng = jtu.rand_default(self.rng()) + else: + a_rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: [a_rng(a_shape, a_dtype)] + def np_fun(*args): + args = [x if jnp.result_type(x) != jnp.bfloat16 else + np.asarray(x, np.float32) for x in args] + return getattr(np, op)(*args, axis=axis, keepdims=keepdims) + jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims) + # TODO(phawkins): we currently set dtype=False because we aren't as + # aggressive about promoting to float64. It's not clear we want to mimic + # Numpy here. + tol_spec = {np.float32: 2e-4, np.float64: 5e-6} + tol = jtu.tolerance(a_dtype, tol_spec) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 42edd02c51ce..b559e78a7513 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3912,104 +3912,6 @@ def args_maker(): return [] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product( - [dict(op=op, q_rng=q_rng) - for (op, q_rng) in ( - ("percentile", partial(jtu.rand_uniform, low=0., high=100.)), - ("quantile", partial(jtu.rand_uniform, low=0., high=1.)), - ("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)), - ("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)), - ) - ], - [dict(a_shape=a_shape, axis=axis) - for a_shape, axis in ( - ((7,), None), - ((47, 7), 0), - ((47, 7), ()), - ((4, 101), 1), - ((4, 47, 7), (1, 2)), - ((4, 47, 7), (0, 2)), - ((4, 47, 7), (1, 0, 2)), - ) - ], - a_dtype=default_dtypes, - q_dtype=[np.float32], - q_shape=scalar_shapes + [(1,), (4,)], - keepdims=[False, True], - method=['linear', 'lower', 'higher', 'nearest', 'midpoint'], - ) - def testQuantile(self, op, q_rng, a_shape, a_dtype, q_shape, q_dtype, - axis, keepdims, method): - a_rng = jtu.rand_some_nan(self.rng()) - q_rng = q_rng(self.rng()) - if "median" in op: - args_maker = lambda: [a_rng(a_shape, a_dtype)] - else: - args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] - - @jtu.ignore_warning(category=RuntimeWarning, - message="All-NaN slice encountered") - def np_fun(*args): - args = [x if jnp.result_type(x) != jnp.bfloat16 else - np.asarray(x, np.float32) for x in args] - if numpy_version <= (1, 22): - return getattr(np, op)(*args, axis=axis, keepdims=keepdims, - interpolation=method) - else: - return getattr(np, op)(*args, axis=axis, keepdims=keepdims, - method=method) - jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims, - method=method) - - # TODO(phawkins): we currently set dtype=False because we aren't as - # aggressive about promoting to float64. It's not clear we want to mimic - # Numpy here. - tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} - tol = max(jtu.tolerance(a_dtype, tol_spec), - jtu.tolerance(q_dtype, tol_spec)) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) - - @unittest.skipIf(not config.jax_enable_x64, "test requires X64") - @unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision") - def testPercentilePrecision(self): - # Regression test for https://github.com/google/jax/issues/8513 - x = jnp.float64([1, 2, 3, 4, 7, 10]) - self.assertEqual(jnp.percentile(x, 50), 3.5) - - @jtu.sample_product( - [dict(a_shape=a_shape, axis=axis) - for a_shape, axis in ( - ((7,), None), - ((47, 7), 0), - ((4, 101), 1), - ) - ], - a_dtype=default_dtypes, - keepdims=[False, True], - op=["median", "nanmedian"], - ) - def testMedian(self, op, a_shape, a_dtype, axis, keepdims): - if op == "median": - a_rng = jtu.rand_default(self.rng()) - else: - a_rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: [a_rng(a_shape, a_dtype)] - def np_fun(*args): - args = [x if jnp.result_type(x) != jnp.bfloat16 else - np.asarray(x, np.float32) for x in args] - return getattr(np, op)(*args, axis=axis, keepdims=keepdims) - jnp_fun = partial(getattr(jnp, op), axis=axis, keepdims=keepdims) - # TODO(phawkins): we currently set dtype=False because we aren't as - # aggressive about promoting to float64. It's not clear we want to mimic - # Numpy here. - tol_spec = {np.float32: 2e-4, np.float64: 5e-6} - tol = jtu.tolerance(a_dtype, tol_spec) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) - @jtu.sample_product( shape=all_shapes, dtype=all_dtypes, From d138853ebe53cfffc31db3b95009f219eae53683 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 6 Feb 2023 11:32:28 -0500 Subject: [PATCH 51/65] Increase minimum NumPy version to 1.21. Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21. --- CHANGELOG.md | 1 + build/build.py | 4 ++-- jaxlib/setup.py | 2 +- setup.py | 4 ++-- tests/lax_numpy_reducers_test.py | 3 +-- tests/scipy_stats_test.py | 2 -- 6 files changed, 7 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16b1e6be4a77..f5c3929ea356 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Remember to align the itemized text with the first line of an item within a list like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects. + * JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer. * Deprecations * The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead, diff --git a/build/build.py b/build/build.py index aaa71e56d6be..0ac3a58e553b 100755 --- a/build/build.py +++ b/build/build.py @@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path): version = shell( [python_bin_path, "-c", "import numpy as np; print(np.__version__)"]) numpy_version = tuple(map(int, version.split(".")[:2])) - if numpy_version < (1, 20): - print("ERROR: JAX requires NumPy 1.20 or newer, found " + version + ".") + if numpy_version < (1, 21): + print("ERROR: JAX requires NumPy 1.21 or newer, found " + version + ".") sys.exit(-1) return version diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 1c607dfd3535..14a0d83eb73b 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -46,7 +46,7 @@ def has_ext_modules(self): author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension'], python_requires='>=3.8', - install_requires=['scipy>=1.5', 'numpy>=1.20', 'ml_dtypes>=0.0.3'], + install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.0.3'], url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ diff --git a/setup.py b/setup.py index 2c8b186e5f05..466603eda598 100644 --- a/setup.py +++ b/setup.py @@ -65,9 +65,9 @@ def generate_proto(source): python_requires='>=3.8', install_requires=[ 'ml_dtypes>=0.0.3', - 'numpy>=1.20', + 'numpy>=1.21', 'opt_einsum', - 'scipy>=1.5', + 'scipy>=1.7', ], extras_require={ # Minimum jaxlib version; used in testing. diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index ee38560d16ab..bb0d758d8e18 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -492,8 +492,7 @@ def np_fun(x): jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] - if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) def testReductionOfOutOfBoundsAxis(self): # Issue 888 diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index fb7a5af9a998..8ef546d0ab6a 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1322,8 +1322,6 @@ def evaluate_kde(kde, x): def testMode(self, shape, dtype, axis, contains_nans, keepdims): if scipy_version < (1, 9, 0) and keepdims != True: self.skipTest("scipy < 1.9.0 only support keepdims == True") - if numpy_version < (1, 21, 0) and contains_nans: - self.skipTest("numpy < 1.21.0 only support contains_nans == False") if contains_nans: rng = jtu.rand_some_nan(self.rng()) From b6b1e4216617d6e8866f0fbcaf97ce3ca9e33830 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 23 Mar 2023 18:38:26 -0700 Subject: [PATCH 52/65] Remove PJRT C API bypass. Now that all functionality needed by frameworks is implemented, let's remove the possibility of not noticing missing functionality due to the bypass. PiperOrigin-RevId: 519018438 --- tests/BUILD | 14 -------------- tests/infeed_test.py | 7 +++++++ tests/pjit_test.py | 3 +++ 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index dd692e7e12a1..d8df5b221e3f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -30,7 +30,6 @@ jax_generate_backend_suites() jax_test( name = "api_test", srcs = ["api_test.py"], - pjrt_c_api_bypass = True, shard_count = 10, ) @@ -165,7 +164,6 @@ jax_test( backend_tags = { "tpu": ["noasan"], # Times out. }, - pjrt_c_api_bypass = True, shard_count = { "cpu": 10, "gpu": 4, @@ -183,7 +181,6 @@ jax_test( backend_tags = { "tpu": ["notsan"], # Times out under tsan. }, - pjrt_c_api_bypass = True, shard_count = { "cpu": 5, "gpu": 5, @@ -198,7 +195,6 @@ jax_test( jax_test( name = "array_test", srcs = ["array_test.py"], - pjrt_c_api_bypass = True, tags = ["multiaccelerator"], deps = [ "//jax:experimental", @@ -220,7 +216,6 @@ jax_test( jax_test( name = "infeed_test", srcs = ["infeed_test.py"], - pjrt_c_api_bypass = True, deps = [ "//jax:experimental_host_callback", ], @@ -300,7 +295,6 @@ jax_test( "cpu": ["noasan"], # Test times out. "tpu": ["noasan"], # Test times out. }, - pjrt_c_api_bypass = True, shard_count = { "cpu": 40, "gpu": 40, @@ -311,7 +305,6 @@ jax_test( jax_test( name = "lax_numpy_operators_test", srcs = ["lax_numpy_operators_test.py"], - pjrt_c_api_bypass = True, shard_count = { "cpu": 30, "gpu": 30, @@ -322,7 +315,6 @@ jax_test( jax_test( name = "lax_numpy_reducers_test", srcs = ["lax_numpy_reducers_test.py"], - pjrt_c_api_bypass = True, shard_count = { "cpu": 20, "gpu": 20, @@ -412,7 +404,6 @@ jax_test( jax_test( name = "lax_test", srcs = ["lax_test.py"], - pjrt_c_api_bypass = True, shard_count = { "cpu": 40, "gpu": 40, @@ -558,7 +549,6 @@ jax_test( "noasan", # Times out under asan. ], }, - pjrt_c_api_bypass = True, shard_count = { "cpu": 30, "gpu": 30, @@ -785,7 +775,6 @@ jax_test( jax_test( name = "checkify_test", srcs = ["checkify_test.py"], - pjrt_c_api_bypass = True, shard_count = { "gpu": 2, "tpu": 2, @@ -871,7 +860,6 @@ jax_test( backend_tags = { "tpu": ["nomsan"], # TODO(b/213388298): this test fails msan. }, - pjrt_c_api_bypass = True, deps = [ "//jax:compilation_cache", "//jax:experimental", @@ -891,7 +879,6 @@ jax_test( name = "host_callback_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=true"], - pjrt_c_api_bypass = True, deps = [ "//jax:experimental", "//jax:experimental_host_callback", @@ -921,7 +908,6 @@ jax_test( jax_test( name = "host_callback_to_tf_test", srcs = ["host_callback_to_tf_test.py"], - pjrt_c_api_bypass = True, deps = [ "//jax:experimental_host_callback", "//jax:ode", diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 9c08b22b9844..6c30416d08cd 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -14,6 +14,7 @@ import threading +from unittest import SkipTest from absl.testing import absltest import jax @@ -21,6 +22,7 @@ from jax import config from jax.experimental import host_callback as hcb from jax._src import core +from jax._src import xla_bridge from jax._src.lib import xla_client import jax._src.test_util as jtu import numpy as np @@ -31,6 +33,11 @@ @jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # infeed class InfeedTest(jtu.JaxTestCase): + def setUp(self): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest("infeed not implemented in PJRT C API") + super().setUp() + @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. def testInfeed(self): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 72d335bee75f..73e91f32630c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -802,6 +802,9 @@ def f_for_pjit(x): @jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # outfeed def testOutfeed(self): + if xla_bridge.using_pjrt_c_api(): + raise unittest.SkipTest('outfeed not implemented in PJRT C API') + devices = np.array(jax.local_devices()) nr_devices = len(devices) shape = (nr_devices * 3, nr_devices * 5) From fd36ed65963c2112bd49616de0d1cd32fd19cdb8 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 22 Mar 2023 20:54:45 -0700 Subject: [PATCH 53/65] [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit --- jax/_src/api.py | 2 +- jax/_src/lax/lax.py | 17 + jax/_src/pjit.py | 199 +++++-- jax/interpreters/partial_eval.py | 5 +- tests/dynamic_api_test.py | 891 ++++++++++++++++--------------- 5 files changed, 616 insertions(+), 498 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 6ea645c28199..7ad1828795f6 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -308,7 +308,7 @@ def infer_params(*args, **kwargs): out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, device=device, backend=backend, keep_unused=keep_unused, - inline=inline, resource_env=None) + inline=inline, resource_env=None, abstracted_axes=abstracted_axes) return pjit.common_infer_params(pjit_info_args, *args, **kwargs) has_explicit_sharding = pjit._pjit_explicit_sharding( diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 73414e5b8479..d02e86860da0 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4710,4 +4710,21 @@ def handler(_, buf): return core.DArray(aval, buf) return handler + @staticmethod + def global_sharded_result_handler(aval, out_sharding, committed, + is_out_sharding_from_xla): + phys_aval, = BIntRules.physical_avals(aval) + phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] + + if not dispatch.is_single_device_sharding(out_sharding): + raise NotImplementedError # TODO(mattjj) + else: + phys_sharding = out_sharding + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, + is_out_sharding_from_xla) + + def handler(bufs): + return core.DArray(aval, phys_handler(bufs)) + return handler + core.bint._rules = BIntRules diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 490de581b80c..09613769dc95 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -31,9 +31,9 @@ from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax._src.interpreters.pxla import PartitionSpec -from jax.tree_util import ( +from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, - treedef_tuple) + treedef_tuple, broadcast_prefix, all_leaves) from jax._src.sharding import Sharding from jax._src.sharding_impls import ( @@ -66,6 +66,9 @@ distributed_debug_log, split_list, tuple_insert, weakref_lru_cache, merge_lists) +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + traceback_util.register_exclusion(__file__) class _FromGdaSingleton: @@ -162,7 +165,7 @@ def _get_arg_names(fun, in_tree, args_flat): def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, arg_names): arg_list = [] - for a, n in safe_zip(args_flat, arg_names): + for a, n in zip(args_flat, arg_names): da = a.sharding._device_assignment if hasattr(a, 'sharding') else None arg_list.append((n, da, shaped_abstractify(a))) @@ -312,9 +315,6 @@ def _resolve_axis_resources_and_shardings_arg( def pre_infer_params(fun, in_shardings, out_shardings, donate_argnums, static_argnums, static_argnames, device, backend, abstracted_axes): - # TODO(yashkatariya, mattjj): Remove when pjit supports dynamic shapes. - if jax.config.jax_dynamic_shapes: - raise ValueError("Dynamic shapes is not supported with pjit yet.") if abstracted_axes and not jax.config.jax_dynamic_shapes: raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes") @@ -414,12 +414,13 @@ class PjitInfo(NamedTuple): keep_unused: bool inline: bool resource_env: Any + abstracted_axes: Optional[Any] def common_infer_params(pjit_info_args, *args, **kwargs): (fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames, donate_argnums, device, backend, keep_unused, inline, - resource_env) = pjit_info_args + resource_env, abstracted_axes) = pjit_info_args if kwargs and not _is_unspecified(user_in_shardings): raise ValueError( @@ -435,6 +436,8 @@ def common_infer_params(pjit_info_args, *args, **kwargs): "Mesh context manager should not be used with jit when backend or " "device is also specified as an argument to jit.") + axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs) + jit_name = 'jit' if resource_env is None else 'pjit' dbg = debug_info(jit_name, fun, args, kwargs, static_argnums, static_argnames) f = lu.wrap_init(fun) @@ -448,10 +451,10 @@ def common_infer_params(pjit_info_args, *args, **kwargs): # leads to wrong expansion. if kwargs: f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs) - args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs)) + explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) flat_fun, out_tree = flatten_fun(f, in_tree) else: - args_flat, in_tree = tree_flatten(dyn_args) + explicit_args, in_tree = tree_flatten(dyn_args) flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree) dyn_kwargs = () del kwargs @@ -459,7 +462,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs): if donate_argnums and not jax.config.jax_debug_nans: donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs) else: - donated_invars = (False,) * len(args_flat) + donated_invars = (False,) * len(explicit_args) # If backend or device is set as an arg on jit, then resolve them to # in_shardings and out_shardings as if user passed in in_shardings @@ -475,25 +478,37 @@ def common_infer_params(pjit_info_args, *args, **kwargs): del user_in_shardings, user_out_shardings - global_in_avals = tuple(shaped_abstractify(a) for a in args_flat) + if config.jax_dynamic_shapes: + in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) + in_avals = tuple([a for a, e in in_type if e]) + else: + in_type = in_avals = tuple(shaped_abstractify(a) for a in explicit_args) canonicalized_in_shardings_flat = _process_in_axis_resources( - hashable_pytree(in_shardings), global_in_avals, in_tree, resource_env) + hashable_pytree(in_shardings), in_avals, in_tree, resource_env) jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( - flat_fun, hashable_pytree(out_shardings), global_in_avals, dbg, + flat_fun, hashable_pytree(out_shardings), in_type, dbg, HashableFunction(out_tree, closure=()), HashableFunction(res_paths, closure=())) if any(_is_from_gda(i) for i in canonicalized_in_shardings_flat): canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec( - canonicalized_in_shardings_flat, args_flat) + canonicalized_in_shardings_flat, explicit_args) + assert len(explicit_args) == len(canonicalized_in_shardings_flat) - assert len(args_flat) == len(canonicalized_in_shardings_flat) + if config.jax_dynamic_shapes: + implicit_args = _extract_implicit_args(in_type, explicit_args) + else: + implicit_args = [] + args_flat = [*implicit_args, *explicit_args] - canonicalized_in_shardings_flat = ( - _UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat - donated_invars = (False,) * len(consts) + donated_invars + num_extra_args = len(implicit_args) + len(consts) + canonicalized_in_shardings_flat = \ + (_UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat + donated_invars = (False,) * num_extra_args + donated_invars + assert (len(canonicalized_in_shardings_flat) == len(donated_invars) == + len(consts) + len(args_flat)) # in_shardings and out_shardings here are all GSPMDSharding. params = dict( @@ -506,9 +521,48 @@ def common_infer_params(pjit_info_args, *args, **kwargs): keep_unused=keep_unused, inline=inline, ) - return (consts + args_flat, global_in_avals, params, in_tree, out_tree(), + return (consts + args_flat, in_type, params, in_tree, out_tree(), donate_argnums) +def _extract_implicit_args( + in_type: Sequence[Tuple[core.AbstractValue, bool]], + explicit_args: Sequence[Any] +) -> Sequence[core.Tracer]: + """ + Given an input type and explicitly-passed arguments (per the user-facing API + calling convention), extract implicit axis size arguments from shapes of + explicit arguments (for the trace-time / jaxpr-level calling convention). + """ + # First, using `in_type` construct a list to represent the full argument list, + # leaving the implicit arguments as None placeholders for now. + explicit_args_ = iter(explicit_args) + args = [next(explicit_args_) if expl else None for _, expl in in_type] + assert next(explicit_args_, None) is None + del explicit_args, explicit_args_ + + # Next, populate the implicit arguments using the DBIdxs in `in_type`. + for i, (aval, explicit) in enumerate(in_type): + if not explicit or not isinstance(aval, core.DShapedArray): + continue # can't populate an implicit argument + arg = args[i] + assert arg is not None + for d1, d2 in zip(aval.shape, arg.aval.shape): + if isinstance(d1, core.DBIdx): + if args[d1.val] is None: + args[d1.val] = d2 + assert core.same_referent(args[d1.val], d2) + assert all(x is not None for x in args) + return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore + +def _flat_axes_specs(abstracted_axes, *args, **kwargs + ) -> Optional[List[pe.AbstractedAxesSpec]]: + if abstracted_axes is None: return None + if kwargs: raise NotImplementedError + def ax_leaf(l): + return (isinstance(l, dict) and all_leaves(l.values()) or + isinstance(l, tuple) and all_leaves(l, lambda x: x is None)) + return broadcast_prefix(abstracted_axes, args, ax_leaf) + # in_shardings and out_shardings can't be None as the default value # because `None` means that the input is fully replicated. @@ -683,7 +737,8 @@ def infer_params(*args, **kwargs): out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, device=device, backend=backend, keep_unused=keep_unused, - inline=inline, resource_env=resource_env) + inline=inline, resource_env=resource_env, + abstracted_axes=abstracted_axes) return common_infer_params(pjit_info_args, *args, **kwargs) has_explicit_sharding = _pjit_explicit_sharding( @@ -800,38 +855,44 @@ def __repr__(self): return "pytree leaf" @lru_cache(maxsize=4096) -def _process_in_axis_resources(in_shardings_thunk, global_in_avals, +def _process_in_axis_resources(in_shardings_thunk, in_type, in_tree, resource_env): orig_in_shardings = in_shardings_thunk() # Only do this if original in_shardings are unspecified. If they are # FROM_GDA or AUTO, go via flatten_axis_resources. if _is_unspecified(orig_in_shardings): - in_shardings_flat = (orig_in_shardings,) * len(global_in_avals) + in_shardings_flat = (orig_in_shardings,) * len(in_type) else: in_shardings_flat = flatten_axis_resources( "pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True) - pjit_check_aval_sharding(in_shardings_flat, global_in_avals, "pjit arguments", - allow_uneven_sharding=False) + if not config.jax_dynamic_shapes: + pjit_check_aval_sharding(in_shardings_flat, in_type, + "pjit arguments", allow_uneven_sharding=False) # TODO(yashkatariya): Only check for is_auto or _is_unspecified when # FROM_GDA is removed. canonicalized_shardings = tuple( i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim) - for i, aval in safe_zip(in_shardings_flat, global_in_avals)) + for i, aval in zip(in_shardings_flat, in_type)) return canonicalized_shardings @lu.cache -def _create_pjit_jaxpr(fun, global_in_avals, debug_info, out_paths): +def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths): with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for pjit in {elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT): pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for) - jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( - fun, global_in_avals, debug_info=pe_debug) + if config.jax_dynamic_shapes: + jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( + lu.annotate(fun, in_type), debug_info=pe_debug) + else: + jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( + fun, in_type, debug_info=pe_debug) - jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths()) + if not config.jax_dynamic_shapes: + jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths()) if any(isinstance(c, core.Tracer) for c in consts): closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) @@ -844,7 +905,7 @@ def _create_pjit_jaxpr(fun, global_in_avals, debug_info, out_paths): @lru_cache(maxsize=4096) def _check_and_canonicalize_out_shardings( - out_shardings_thunk, out_tree, global_out_avals): + out_shardings_thunk, out_tree, out_type): orig_out_shardings = out_shardings_thunk() # TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources # instead. This condition exists because flatten_axis_resources passes in an @@ -852,28 +913,29 @@ def _check_and_canonicalize_out_shardings( # pytrees (which shouldn't exist but they do). if (_is_unspecified(orig_out_shardings) or isinstance(orig_out_shardings, XLACompatibleSharding)): - out_shardings_flat = (orig_out_shardings,) * len(global_out_avals) + out_shardings_flat = (orig_out_shardings,) * len(out_type) else: out_shardings_flat = flatten_axis_resources( "pjit out_shardings", out_tree(), orig_out_shardings, tupled_args=False) - pjit_check_aval_sharding(out_shardings_flat, global_out_avals, "pjit outputs", - allow_uneven_sharding=False) + if not config.jax_dynamic_shapes: + pjit_check_aval_sharding(out_shardings_flat, out_type, "pjit outputs", + allow_uneven_sharding=False) canonicalized_out_shardings_flat = tuple( o if _is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim) - for o, aval in safe_zip(out_shardings_flat, global_out_avals) + for o, aval in zip(out_shardings_flat, out_type) ) return canonicalized_out_shardings_flat -def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, debug_info, out_tree, +def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, out_tree, result_paths): - jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr( - fun, global_in_avals, debug_info, result_paths) + jaxpr, final_consts, out_type = _create_pjit_jaxpr( + fun, in_type, debug_info, result_paths) canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings( - out_shardings_thunk, out_tree, tuple(global_out_avals)) + out_shardings_thunk, out_tree, tuple(out_type)) # lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple return jaxpr, final_consts, canonicalized_out_shardings_flat @@ -1118,7 +1180,7 @@ def _resolve_in_shardings( (None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat))) resolved_in_shardings = [] - for arg, pjit_in_s in safe_zip(args, pjit_in_shardings): + for arg, pjit_in_s in zip(args, pjit_in_shardings): arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True)) if hasattr(arg, 'sharding') else (_UNSPECIFIED, False)) if _is_unspecified(pjit_in_s): @@ -1201,7 +1263,7 @@ def _pjit_call_impl(*args, jaxpr, distributed_debug_log(("Running pjit'd function", name), ("in_shardings", in_shardings), ("out_shardings", out_shardings), - ("abstract args", list(map(xla.abstractify, args))), + ("abstract args", map(xla.abstractify, args)), ("fingerprint", fingerprint)) try: return compiled.unsafe_call(*args) @@ -1255,7 +1317,7 @@ def __eq__(self, other): return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding) else s == o - for s, o in safe_zip(self.shardings, other.shardings)) and + for s, o in zip(self.shardings, other.shardings)) and self.device_assignment == other.device_assignment) @@ -1341,10 +1403,43 @@ def pjit_staging_rule(trace, *args, **params): all(_is_unspecified(o) for o in params["out_shardings"])): jaxpr = params['jaxpr'] return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + elif config.jax_dynamic_shapes: + source_info = source_info_util.current() + out_tracers = [] + for aval in _out_type(params['jaxpr']): + if type(aval) is core.DShapedArray: + shape = [args[d.val] if type(d) is core.InDBIdx else + out_tracers[d.val] if type(d) is core.OutDBIdx else + d for d in aval.shape] + aval = aval.update(shape=tuple(core.get_referent(d) for d in shape)) + out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info)) + eqn = core.new_jaxpr_eqn( + map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params, + params['jaxpr'].effects, source_info) + trace.frame.add_eqn(eqn) + return out_tracers else: return trace.default_process_primitive(pjit_p, args, params) pe.custom_staging_rules[pjit_p] = pjit_staging_rule +# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them, +# since it's actually not possible in general to infer the type from the term +def _out_type(jaxpr: core.ClosedJaxpr) -> List[core.AbstractValue]: + out = [] + in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} + out_idx = {x: i for i, x in enumerate(jaxpr.jaxpr.invars) + if type(x) is core.Var} + for x in jaxpr.jaxpr.outvars: + aval = x.aval + if type(aval) is core.DShapedArray: + shape = [core.InDBIdx(in_idx[d]) if d in in_idx else + core.OutDBIdx(out_idx[d]) if d in out_idx else + d for d in x.aval.shape] + aval = aval.update(shape=tuple(shape)) + out.append(aval) + return out + + def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): return core._check_call(ctx_factory, pjit_p, in_atoms, dict(params, call_jaxpr=jaxpr.jaxpr)) @@ -1360,14 +1455,14 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, keep_unused, inline): effects = list(ctx.tokens_in.effects()) - output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) + output_types = map(mlir.aval_to_ir_types, ctx.avals_out) output_types = [mlir.token_type()] * len(effects) + output_types flat_output_types = util.flatten(output_types) arg_shardings = [None if _is_unspecified(i) else i._to_xla_op_sharding(aval.ndim) - for aval, i in safe_zip(ctx.avals_in, in_shardings)] + for aval, i in zip(ctx.avals_in, in_shardings)] result_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim) - for aval, o in safe_zip(ctx.avals_out, out_shardings)] + for aval, o in zip(ctx.avals_out, out_shardings)] # TODO(b/228598865): inlined calls cannot have shardings set directly on the # inputs or outputs because they are lost during MLIR->HLO conversion. @@ -1381,7 +1476,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, call = func_dialect.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(func.name.value), mlir.flatten_lowering_ir_args(args)) - out_nodes = util.unflatten(call.results, safe_map(len, output_types)) + out_nodes = util.unflatten(call.results, map(len, output_types)) tokens, out_nodes = split_list(out_nodes, [len(effects)]) tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens))) ctx.set_tokens_out(tokens_out) @@ -1486,7 +1581,7 @@ def _filter_zeros(is_nz_l, l): def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr, fwds_known: Tuple[Optional[int]]) -> core.ClosedJaxpr: updated_jaxpr = known_jaxpr.jaxpr.replace( - outvars=[x for x, i in safe_zip(known_jaxpr.jaxpr.outvars, fwds_known) + outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, fwds_known) if i is None]) return known_jaxpr.replace(jaxpr=updated_jaxpr) @@ -1505,7 +1600,7 @@ def _pjit_partial_eval(trace, *in_tracers, num_residuals = len(res_avals) def keep_where(l, should_keep): - return tuple(x for x, keep in zip(l, should_keep) if keep) + return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep) residual_shardings = (_UNSPECIFIED,) * num_residuals # Compute the known outputs @@ -1526,7 +1621,7 @@ def keep_where(l, should_keep): known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs) fwds_known_user = [ fwd if _is_unspecified(os) else None - for os, fwd in safe_zip(known_user_out_shardings, + for os, fwd in zip(known_user_out_shardings, fwds_known[:len(known_user_out_shardings)])] fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):] del fwds_known_user @@ -1534,7 +1629,7 @@ def keep_where(l, should_keep): # Remove forwarded outvars and out_shardings known_params['jaxpr'] = _known_jaxpr_fwd(known_params['jaxpr'], tuple(fwds_known)) known_out_shardings = tuple( - s for s, i in safe_zip(known_params['out_shardings'], fwds_known) if i is None) + s for s, i in zip(known_params['out_shardings'], fwds_known) if i is None) known_params['out_shardings'] = known_out_shardings del known_out_shardings @@ -1694,7 +1789,7 @@ def dce_jaxpr_pjit_rule(used_outputs: List[bool], eqn: core.JaxprEqn eqn.params['jaxpr'], tuple(used_outputs)) def keep_where(xs, keeps): - return tuple(x for x, keep in safe_zip(xs, keeps) if keep) + return tuple(x for x, keep in zip(xs, keeps) if keep) eqn_params = eqn.params new_params = dict( @@ -1847,7 +1942,7 @@ def with_sharding_constraint(x, axis_resources=_UNSPECIFIED, outs = [sharding_constraint_p.bind(xf, sharding=to_gspmd_sharding(i, xf.ndim), resource_env=resource_env, unconstrained_dims=ud) - for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)] + for xf, i, ud in zip(x_flat, shardings_flat, unconstrained_dims)] return tree_unflatten(tree, outs) def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims): @@ -1983,7 +2078,7 @@ def _gda_check_and_get_sharding( return to_gspmd_sharding(gda_sharding, ndim) out = [] - for in_sharding_flat, arg in safe_zip(in_shardings_flat, args_flat): + for in_sharding_flat, arg in zip(in_shardings_flat, args_flat): if is_auto(in_sharding_flat): out.append(in_sharding_flat) elif isinstance(arg, array.ArrayImpl): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b6d44923a5ad..7170e2d3305e 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1732,6 +1732,7 @@ def new_const(self, c): tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) + aval = self._lift_tracers_in_aval(aval) tracer = self._new_const(aval, c) return tracer @@ -1820,8 +1821,8 @@ def process_call(self, call_primitive, f, explicit_tracers, params): for aval, _ in out_type: if type(aval) is DShapedArray: shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else - out_tracers[d.val] if type(d) is OutDBIdx else - d for d in aval.shape] + out_tracers[d.val] if type(d) is OutDBIdx else + d for d in aval.shape] aval = aval.update(shape=tuple(get_referent(d) for d in shape)) out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) invars = map(self.getvar, in_tracers) diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 1d55b4e0fb53..191bbab446c3 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -37,9 +37,8 @@ python_version = (sys.version_info[0], sys.version_info[1]) -@unittest.skip("Test does not work with jax.Array") @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class DynamicShapeTest(jtu.JaxTestCase): +class DynamicShapeStagingTest(jtu.JaxTestCase): def test_basic_staging(self): def f(x, _): return x @@ -223,8 +222,8 @@ def f(n): # { lambda ; a:i32[]. let # b:f32[a] = bcast[dims=() shape=(None,)] 0.0 a - # c:f32[a] = xla_call[ - # call_jaxpr={ lambda ; d:i32[] e:f32[d]. let in (e,) } + # c:f32[a] = pjit[ + # jaxpr={ lambda ; d:i32[] e:f32[d]. let in (e,) } # name= # ] a b # in (c,) } @@ -242,8 +241,8 @@ def f(n): # { lambda ; a:i32[]. let # b:i32[] = mul a 2 # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 0.0 b - # d:f32[b] = xla_call[ - # call_jaxpr={ lambda ; e:i32[] f:f32[e]. let in (f,) } + # d:f32[b] = pjit[ + # jaxpr={ lambda ; e:i32[] f:f32[e]. let in (f,) } # name= # ] b c # in (d,) } @@ -269,8 +268,8 @@ def g(): # { lambda ; a:i32[]. let # b:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a # c:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a - # d:f32[a] = xla_call[ - # call_jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let + # d:f32[a] = pjit[ + # jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let # h:f32[e] = add f g # in (h,) } # name=g @@ -296,8 +295,8 @@ def g(x, y): # { lambda ; a:i32[]. let # b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 0.0 a - # c:f32[a] = xla_call[ - # call_jaxpr={ lambda ; d:i32[] e:f32[d] f:f32[d]. let + # c:f32[a] = pjit[ + # jaxpr={ lambda ; d:i32[] e:f32[d] f:f32[d]. let # g:f32[d] = add e f # in (g,) } # name=g @@ -345,8 +344,8 @@ def g(): # { lambda ; a:i32[]. let # b:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a - # c:f32[a] = xla_call[ - # call_jaxpr={ lambda ; d:i32[] e:f32[d]. let + # c:f32[a] = pjit[ + # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[d] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 d # g:f32[d] = add f e # in (g,) } @@ -375,8 +374,8 @@ def f(n): # b:i32[] = mul a 2 # c:f32[b,b] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 0.0 # b b - # d:f32[b] = xla_call[ - # call_jaxpr={ lambda ; e:i32[] f:f32[e,e]. let + # d:f32[b] = pjit[ + # jaxpr={ lambda ; e:i32[] f:f32[e,e]. let # g:f32[e] = reduce_sum[axes=(0,)] f # in (g,) } # name= @@ -413,8 +412,8 @@ def f(n): # c:f32[b,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 0.0 # b a # d:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b - # e:f32[b] = xla_call[ - # call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f,g] i:f32[f]. let + # e:f32[b] = pjit[ + # jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f,g] i:f32[f]. let # j:f32[f] = reduce_sum[axes=(1,)] h # k:f32[f] = add j i # in (k,) } @@ -446,8 +445,8 @@ def f(x): return jnp.sum(x) jaxpr = jax.make_jaxpr(f)(jnp.ones(3, jnp.dtype('float32'))) # { lambda ; a:f32[3]. let - # b:f32[] = xla_call[ - # call_jaxpr={ lambda ; c:i32[] d:f32[c]. let + # b:f32[] = pjit[ + # jaxpr={ lambda ; c:i32[] d:f32[c]. let # e:f32[] = reduce_sum[axes=(0,)] d # in (e,) } # name=f @@ -461,8 +460,8 @@ def f(x): b, = e.outvars self.assertLen(b.aval.shape, 0) - subjaxpr = e.params['call_jaxpr'] - c, d = subjaxpr.invars + subjaxpr = e.params['jaxpr'] + c, d = subjaxpr.jaxpr.invars self.assertLen(c.aval.shape, 0) self.assertLen(d.aval.shape, 1) self.assertIs(d.aval.shape[0], c) @@ -476,8 +475,8 @@ def fun(x): # { lambda ; a:i32[]. let # b:i32[] = add a a # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b - # d:f32[] = xla_call[ - # call_jaxpr={ lambda ; e:i32[] f:f32[e]. let + # d:f32[] = pjit[ + # jaxpr={ lambda ; e:i32[] f:f32[e]. let # g:f32[] = reduce_sum[axes=(0,)] f # in (g,) } # name=f @@ -491,8 +490,8 @@ def fun(x): self.assertIs(b, b_) self.assertIs(c, c_) - subjaxpr = e3.params['call_jaxpr'] - e, f = subjaxpr.invars + subjaxpr = e3.params['jaxpr'] + e, f = subjaxpr.jaxpr.invars self.assertLen(e.aval.shape, 0) self.assertLen(f.aval.shape, 1) self.assertIs(f.aval.shape[0], e) @@ -501,8 +500,8 @@ def test_jit_abstracted_axes_staging3(self): f = jax.jit(jnp.sum, abstracted_axes=('n',)) jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3.)) # { lambda ; a:i32[] b:f32[a]. let - # c:f32[] = xla_call[ - # call_jaxpr={ lambda ; d:i32[] e:f32[d]. let + # c:f32[] = pjit[ + # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[] = reduce_sum[axes=(0,)] e # in (f,) } # name=sum @@ -515,8 +514,8 @@ def test_jit_abstracted_axes_staging3(self): c, = e.outvars self.assertLen(c.aval.shape, 0) - subjaxpr = e.params['call_jaxpr'] - d, e = subjaxpr.invars + subjaxpr = e.params['jaxpr'] + d, e = subjaxpr.jaxpr.invars self.assertLen(d.aval.shape, 0) self.assertLen(e.aval.shape, 1) self.assertIs(e.aval.shape[0], d) @@ -525,8 +524,8 @@ def test_jit_abstracted_axes_return_polymorphic_shape(self): f = jax.jit(lambda x: x, abstracted_axes=('n',)) jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash # { lambda ; a:i32[3]. let - # b:i32[3] = xla_call[ - # call_jaxpr={ lambda ; c:i32[] d:i32[c]. let in (d,) } + # b:i32[3] = pjit[ + # jaxpr={ lambda ; c:i32[] d:i32[c]. let in (d,) } # name= # ] 3 a # in (b,) } @@ -546,8 +545,8 @@ def test_jit_abstracted_axes_return_polymorphic_shape2(self): with jax.enable_checks(False): jaxpr = jax.make_jaxpr(f)(3) # { lambda ; a:i32[]. let - # b:f32[a] = xla_call[ - # call_jaxpr={ lambda ; c:i32[]. let + # b:f32[a] = pjit[ + # jaxpr={ lambda ; c:i32[]. let # d:f32[c] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 # c # in (d,) } @@ -565,8 +564,8 @@ def test_jit_abstracted_axes_return_polymorphic_shape2(self): with jax.enable_checks(False): jaxpr = jax.make_jaxpr(lambda: f(3))() # { lambda ; . let - # a:f32[3] = xla_call[ - # call_jaxpr={ lambda ; b:i32[]. let + # a:f32[3] = pjit[ + # jaxpr={ lambda ; b:i32[]. let # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 # b # in (c,) } @@ -583,214 +582,9 @@ def test_jit_abstracted_axes_return_polymorphic_shape2(self): self.assertIsInstance(three_, int) self.assertEqual(three_, 3) - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_jit_basic_iree(self): - @jax.jit - def f(i): - return jnp.sum(jnp.ones(i, dtype='float32')) - self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True) - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_jit_basic_iree_2(self): - count = 0 - - @partial(jax.jit, abstracted_axes=('n',)) - def f(x): - nonlocal count - count += 1 - return jnp.sum(x) - - x = f(np.arange(3)) - y = f(np.arange(4)) - self.assertAllClose(x, 3., check_dtypes=False) - self.assertAllClose(y, 6., check_dtypes=False) - self.assertEqual(count, 1) - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_jit_polymorphic_output_iree(self): - # like test_jit_basic_iree, but without the jnp.sum! - count = 0 - - @jax.jit - def f(i): - nonlocal count - count += 1 - return jnp.ones(i, dtype='float32') - - self.assertAllClose(f(3), np.ones(3, dtype='float32'), check_dtypes=True) - self.assertAllClose(f(4), np.ones(4, dtype='float32'), check_dtypes=True) - self.assertEqual(count, 1) - - @unittest.skip('TODO: need typechecking rule for concatenate') - def test_concatenate(self): - @partial(jax.jit, abstracted_axes=({0: 'n'},)) - def f(x): # x: f32[n, 4] - return jnp.concatenate([x, x, x], axis=0) - - f(np.ones((5, 4), dtype=np.float32)) - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_reshape(self): - @partial(jax.jit, abstracted_axes=({0: 'n'},)) - def f(x): # x: f32[n, 4] - return jnp.reshape(x, (2, -1)) - - f(np.ones((5, 4), dtype=np.float32)) - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_nested(self): - @jax.jit - def nested_f(x): # f32[h, v] -> f32[h, v] - # A nested call that needs shape variables - return jnp.sin(x) - - @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'v'},)) - def f(x): # f32[h, w] -> f32[h, w] - return jnp.sin(x) + jax.jit(nested_f)(x) - f(np.ones((3, 5), dtype=np.float32)) - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_nested_arange(self): - def nested_f(x): # f32[h, v] -> f32[h, v] - # A nested call that needs to compute with shapes - return jnp.arange(x.shape[0] * x.shape[1], dtype=x.dtype).reshape(x.shape) - - @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> f32[h, w] - return x + jax.jit(nested_f)(x) - f(np.ones((3, 5), dtype=np.float32)) - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') - def test_transpose(self): - # see also https://github.com/iree-org/iree-jax/issues/57 - @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> f32[w, h] - return x.T - - f(np.ones((3, 5), dtype=np.float32)) # doesn't crash - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') - def test_matmul(self): - @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) - def f(x): # f32[w, w] -> f32[w, w] - return jnp.matmul(x, x) - - f(np.ones((5, 5), dtype=np.float32)) - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') - def test_matmul_shape_error(self): - @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> error - return jnp.matmul(x, x) - - # TODO(necula): improve error message, print actual shapes - with self.assertRaisesRegex(TypeError, - re.escape("dot_general requires contracting dimensions to have the same shape, got")): - f(np.ones((5, 5), dtype=np.float32)) - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - @unittest.skip("TODO: investigate failure") - def test_cond(self): - @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) - def f(x): # f32[w, w] -> f32[w, w] - return lax.cond(True, - lambda x: jnp.sin(x), - lambda x: jnp.matmul(x, x), x) - f(np.ones((5, 5), dtype=np.float32)) - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_arange(self): - @partial(jax.jit, abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w] - return jnp.arange(x.shape[0], dtype=x.dtype) + x - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - @unittest.skip('failing w/ iree error') - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_broadcast(self): - @partial(jax.jit, abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w, w] - return jnp.broadcast_to(x, (x.shape[0], x.shape[0])) - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_zeros(self): - @partial(jax.jit, abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w] - return jnp.zeros(x.shape[0], dtype=x.dtype) + x - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - @unittest.skip('failing w/ iree error') - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_stack(self): - @partial(jax.jit, abstracted_axes=({0: 'w'},)) - def f(x): - return jnp.stack([jnp.sin(x), jnp.cos(x)]) - - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_jit_dependent_pair_output_iree(self): - # Like the above 'polymorhpic output' test, but now with a `2 * n`! - count = 0 - - @jax.jit - def f(n): - nonlocal count - count += 1 - return jnp.arange(2 * n) - - x = f(3) - y = f(4) - self.assertAllClose(x, jnp.arange(2 * 3), check_dtypes=False) - self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False) - self.assertEqual(count, 1) - - @unittest.skip("revising slicing logic") - def test_slicing_basic(self): - f = jax.jit(lambda x, n: jnp.sum(x[:n])) - # TODO(mattjj): revise getslice, add typecheck rule for it, enable checks - with jax.enable_checks(False): - ans = f(jnp.arange(10), 3) - expected = jnp.sum(jnp.arange(10)[:3]) - self.assertAllClose(ans, expected, check_dtypes=True) - - # TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize - # operation 'while' that was explicitly marked illegal" - @unittest.skip("revising slicing logic") - def test_scan_basic(self): - def cumsum(x): - def body(i, _): - return i + 1, jnp.sum(x[:i+1]) - _, ans = lax.scan(body, 0, None, length=len(x)) - return ans - x = jnp.array([3, 1, 4, 1, 5, 9]) - with jax.enable_checks(False): - ans = cumsum(x) - expected = jnp.cumsum(x) - self.assertAllClose(ans, expected, check_dtypes=False) - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_jit_of_broadcast(self): - x = jax.jit(jnp.ones)(3) - self.assertAllClose(x, jnp.ones(3)) - - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_jit_of_broadcast2(self): - x = jax.jit(lambda n: jnp.ones(2 * n))(3) - self.assertAllClose(x, jnp.ones(2 * 3)) - +@unittest.skip("Test does not work with jax.Array") +@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") +class DynamicShapeAutodiffTest(jtu.JaxTestCase): def test_jvp_broadcast(self): @jax.jit def fn(n, x): @@ -800,8 +594,8 @@ def fn(n, x): lambda x, t: jax.jvp(lambda y: fn(3, y), (x,), (t,)) )(3., 4.) # { lambda ; a:f32[] b:f32[]. let - # c:f32[3] d:f32[3] = xla_call[ - # call_jaxpr={ lambda ; e:i32[] f:f32[] g:f32[]. let + # c:f32[3] d:f32[3] = pjit[ + # jaxpr={ lambda ; e:i32[] f:f32[] g:f32[]. let # h:f32[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] f e # i:f32[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] g e # in (h, i) } @@ -810,8 +604,8 @@ def fn(n, x): # in (c, d) } self.assertLen(outer_jaxpr.jaxpr.eqns, 1) eqn, = outer_jaxpr.jaxpr.eqns - self.assertIn('call_jaxpr', eqn.params) - jaxpr = eqn.params['call_jaxpr'] + self.assertIn('jaxpr', eqn.params) + jaxpr = eqn.params['jaxpr'].jaxpr self.assertLen(jaxpr.invars, 3) e, f, g = jaxpr.invars self.assertEqual(e.aval.shape, ()) @@ -834,8 +628,8 @@ def foo(x): x = t = jnp.arange(3.) outer_jaxpr = jax.make_jaxpr(lambda x, t: jax.jvp(foo, (x,), (t,)))(x, t) # { lambda ; a:f32[3] b:f32[3]. let - # c:f32[3] d:f32[3] = xla_call[ - # call_jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let + # c:f32[3] d:f32[3] = pjit[ + # jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let # h:f32[e] = sin f # i:f32[e] = cos f # j:f32[e] = mul g i @@ -845,8 +639,8 @@ def foo(x): # in (c, d) } self.assertLen(outer_jaxpr.jaxpr.eqns, 1) eqn, = outer_jaxpr.eqns - self.assertIn('call_jaxpr', eqn.params) - jaxpr = eqn.params['call_jaxpr'] + self.assertIn('jaxpr', eqn.params) + jaxpr = eqn.params['jaxpr'].jaxpr self.assertLen(jaxpr.invars, 3) e, f, g = jaxpr.invars self.assertEqual(e.aval.shape, ()) @@ -868,8 +662,8 @@ def foo(x): # primal computation outer_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(foo, x))(x) # { lambda ; a:f32[3]. let - # b:f32[3] c:f32[3] = xla_call[ - # call_jaxpr={ lambda ; d:i32[] e:f32[d]. let + # b:f32[3] c:f32[3] = pjit[ + # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[d] = sin e # g:f32[d] = cos e # in (f, g) } @@ -878,8 +672,8 @@ def foo(x): # in (b, c) } self.assertLen(outer_jaxpr.jaxpr.eqns, 1) eqn, = outer_jaxpr.jaxpr.eqns - self.assertIn('call_jaxpr', eqn.params) - jaxpr = eqn.params['call_jaxpr'] + self.assertIn('jaxpr', eqn.params) + jaxpr = eqn.params['jaxpr'].jaxpr self.assertLen(jaxpr.invars, 2) d, e = jaxpr.invars self.assertEqual(d.aval.shape, ()) @@ -898,15 +692,15 @@ def foo(x): outer_jaxpr = jax.make_jaxpr( lambda x, xdot: jax.linearize(foo, x)[1](xdot))(x, x) # { lambda ; a:f32[3] b:f32[3]. let - # _:f32[3] c:f32[3] = xla_call[ - # call_jaxpr={ lambda ; d:i32[] e:f32[d]. let + # _:f32[3] c:f32[3] = pjit[ + # jaxpr={ lambda ; d:i32[] e:f32[d]. let # f:f32[d] = sin e # g:f32[d] = cos e # in (f, g) } # name=foo # ] 3 a - # h:f32[3] = xla_call[ - # call_jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[i]. let + # h:f32[3] = pjit[ + # jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[i]. let # l:f32[i] = mul k j # in (l,) } # name=foo @@ -914,8 +708,8 @@ def foo(x): # in (h,) } self.assertLen(outer_jaxpr.jaxpr.eqns, 2) _, eqn = outer_jaxpr.jaxpr.eqns - self.assertIn('call_jaxpr', eqn.params) - jaxpr = eqn.params['call_jaxpr'] + self.assertIn('jaxpr', eqn.params) + jaxpr = eqn.params['jaxpr'].jaxpr self.assertLen(jaxpr.invars, 3) i, j, k = jaxpr.invars self.assertEqual(i.aval.shape, ()) @@ -933,10 +727,10 @@ def foo(x): x = jnp.arange(3.) outer_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(foo, x))(x) # { lambda ; a:f32[3]. let - # b:f32[3] c:f32[3] = xla_call[ - # call_jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] g:f32[d] = xla_call[ - # call_jaxpr={ lambda ; h:i32[] i:f32[h]. let + # b:f32[3] c:f32[3] = pjit[ + # jaxpr={ lambda ; d:i32[] e:f32[d]. let + # f:f32[d] g:f32[d] = pjit[ + # jaxpr={ lambda ; h:i32[] i:f32[h]. let # j:f32[h] = sin i # k:f32[h] = cos i # in (j, k) } @@ -962,16 +756,16 @@ def foo(x): x = jnp.arange(3.) outer_jaxpr = jax.make_jaxpr(jax.grad(foo))(x) # { lambda ; a:f32[3]. let - # _:f32[] b:f32[3] = xla_call[ - # call_jaxpr={ lambda ; c:i32[] d:f32[c]. let + # _:f32[] b:f32[3] = pjit[ + # jaxpr={ lambda ; c:i32[] d:f32[c]. let # e:f32[c] = sin d # f:f32[c] = cos d # g:f32[] = reduce_sum[axes=(0,)] e # in (g, f) } # name=foo # ] 3 a - # h:f32[3] = xla_call[ - # call_jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[]. let + # h:f32[3] = pjit[ + # jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[]. let # l:f32[i] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] k i # m:f32[i] = mul l j # in (m,) } @@ -980,8 +774,8 @@ def foo(x): # in (h,) } self.assertLen(outer_jaxpr.jaxpr.eqns, 2) fwd_eqn, bwd_eqn = outer_jaxpr.jaxpr.eqns - self.assertIn('call_jaxpr', fwd_eqn.params) - fwd_jaxpr = fwd_eqn.params['call_jaxpr'] + self.assertIn('jaxpr', fwd_eqn.params) + fwd_jaxpr = fwd_eqn.params['jaxpr'].jaxpr self.assertLen(fwd_jaxpr.invars, 2) c, d = fwd_jaxpr.invars self.assertEqual(c.aval.shape, ()) @@ -993,8 +787,8 @@ def foo(x): self.assertLen(fwd_eqn.outvars, 2) _, b = fwd_eqn.outvars self.assertEqual(b.aval.shape, (3,)) - self.assertIn('call_jaxpr', bwd_eqn.params) - bwd_jaxpr = bwd_eqn.params['call_jaxpr'] + self.assertIn('jaxpr', bwd_eqn.params) + bwd_jaxpr = bwd_eqn.params['jaxpr'].jaxpr self.assertLen(bwd_jaxpr.invars, 3) i, j, k = bwd_jaxpr.invars self.assertEqual(i.aval.shape, ()) @@ -1081,27 +875,391 @@ def loss_lin(params, batch): jaxpr = jax.make_jaxpr(jax.grad(loss))(params, batch) core.check_jaxpr(jaxpr.jaxpr) - @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") - def test_mlp_autodiff_dynamic_batch_iree(self): - count = 0 - - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs + def test_bint_broadcast(self): + d = lax.convert_element_type(3, core.bint(5)) + bint = lambda x, b: lax.convert_element_type(x, core.bint(b)) - def loss_ref(params, batch): - nonlocal count - count += 1 # count retraces - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) + x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash + self.assertIsInstance(x, core.DArray) + self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) + self.assertEqual( + x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, True)) - loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) + def f(n): + return jnp.zeros(n) + x = jax.jit(f)(d) + self.assertIsInstance(x, core.DArray) + self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) + self.assertEqual( + x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, False)) - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] + jaxpr = jax.make_jaxpr(f)(d).jaxpr + # { lambda ; a:bint{≤5}[]. let + # b:f32[a] = broadcast_in_dim[...] 0.0 a + # in (b,) } + self.assertLen(jaxpr.invars, 1) + a, = jaxpr.invars + self.assertEqual(a.aval, core.DShapedArray((), core.bint(5))) + self.assertLen(jaxpr.eqns, 1) + eqn, = jaxpr.eqns + self.assertLen(eqn.outvars, 1) + b, = eqn.outvars + self.assertEqual(b.aval.shape, (a,)) + + def test_vmap_abstracted_axis(self): + def foo(x, y): + z = jax.vmap(jnp.sin)(x) * y + return jax.vmap(jnp.add)(x, z) + + x = jnp.arange(3.) + jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n',))(x, x).jaxpr + self.assertLen(jaxpr.invars, 3) + a, b, c = jaxpr.invars + self.assertEqual(a.aval.shape, ()) + self.assertEqual(b.aval.shape, (a,)) + self.assertEqual(c.aval.shape, (a,)) + self.assertLen(jaxpr.eqns, 3) + self.assertLen(jaxpr.outvars, 1) + f, = jaxpr.outvars + self.assertEqual(f.aval.shape, (a,)) + + def test_vmap_abstracted_axes_2d(self): + def foo(x, y): + z = jax.vmap(jax.vmap(jnp.sin))(x) * y + return jax.vmap(jax.vmap(jnp.add))(x, z) + + x = jnp.arange(12.).reshape(3, 4) + jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n', 'm'))(x, x).jaxpr + self.assertLen(jaxpr.invars, 4) + a, b, c, d = jaxpr.invars + self.assertEqual(a.aval.shape, ()) + self.assertEqual(b.aval.shape, ()) + self.assertEqual(c.aval.shape, (a, b)) + self.assertEqual(c.aval.shape, (a, b)) + self.assertLen(jaxpr.eqns, 3) + self.assertLen(jaxpr.outvars, 1) + f, = jaxpr.outvars + self.assertEqual(f.aval.shape, (a, b)) + + def test_vmap_of_indexing_basic(self): + x = jnp.arange(3.) + + def f(idxs): + return jax.vmap(lambda i: x[i])(idxs) + + idxs = jnp.arange(3) + jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(idxs).jaxpr + # { lambda a:f32[3]; b:i32[] c:i32[b]. let + # d:bool[b] = lt c 0 + # e:i32[b] = add c 3 + # f:i32[b] = select_n d c e + # g:i32[b,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None, 1)] f b + # h:f32[b,1] = gather[ + # dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)) + # fill_value=None + # indices_are_sorted=False + # mode=GatherScatterMode.PROMISE_IN_BOUNDS + # slice_sizes=(1,) + # unique_indices=False + # ] a g + # i:f32[b] = squeeze[dimensions=(1,)] h + # in (i,) } + b, _ = jaxpr.invars + e, = (e for e in jaxpr.eqns if str(e.primitive) == 'gather') + h, = e.outvars + self.assertEqual(h.aval.shape, (b, 1)) + + def test_einsum_basic(self): + x = jnp.arange(20.).reshape(4, 5) + + def f(x): + return jnp.einsum('ij,kj->ik', x, x) + + jaxpr = jax.make_jaxpr(f, abstracted_axes=('n', 'm'))(x).jaxpr + # { lambda ; a:i32[] b:i32[] c:f32[a,b]. let + # d:f32[a,a] = pjit[ + # jaxpr={ lambda ; e:i32[] f:i32[] g:f32[e,f] h:f32[e,f]. let + # i:f32[e,e] = dot_general[ + # dimension_numbers=(((1,), (1,)), ((), ())) + # precision=None + # preferred_element_type=None + # ] g h + # in (i,) } + # name=_einsum + # ] a b c c + # in (d,) } + self.assertLen(jaxpr.invars, 3) + a, b, c = jaxpr.invars + self.assertEqual(c.aval.shape[0], a) + self.assertLen(jaxpr.eqns, 1) + self.assertLen(jaxpr.eqns[0].outvars, 1) + d, = jaxpr.eqns[0].outvars + self.assertEqual(d.aval.shape, (a, a)) + + def test_inferring_valid_subjaxpr_type_add(self): + def f(x): + return x + x.shape[0] + + jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash + + def test_slicing_basic_jaxpr(self): + def f(x): + return x[0] + + jaxpr = jax.make_jaxpr(f, abstracted_axes=(None, 'n'))(jnp.zeros((3, 4))) + # { lambda ; a:i32[] b:f32[3,a]. let + # c:f32[1,a] = dynamic_slice[slice_sizes=(1, None)] b 0 0 a + # d:f32[a] = squeeze[dimensions=(0,)] c + # in (d,) } + self.assertLen(jaxpr.jaxpr.invars, 2) + a, _ = jaxpr.jaxpr.invars + self.assertLen(jaxpr.jaxpr.outvars, 1) + d, = jaxpr.jaxpr.outvars + self.assertLen(d.aval.shape, 1) + self.assertEqual(d.aval.shape, (a,)) + + def test_shape_tuple_argument_to_zeros(self): + @partial(jax.jit, abstracted_axes=(('n',), ('n',))) + def f(x, y): + zero = jnp.zeros(jnp.shape(x)) + return zero * y + + x = jnp.arange(3.0) + y = jnp.arange(3.0) + 1 + jax.make_jaxpr(f)(x, y) # doesn't crash + +@unittest.skip("Test does not work with jax.Array") +@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") +class DynamicShapeExecutionTest(jtu.JaxTestCase): + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_jit_basic_iree(self): + @jax.jit + def f(i): + return jnp.sum(jnp.ones(i, dtype='float32')) + self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True) + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_jit_basic_iree_2(self): + count = 0 + + @partial(jax.jit, abstracted_axes=('n',)) + def f(x): + nonlocal count + count += 1 + return jnp.sum(x) + + x = f(np.arange(3)) + y = f(np.arange(4)) + self.assertAllClose(x, 3., check_dtypes=False) + self.assertAllClose(y, 6., check_dtypes=False) + self.assertEqual(count, 1) + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_jit_polymorphic_output_iree(self): + # like test_jit_basic_iree, but without the jnp.sum! + count = 0 + + @jax.jit + def f(i): + nonlocal count + count += 1 + return jnp.ones(i, dtype='float32') + + self.assertAllClose(f(3), np.ones(3, dtype='float32'), check_dtypes=True) + self.assertAllClose(f(4), np.ones(4, dtype='float32'), check_dtypes=True) + self.assertEqual(count, 1) + + @unittest.skip('TODO: need typechecking rule for concatenate') + def test_concatenate(self): + @partial(jax.jit, abstracted_axes=({0: 'n'},)) + def f(x): # x: f32[n, 4] + return jnp.concatenate([x, x, x], axis=0) + + f(np.ones((5, 4), dtype=np.float32)) + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_reshape(self): + @partial(jax.jit, abstracted_axes=({0: 'n'},)) + def f(x): # x: f32[n, 4] + return jnp.reshape(x, (2, -1)) + + f(np.ones((5, 4), dtype=np.float32)) + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_nested(self): + @jax.jit + def nested_f(x): # f32[h, v] -> f32[h, v] + # A nested call that needs shape variables + return jnp.sin(x) + + @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'v'},)) + def f(x): # f32[h, w] -> f32[h, w] + return jnp.sin(x) + jax.jit(nested_f)(x) + f(np.ones((3, 5), dtype=np.float32)) + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_nested_arange(self): + def nested_f(x): # f32[h, v] -> f32[h, v] + # A nested call that needs to compute with shapes + return jnp.arange(x.shape[0] * x.shape[1], dtype=x.dtype).reshape(x.shape) + + @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) + def f(x): # f32[h, w] -> f32[h, w] + return x + jax.jit(nested_f)(x) + f(np.ones((3, 5), dtype=np.float32)) + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') + def test_transpose(self): + # see also https://github.com/iree-org/iree-jax/issues/57 + @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) + def f(x): # f32[h, w] -> f32[w, h] + return x.T + + f(np.ones((3, 5), dtype=np.float32)) # doesn't crash + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') + def test_matmul(self): + @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) + def f(x): # f32[w, w] -> f32[w, w] + return jnp.matmul(x, x) + + f(np.ones((5, 5), dtype=np.float32)) + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test') + def test_matmul_shape_error(self): + @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},)) + def f(x): # f32[h, w] -> error + return jnp.matmul(x, x) + + # TODO(necula): improve error message, print actual shapes + with self.assertRaisesRegex(TypeError, + re.escape("dot_general requires contracting dimensions to have the same shape, got")): + f(np.ones((5, 5), dtype=np.float32)) + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + @unittest.skip("TODO: investigate failure") + def test_cond(self): + @partial(jax.jit, abstracted_axes=({0: 'w', 1: 'w'},)) + def f(x): # f32[w, w] -> f32[w, w] + return lax.cond(True, + lambda x: jnp.sin(x), + lambda x: jnp.matmul(x, x), x) + f(np.ones((5, 5), dtype=np.float32)) + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_arange(self): + @partial(jax.jit, abstracted_axes=({0: 'w'},)) + def f(x): # f32[w] -> f32[w] + return jnp.arange(x.shape[0], dtype=x.dtype) + x + f(np.ones((5,), dtype=np.float32)) + # TODO: add assertions + + @unittest.skip('failing w/ iree error') + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_broadcast(self): + @partial(jax.jit, abstracted_axes=({0: 'w'},)) + def f(x): # f32[w] -> f32[w, w] + return jnp.broadcast_to(x, (x.shape[0], x.shape[0])) + f(np.ones((5,), dtype=np.float32)) + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_zeros(self): + @partial(jax.jit, abstracted_axes=({0: 'w'},)) + def f(x): # f32[w] -> f32[w] + return jnp.zeros(x.shape[0], dtype=x.dtype) + x + f(np.ones((5,), dtype=np.float32)) + # TODO: add assertions + + @unittest.skip('failing w/ iree error') + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_stack(self): + @partial(jax.jit, abstracted_axes=({0: 'w'},)) + def f(x): + return jnp.stack([jnp.sin(x), jnp.cos(x)]) + + f(np.ones((5,), dtype=np.float32)) + # TODO: add assertions + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_jit_dependent_pair_output_iree(self): + # Like the above 'polymorhpic output' test, but now with a `2 * n`! + count = 0 + + @jax.jit + def f(n): + nonlocal count + count += 1 + return jnp.arange(2 * n) + + x = f(3) + y = f(4) + self.assertAllClose(x, jnp.arange(2 * 3), check_dtypes=False) + self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False) + self.assertEqual(count, 1) + + @unittest.skip("revising slicing logic") + def test_slicing_basic(self): + f = jax.jit(lambda x, n: jnp.sum(x[:n])) + # TODO(mattjj): revise getslice, add typecheck rule for it, enable checks + with jax.enable_checks(False): + ans = f(jnp.arange(10), 3) + expected = jnp.sum(jnp.arange(10)[:3]) + self.assertAllClose(ans, expected, check_dtypes=True) + + # TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize + # operation 'while' that was explicitly marked illegal" + @unittest.skip("revising slicing logic") + def test_scan_basic(self): + def cumsum(x): + def body(i, _): + return i + 1, jnp.sum(x[:i+1]) + _, ans = lax.scan(body, 0, None, length=len(x)) + return ans + x = jnp.array([3, 1, 4, 1, 5, 9]) + with jax.enable_checks(False): + ans = cumsum(x) + expected = jnp.cumsum(x) + self.assertAllClose(ans, expected, check_dtypes=False) + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_jit_of_broadcast(self): + x = jax.jit(jnp.ones)(3) + self.assertAllClose(x, jnp.ones(3)) + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_jit_of_broadcast2(self): + x = jax.jit(lambda n: jnp.ones(2 * n))(3) + self.assertAllClose(x, jnp.ones(2 * 3)) + + @unittest.skipIf(jtu.device_under_test() != 'iree', "iree test") + def test_mlp_autodiff_dynamic_batch_iree(self): + count = 0 + + def predict(params, inputs): + for W, b in params: + outputs = jnp.dot(inputs, W) + b + inputs = jnp.maximum(0, outputs) + return outputs + + def loss_ref(params, batch): + nonlocal count + count += 1 # count retraces + inputs, targets = batch + predictions = predict(params, inputs) + return jnp.sum((predictions - targets) ** 2) + + loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) + + params = [(jnp.ones((784, 256)), jnp.ones(256)), + (jnp.ones((256, 10)), jnp.ones( 10))] # two different size batches batch1 = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) @@ -1175,37 +1333,6 @@ def f(d): return d f(d) # doesn't crash - def test_bint_broadcast(self): - d = lax.convert_element_type(3, core.bint(5)) - bint = lambda x, b: lax.convert_element_type(x, core.bint(b)) - - x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash - self.assertIsInstance(x, core.DArray) - self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) - self.assertEqual( - x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, True)) - - def f(n): - return jnp.zeros(n) - x = jax.jit(f)(d) - self.assertIsInstance(x, core.DArray) - self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) - self.assertEqual( - x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, False)) - - jaxpr = jax.make_jaxpr(f)(d).jaxpr - # { lambda ; a:bint{≤5}[]. let - # b:f32[a] = broadcast_in_dim[...] 0.0 a - # in (b,) } - self.assertLen(jaxpr.invars, 1) - a, = jaxpr.invars - self.assertEqual(a.aval, core.DShapedArray((), core.bint(5))) - self.assertLen(jaxpr.eqns, 1) - eqn, = jaxpr.eqns - self.assertLen(eqn.outvars, 1) - b, = eqn.outvars - self.assertEqual(b.aval.shape, (a,)) - def test_bint_iota(self): def f(d): return jnp.arange(d, dtype='int32') @@ -1287,118 +1414,6 @@ def f(x): mlir_str = f_lowered.compiler_ir() self.assertIn('tensor', str(mlir_str)) - def test_vmap_abstracted_axis(self): - def foo(x, y): - z = jax.vmap(jnp.sin)(x) * y - return jax.vmap(jnp.add)(x, z) - - x = jnp.arange(3.) - jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n',))(x, x).jaxpr - self.assertLen(jaxpr.invars, 3) - a, b, c = jaxpr.invars - self.assertEqual(a.aval.shape, ()) - self.assertEqual(b.aval.shape, (a,)) - self.assertEqual(c.aval.shape, (a,)) - self.assertLen(jaxpr.eqns, 3) - self.assertLen(jaxpr.outvars, 1) - f, = jaxpr.outvars - self.assertEqual(f.aval.shape, (a,)) - - def test_vmap_abstracted_axes_2d(self): - def foo(x, y): - z = jax.vmap(jax.vmap(jnp.sin))(x) * y - return jax.vmap(jax.vmap(jnp.add))(x, z) - - x = jnp.arange(12.).reshape(3, 4) - jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n', 'm'))(x, x).jaxpr - self.assertLen(jaxpr.invars, 4) - a, b, c, d = jaxpr.invars - self.assertEqual(a.aval.shape, ()) - self.assertEqual(b.aval.shape, ()) - self.assertEqual(c.aval.shape, (a, b)) - self.assertEqual(c.aval.shape, (a, b)) - self.assertLen(jaxpr.eqns, 3) - self.assertLen(jaxpr.outvars, 1) - f, = jaxpr.outvars - self.assertEqual(f.aval.shape, (a, b)) - - def test_vmap_of_indexing_basic(self): - x = jnp.arange(3.) - - def f(idxs): - return jax.vmap(lambda i: x[i])(idxs) - - idxs = jnp.arange(3) - jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(idxs).jaxpr - # { lambda a:f32[3]; b:i32[] c:i32[b]. let - # d:bool[b] = lt c 0 - # e:i32[b] = add c 3 - # f:i32[b] = select_n d c e - # g:i32[b,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None, 1)] f b - # h:f32[b,1] = gather[ - # dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)) - # fill_value=None - # indices_are_sorted=False - # mode=GatherScatterMode.PROMISE_IN_BOUNDS - # slice_sizes=(1,) - # unique_indices=False - # ] a g - # i:f32[b] = squeeze[dimensions=(1,)] h - # in (i,) } - b, _ = jaxpr.invars - e, = (e for e in jaxpr.eqns if str(e.primitive) == 'gather') - h, = e.outvars - self.assertEqual(h.aval.shape, (b, 1)) - - def test_einsum_basic(self): - x = jnp.arange(20.).reshape(4, 5) - - def f(x): - return jnp.einsum('ij,kj->ik', x, x) - - jaxpr = jax.make_jaxpr(f, abstracted_axes=('n', 'm'))(x).jaxpr - # { lambda ; a:i32[] b:i32[] c:f32[a,b]. let - # d:f32[a,a] = xla_call[ - # call_jaxpr={ lambda ; e:i32[] f:i32[] g:f32[e,f] h:f32[e,f]. let - # i:f32[e,e] = dot_general[ - # dimension_numbers=(((1,), (1,)), ((), ())) - # precision=None - # preferred_element_type=None - # ] g h - # in (i,) } - # name=_einsum - # ] a b c c - # in (d,) } - self.assertLen(jaxpr.invars, 3) - a, b, c = jaxpr.invars - self.assertEqual(c.aval.shape[0], a) - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.eqns[0].outvars, 1) - d, = jaxpr.eqns[0].outvars - self.assertEqual(d.aval.shape, (a, a)) - - def test_inferring_valid_subjaxpr_type_add(self): - def f(x): - return x + x.shape[0] - - jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash - - def test_slicing_basic_jaxpr(self): - def f(x): - return x[0] - - jaxpr = jax.make_jaxpr(f, abstracted_axes=(None, 'n'))(jnp.zeros((3, 4))) - # { lambda ; a:i32[] b:f32[3,a]. let - # c:f32[1,a] = dynamic_slice[slice_sizes=(1, None)] b 0 0 a - # d:f32[a] = squeeze[dimensions=(0,)] c - # in (d,) } - self.assertLen(jaxpr.jaxpr.invars, 2) - a, _ = jaxpr.jaxpr.invars - self.assertLen(jaxpr.jaxpr.outvars, 1) - d, = jaxpr.jaxpr.outvars - self.assertLen(d.aval.shape, 1) - self.assertEqual(d.aval.shape, (a,)) - def test_slicing_basic_lower(self): @partial(jax.jit, abstracted_axes=(None, 'n')) def f(x): @@ -1428,16 +1443,6 @@ def f(i): self.assertEqual(y.shape, (sz, 4)) self.assertAllClose(y._data, x) - def test_shape_tuple_argument_to_zeros(self): - @partial(jax.jit, abstracted_axes=(('n',), ('n',))) - def f(x, y): - zero = jnp.zeros(jnp.shape(x)) - return zero * y - - x = jnp.arange(3.0) - y = jnp.arange(3.0) + 1 - jax.make_jaxpr(f)(x, y) # doesn't crash - @unittest.skip("Test does not work with jax.Array") @jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") class PileTest(jtu.JaxTestCase): From 10aeadba266035714fd825764d7e3fc3e68e3ed6 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 23 Mar 2023 20:16:23 -0700 Subject: [PATCH 54/65] fix jax.Array.round() fixes #15190 --- jax/_src/numpy/array_methods.py | 2 +- tests/lax_numpy_test.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 364d2f72875b..576236d0e144 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -725,7 +725,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "ravel": lax_numpy.ravel, "repeat": lax_numpy.repeat, "reshape": _reshape, - "round": round, + "round": lax_numpy.round, "searchsorted": lax_numpy.searchsorted, "sort": lax_numpy.sort, "squeeze": lax_numpy.squeeze, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index b559e78a7513..c68162d704c0 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -858,6 +858,10 @@ def testOperatorRound(self, jit): jround(jnp.array(1.234, jnp.float32)), check_dtypes=False) + def testRoundMethod(self): + # https://github.com/google/jax/issues/15190 + (jnp.arange(3.) / 5.).round() # doesn't crash + @jtu.sample_product(shape=[(5,), (5, 2)]) def testOperatorReversed(self, shape): rng = jtu.rand_default(self.rng()) From 9d60043506e8d39fb75aee64d28f0dcca5010824 Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Fri, 24 Mar 2023 05:42:01 -0700 Subject: [PATCH 55/65] [jaxlib] fix build w/ depenency on stablehlo_serialization PiperOrigin-RevId: 519120624 --- jaxlib/mlir/_mlir_libs/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 9bfb97424083..de9e166d6e45 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -171,6 +171,7 @@ py_extension( "@local_config_python//:headers", "@pybind11", "@stablehlo//:stablehlo_capi_headers", + "@stablehlo//:stablehlo_serialization", ], ) From a88509a844fa1115d2d6b42a475bfc9be73ae93a Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Fri, 24 Mar 2023 05:57:59 -0700 Subject: [PATCH 56/65] Migrate igammac_p off xla_fallback path It is now decomposed into stablehlo ops. PiperOrigin-RevId: 519122775 --- jax/_src/lax/special.py | 149 +++++++++++++++++++++++++++++- tests/filecheck/math.filecheck.py | 5 - 2 files changed, 144 insertions(+), 10 deletions(-) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index 804adaf7703e..83e30b91e845 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -26,7 +26,7 @@ convert_element_type, eq, exp, full_like, gt, le, log, log1p, lt, mul, neg, reciprocal, reduce, select, sign, square, standard_naryop, - standard_unop, xla, xops, + standard_unop, xla, xops, ne, div, sub, add, _broadcast_translate, _const, _dtype, _float, _nary_lower_hlo, _ones, _isnan, _reduce) from jax._src.lax.control_flow import while_loop @@ -234,9 +234,147 @@ def doit(a, x, dtype): return result def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): - # TODO(atondwal): implement _igammac_continued_fraction in JAX. - # Right now we fallback to the XLA implementation of IgammacContinuedFraction. - return igammac(a, x) + eps = dtypes.finfo(dtype).eps + + def cond_fn(vals): + enabled, _ans, _t, _y, _x, c, *_ = vals + return bitwise_and(c < _const(c, 2000), _any(enabled)) + + def body_fn(vals): + (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, + dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da) = vals + + c = c + _const(c, 1) + y = y + _const(y, 1) + z = z + _const(z, 2) + yc = y * c + pk = pkm1 * z - pkm2 * yc + qk = qkm1 * z - qkm2 * yc + qk_is_nonzero = ne(qk, _const(qk, 0)) + r = pk / qk + + t = select(qk_is_nonzero, abs(div(sub(ans, r), r)), full_like(r, 1)) + ans = select(qk_is_nonzero, r, ans) + + dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c + dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c + dans_da_new = select(qk_is_nonzero, div(dpk_da - ans * dqk_da, qk), dans_da) + grad_conditional = select(qk_is_nonzero, + abs(dans_da_new - dans_da), + full_like(dans_da, 1)) + + pkm2 = pkm1 + pkm1 = pk + qkm2 = qkm1 + qkm1 = qk + + dpkm2_da = dpkm1_da + dqkm2_da = dqkm1_da + dpkm1_da = dpk_da + dqkm1_da = dqk_da + + rescale = gt(abs(pk), reciprocal(_const(pk, eps))) + pkm2 = select(rescale, mul(pkm2, _const(pkm2, eps)), pkm2) + pkm1 = select(rescale, mul(pkm1, _const(pkm1, eps)), pkm1) + qkm2 = select(rescale, mul(qkm2, _const(qkm2, eps)), qkm2) + qkm1 = select(rescale, mul(qkm1, _const(qkm1, eps)), qkm1) + + dpkm2_da = select(rescale, mul(dpkm2_da, _const(dpkm2_da, eps)), dpkm2_da) + dqkm2_da = select(rescale, mul(dqkm2_da, _const(dqkm2_da, eps)), dqkm2_da) + dpkm1_da = select(rescale, mul(dpkm1_da, _const(dpkm1_da, eps)), dpkm1_da) + dqkm1_da = select(rescale, mul(dqkm1_da, _const(dqkm1_da, eps)), dqkm1_da) + + if mode == IgammaMode.VALUE: + conditional = bitwise_and(enabled, t > eps) + else: + conditional = bitwise_and(enabled, + grad_conditional > _const(grad_conditional, eps)) + + return (conditional, + select(enabled, ans, vals[1]), + select(enabled, t, vals[2]), + select(enabled, y, vals[3]), + select(enabled, z, vals[4]), + c, + select(enabled, pkm1, vals[6]), + select(enabled, qkm1, vals[7]), + select(enabled, pkm2, vals[8]), + select(enabled, qkm2, vals[9]), + select(enabled, dpkm2_da, vals[10]), + select(enabled, dqkm2_da, vals[11]), + select(enabled, dpkm1_da, vals[12]), + select(enabled, dqkm1_da, vals[13]), + select(enabled, dans_da_new, vals[14])) + + y = _const(a, 1) - a + z = x + y + _const(x, 1) + c = _const(x, 0) + pkm2 = full_like(x, 1) + qkm2 = x + pkm1 = x + _const(x, 1) + qkm1 = z * x + ans = pkm1 / qkm1 + t = full_like(x, 1) + dpkm2_da = full_like(x, 0) + dqkm2_da = full_like(x, 0) + dpkm1_da = full_like(x, 0) + dqkm1_da = -x + dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1 + init_vals = (enabled, ans, t, y, z, + c, pkm1, qkm1, pkm2, qkm2, + dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da) + + vals = while_loop(cond_fn, body_fn, init_vals) + ans = vals[1] + if mode == IgammaMode.VALUE: + return ans * ax + dans_da = vals[14] + dlogax_da = log(x) - digamma(a) + + if mode == IgammaMode.DERIVATIVE: + return mul(ax, add(mul(ans, dlogax_da), dans_da)) + elif mode == IgammaMode.SAMPLE_DERIVATIVE: + return neg(add(dans_da, mul(ans, dlogax_da)) * x) + else: + raise ValueError(f"Invalid mode: {mode}") + + +def igammac_impl(a, x): + broadcasted_shape = broadcast_shapes(a.shape, x.shape) + a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) + x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim))) + + def doit(a, x, dtype): + out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0))) + use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a)) + ax = a * log(x) - x - lgamma(a) + underflow = lt(ax, -log(dtypes.finfo(dtype).max)) + enabled = bitwise_not(bitwise_or(out_of_range, underflow)) + ax = exp(ax) + + igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma), + dtype, IgammaMode.VALUE) + igammac_cf_call = _igammac_continued_fraction(ax, x, a, + bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE) + + result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call) + x_is_infinity = eq(x, _const(x, float('inf'))) + result = select(x_is_infinity, full_like(result, 0), result); + return select(out_of_range, full_like(a, 1), result); + + needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16 + if needs_upcast: + a_dtype = a.dtype + a = convert_element_type(a, np.float32) + x = convert_element_type(x, np.float32) + a_x_type = np.float32 + else: + a_x_type = a.dtype + + result = doit(a, x, a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result lgamma_p = standard_unop(_float, 'lgamma') ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x))) @@ -256,7 +394,8 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): ad.defjvp(igamma_p, igamma_grada, igamma_gradx) igammac_p = standard_naryop([_float, _float], 'igammac') -xla.register_translation(igammac_p, partial(_broadcast_translate, xops.Igammac)) +mlir.register_lowering(igammac_p, + mlir.lower_fun(igammac_impl, multiple_results=False)) ad.defjvp(igammac_p, igammac_grada, igammac_gradx) diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index bfc5d73fc219..a60eab26cb88 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -262,11 +262,6 @@ def main(_): # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.gt) - # CHECK-LABEL: TEST: igammac float32[] float32[] - # CHECK: xla_fallback_igammac - # CHECK-SAME: tensor - print_ir(np.float32(0), np.float32(0))(lax.igammac) - # CHECK-LABEL: TEST: igamma_grad_a float32[] float32[] # CHECK: xla_fallback_igamma_grad_a # CHECK-SAME: tensor From c6bc3ed7a17c9cbc744ead329ca6525fce42ddf2 Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Fri, 24 Mar 2023 08:20:46 -0700 Subject: [PATCH 57/65] Migrate igamma_grad_a_p off xla_fallback PiperOrigin-RevId: 519148548 --- jax/_src/lax/special.py | 172 ++++++++++++++++-------------- tests/filecheck/math.filecheck.py | 5 - 2 files changed, 89 insertions(+), 88 deletions(-) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index 83e30b91e845..ecadcb589edf 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -189,49 +189,31 @@ def body_fn(vals): else: raise ValueError("Invalid IgammaMode") -def igamma_impl(a, x): - broadcasted_shape = broadcast_shapes(a.shape, x.shape) - a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) - x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim))) - - def doit(a, x, dtype): - is_nan = bitwise_or(_isnan(a), _isnan(x)) - x_is_zero = eq(x, _const(x, 0)) - x_is_infinity = eq(x, _const(x, float('inf'))) - domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0))) - use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a)) - ax = a * log(x) - x - lgamma(a) - underflow = lt(ax, -log(dtypes.finfo(dtype).max)) - ax = exp(ax) - enabled = bitwise_not( - _reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan])) - - output = select( - use_igammac, - _const(a, 1) - - _igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac), - dtype, IgammaMode.VALUE), - _igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)), - dtype, IgammaMode.VALUE) - ) - output = select(x_is_zero, full_like(a, 0), output) - output = select(x_is_infinity, full_like(a, 1), output) - output = select(bitwise_or(domain_error, is_nan), - full_like(a, float('nan')), output) - return output - - needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16 - if needs_upcast: - a_dtype = a.dtype - a = convert_element_type(a, np.float32) - x = convert_element_type(x, np.float32) - a_x_type = np.float32 - else: - a_x_type = a.dtype - result = doit(a, x, a_x_type) - if needs_upcast: - result = convert_element_type(result, a_dtype) - return result +def igamma_impl(a, x, dtype): + is_nan = bitwise_or(_isnan(a), _isnan(x)) + x_is_zero = eq(x, _const(x, 0)) + x_is_infinity = eq(x, _const(x, float('inf'))) + domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0))) + use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a)) + ax = a * log(x) - x - lgamma(a) + underflow = lt(ax, -log(dtypes.finfo(dtype).max)) + ax = exp(ax) + enabled = bitwise_not( + _reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan])) + + output = select( + use_igammac, + _const(a, 1) - + _igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac), + dtype, IgammaMode.VALUE), + _igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)), + dtype, IgammaMode.VALUE) + ) + output = select(x_is_zero, full_like(a, 0), output) + output = select(x_is_infinity, full_like(a, 1), output) + output = select(bitwise_or(domain_error, is_nan), + full_like(a, float('nan')), output) + return output def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): eps = dtypes.finfo(dtype).eps @@ -339,42 +321,64 @@ def body_fn(vals): raise ValueError(f"Invalid mode: {mode}") -def igammac_impl(a, x): - broadcasted_shape = broadcast_shapes(a.shape, x.shape) - a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) - x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim))) - - def doit(a, x, dtype): - out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0))) - use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a)) - ax = a * log(x) - x - lgamma(a) - underflow = lt(ax, -log(dtypes.finfo(dtype).max)) - enabled = bitwise_not(bitwise_or(out_of_range, underflow)) - ax = exp(ax) - - igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma), - dtype, IgammaMode.VALUE) - igammac_cf_call = _igammac_continued_fraction(ax, x, a, - bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE) - - result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call) - x_is_infinity = eq(x, _const(x, float('inf'))) - result = select(x_is_infinity, full_like(result, 0), result); - return select(out_of_range, full_like(a, 1), result); - - needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16 - if needs_upcast: - a_dtype = a.dtype - a = convert_element_type(a, np.float32) - x = convert_element_type(x, np.float32) - a_x_type = np.float32 - else: - a_x_type = a.dtype +def igammac_impl(a, x, dtype): + out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0))) + use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a)) + ax = a * log(x) - x - lgamma(a) + underflow = lt(ax, -log(dtypes.finfo(dtype).max)) + enabled = bitwise_not(bitwise_or(out_of_range, underflow)) + ax = exp(ax) + + igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma), + dtype, IgammaMode.VALUE) + igammac_cf_call = _igammac_continued_fraction(ax, x, a, + bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE) + + result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call) + x_is_infinity = eq(x, _const(x, float('inf'))) + result = select(x_is_infinity, full_like(result, 0), result) + return select(out_of_range, full_like(a, 1), result) + +def igamma_grad_a_impl(a, x, dtype): + is_nan = bitwise_or(_isnan(a), _isnan(x)) + x_is_zero = eq(x, full_like(x,0)) + domain_error = bitwise_or(lt(x, full_like(x, 0)), le(a, full_like(a, 0))) + use_igammac = bitwise_and(gt(x, full_like(x,1)), gt(x, a)) + ax = a * log(x) - x - lgamma(a) + underflow = lt(ax, -log(dtypes.finfo(dtype).max)) + ax = exp(ax) + enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or( + x_is_zero, domain_error), underflow), is_nan)) + output = select(use_igammac, + -_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac), + dtype, IgammaMode.DERIVATIVE), + _igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)), + dtype, IgammaMode.DERIVATIVE)) + output = select(x_is_zero, full_like(output,0), output) + output = select(bitwise_or(domain_error, is_nan), + full_like(a, float('nan')), output) + return output + +def _up_and_broadcast(doit): + def up_and_broadcast(a, x): + broadcasted_shape = broadcast_shapes(a.shape, x.shape) + a = broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) + x = broadcast_in_dim(x, broadcasted_shape, list(range(x.ndim))) + + needs_upcast = a.dtype == dtypes.bfloat16 or a.dtype == np.float16 + if needs_upcast: + a_dtype = a.dtype + a = convert_element_type(a, np.float32) + x = convert_element_type(x, np.float32) + a_x_type = np.float32 + else: + a_x_type = a.dtype - result = doit(a, x, a_x_type) - if needs_upcast: - result = convert_element_type(result, a_dtype) - return result + result = doit(a, x, a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result + return up_and_broadcast lgamma_p = standard_unop(_float, 'lgamma') ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x))) @@ -384,18 +388,20 @@ def doit(a, x, dtype): mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp)) igamma_p = standard_naryop([_float, _float], 'igamma') -mlir.register_lowering(igamma_p, - mlir.lower_fun(igamma_impl, multiple_results=False)) +mlir.register_lowering(igamma_p, mlir.lower_fun(_up_and_broadcast(igamma_impl), + multiple_results=False)) igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a') -xla.register_translation(igamma_grad_a_p, - partial(_broadcast_translate, xops.IgammaGradA)) +mlir.register_lowering(igamma_grad_a_p, + mlir.lower_fun(_up_and_broadcast(igamma_grad_a_impl), + multiple_results=False)) ad.defjvp(igamma_p, igamma_grada, igamma_gradx) igammac_p = standard_naryop([_float, _float], 'igammac') mlir.register_lowering(igammac_p, - mlir.lower_fun(igammac_impl, multiple_results=False)) + mlir.lower_fun(_up_and_broadcast(igammac_impl), + multiple_results=False)) ad.defjvp(igammac_p, igammac_grada, igammac_gradx) diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index a60eab26cb88..3cf953261e4d 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -262,11 +262,6 @@ def main(_): # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.gt) - # CHECK-LABEL: TEST: igamma_grad_a float32[] float32[] - # CHECK: xla_fallback_igamma_grad_a - # CHECK-SAME: tensor - print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a) - # CHECK-LABEL: TEST: imag complex64[] # CHECK: hlo.imag # CHECK-SAME: tensor> From aa37c741a1e34337b52eea45dbd8d42c0bb5d047 Mon Sep 17 00:00:00 2001 From: Anish Tondwalkar Date: Fri, 24 Mar 2023 08:48:55 -0700 Subject: [PATCH 58/65] Migrate random_gamma_grad off xla_fallback PiperOrigin-RevId: 519154537 --- jax/_src/lax/special.py | 26 +++++++++++++++++++++++--- tests/filecheck/math.filecheck.py | 5 ----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index ecadcb589edf..e76da9be3b64 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -320,7 +320,6 @@ def body_fn(vals): else: raise ValueError(f"Invalid mode: {mode}") - def igammac_impl(a, x, dtype): out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0))) use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a)) @@ -359,6 +358,26 @@ def igamma_grad_a_impl(a, x, dtype): full_like(a, float('nan')), output) return output +def random_gamma_grad_impl(a, x, dtype): + is_nan = bitwise_or(_isnan(a), _isnan(x)) + x_is_zero = eq(x, full_like(x,0)) + domain_error = bitwise_or(lt(x, full_like(x,0)), le(a, full_like(a,0))) + use_igammac = bitwise_and(gt(x, full_like(x,1)), gt(x, a)) + ax = a * log(x) - x - lgamma(a) + underflow = lt(ax, -log(dtypes.finfo(a.dtype).max)) + ax = exp(ax) + enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or + (x_is_zero, domain_error), underflow), is_nan)) + output = select(use_igammac, + -_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac), + dtype, IgammaMode.SAMPLE_DERIVATIVE), + _igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)), + dtype, IgammaMode.SAMPLE_DERIVATIVE)) + output = select(x_is_zero, full_like(output,0), output) + output = select(bitwise_or(domain_error, is_nan), + full_like(a, float('nan')), output) + return output + def _up_and_broadcast(doit): def up_and_broadcast(a, x): broadcasted_shape = broadcast_shapes(a.shape, x.shape) @@ -406,8 +425,9 @@ def up_and_broadcast(a, x): ad.defjvp(igammac_p, igammac_grada, igammac_gradx) random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') -xla.register_translation(random_gamma_grad_p, - partial(_broadcast_translate, xops.RandomGammaGrad)) +mlir.register_lowering(random_gamma_grad_p, + mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl), + multiple_results=False)) bessel_i0e_p = standard_unop(_float, 'bessel_i0e') xla.register_translation(bessel_i0e_p, standard_translate(bessel_i0e_p)) diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index 3cf953261e4d..6ae183c299c6 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -356,11 +356,6 @@ def integer_pow(x): return lax.integer_pow(x, 3) # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.pow) - # CHECK-LABEL: TEST: random_gamma_grad float32[] float32[] - # CHECK: xla_fallback_random_gamma_grad - # CHECK-SAME: tensor - print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad) - # CHECK-LABEL: TEST: real complex128[] # CHECK: hlo.real # CHECK-SAME: tensor> From 2900c787744185565f3cdfda3f2395b24a068533 Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Fri, 24 Mar 2023 07:59:11 -0700 Subject: [PATCH 59/65] remove another dependency not currently needed. --- .github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub b/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub index 444f9c691d97..77c4c315602a 100644 --- a/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub +++ b/.github/workflows/slurm_job_scripts/run_e2e_t5x_tests.sub @@ -49,7 +49,7 @@ rm -rf ${E2E_TESTS_WORKSPACE_DIR}/* \ && mkdir -p ${TFDS_DATA_DIR} \ && python3.8 -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ && git clone https://github.com/google-research/t5x.git ${T5X_DIR} \ -&& python3.8 -m pip uninstall -y cudf \ +&& python3.8 -m pip uninstall -y cudf dask-cudf \ && python3.8 -m pip install ${T5X_DIR} \ && python3.8 -m pip install ${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/dummy_wikipedia_config \ && hostname > ${E2E_TESTS_WORKSPACE_DIR}/hostname.txt From f1af74f0eab45e1ae9214ac1f38d53ace101664a Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Thu, 23 Mar 2023 11:57:20 -0700 Subject: [PATCH 60/65] WAR ssh timeout like: client_loop: send disconnect: Broken pipe https://github.com/google/jax/actions/runs/4500333187/jobs/7919324156#step:8:42 --- .github/workflows/nightly-ci-multiprocess-gpu.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/nightly-ci-multiprocess-gpu.yml b/.github/workflows/nightly-ci-multiprocess-gpu.yml index b99df82ee8eb..c593c7b138ff 100644 --- a/.github/workflows/nightly-ci-multiprocess-gpu.yml +++ b/.github/workflows/nightly-ci-multiprocess-gpu.yml @@ -60,7 +60,8 @@ jobs: echo "Host headnode User ${USER} HostName ${IP} - IdentityFile ${GITHUB_WORKSPACE}/.ssh/id_rsa" > ./.ssh/config + IdentityFile ${GITHUB_WORKSPACE}/.ssh/id_rsa + ServerAliveInterval 30" > ./.ssh/config - name: Check SLURM is working run: | From 14f1f60b6a2653031dfe054f5f7e6727be2a401a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 24 Mar 2023 09:59:55 -0700 Subject: [PATCH 61/65] After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host. PiperOrigin-RevId: 519170785 --- jax/_src/array.py | 15 ++++++++------- tests/array_test.py | 4 +--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 38d8dddceb15..cf86b40b64a1 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -304,14 +304,15 @@ def __getitem__(self, idx): arr = self._arrays[arr_idx] return _single_device_array_from_buf(arr, committed=False) return lax_numpy._rewriting_take(self, idx) - elif (dispatch.is_single_device_sharding(self.sharding) or - self.is_fully_replicated or _is_reduced_on_dim(idx)): - return lax_numpy._rewriting_take(self, idx) else: - # TODO(yashkatariya): Don't bounce to host and use `_rewriting_take` or - # the fast path (see PmapSharding branch above) after after uneven - # partitioning support is added - return api.device_put(self._value[idx]) + if xla_extension_version >= 144: + return lax_numpy._rewriting_take(self, idx) + else: + if (dispatch.is_single_device_sharding(self.sharding) or + self.is_fully_replicated or _is_reduced_on_dim(idx)): + return lax_numpy._rewriting_take(self, idx) + else: + return api.device_put(self._value[idx]) def __iter__(self): if self.ndim == 0: diff --git a/tests/array_test.py b/tests/array_test.py index 632e15f843f7..704416912cff 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -469,11 +469,9 @@ def test_array_getitem_mesh_pspec_sharding_multi_device(self): arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) - # TODO(yashkatariya): `__getitem__` with a specific index takes the fast - # path after b/245667823 is fixed. s = arr[2:4, 0:1] self.assertIsInstance(s, array.ArrayImpl) - self.assertArraysEqual(s, np.array([[4], [6]])) + self.assertArraysEqual(s, input_data[2:4, 0:1]) p = arr[:2] self.assertIsInstance(p, array.ArrayImpl) From b9bd60dc58f19b92b64db391ecbea2b41a8e1063 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 24 Mar 2023 11:14:59 -0700 Subject: [PATCH 62/65] Guard ArrayImpl checks by xla_extension_version. PiperOrigin-RevId: 519191714 --- jax/_src/array.py | 6 +++--- tests/lax_test.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index cf86b40b64a1..ca882fb50bf5 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -508,7 +508,7 @@ def copy_to_host_async(self): device_replica_id_map(self.sharding, self.shape)) for arr in self._arrays: if device_to_replica_id_map[arr.device()] == 0: - if isinstance(arr, ArrayImpl): + if xla_extension_version >= 140: arr._copy_single_device_array_to_host_async() else: arr.copy_to_host_async() @@ -536,7 +536,7 @@ def _value(self) -> np.ndarray: devices = [arr.device() for arr in self._arrays] for arr, d in zip(self._arrays, devices): if device_to_replica_id_map[d] == 0: - if isinstance(arr, ArrayImpl): + if xla_extension_version >= 140: arr._copy_single_device_array_to_host_async() else: arr.copy_to_host_async() @@ -545,7 +545,7 @@ def _value(self) -> np.ndarray: npy_value = np.empty(self.shape, self.dtype) for arr, d in zip(self._arrays, devices): if device_to_replica_id_map[d] == 0: - if isinstance(arr, ArrayImpl): + if xla_extension_version >= 140: npy_value[device_to_index_map[d]] = ( arr._single_device_array_to_np_array()) else: diff --git a/tests/lax_test.py b/tests/lax_test.py index e8abb8da7dfa..826335721429 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -39,6 +39,7 @@ from jax._src.interpreters import mlir from jax.interpreters import batching from jax._src import array +from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import hlo from jax._src import dtypes from jax._src.interpreters import pxla @@ -2854,7 +2855,8 @@ def handler(arr): buf, = arr._arrays else: buf, = arr - buf.aval = core.ShapedArray(buf.shape, buf.dtype) + if xla_extension_version < 140: + buf.aval = core.ShapedArray(buf.shape, buf.dtype) return FooArray(aval.shape, buf) return handler From 5c08753a19f7cec9474fce7f3bc57a03b4c014df Mon Sep 17 00:00:00 2001 From: John QiangZhang Date: Fri, 24 Mar 2023 11:26:44 -0700 Subject: [PATCH 63/65] [1/n] store embedded tf.graph to stablehlo.custom_call PiperOrigin-RevId: 519194911 --- jax/experimental/jax2tf/call_tf.py | 192 +++++++++++++++--- jax/experimental/jax2tf/tests/call_tf_test.py | 137 ++++++++++++- 2 files changed, 294 insertions(+), 35 deletions(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 60c10997ca2a..0f74c465e43a 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -22,9 +22,10 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax. """ +import base64 import enum import functools -from typing import Any, Callable, Optional, Sequence, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple, List from absl import logging @@ -62,8 +63,13 @@ # DLPack, if we are careful. _DLPACK_PLATFORMS = ("gpu",) -def call_tf(callable_tf: Callable, has_side_effects=True, - output_shape_dtype=None) -> Callable: + +def call_tf( + callable_tf: Callable, + has_side_effects=True, + output_shape_dtype=None, + use_custom_call=False, +) -> Callable: """Calls a TensorFlow function from JAX, with support for reverse autodiff. The ``callable_tf`` will be called with TensorFlow-compatible arguments ( @@ -72,7 +78,8 @@ def call_tf(callable_tf: Callable, has_side_effects=True, If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`, or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then - ``callable_tf`` will be compiled with ``tf.function(callable_tf, jit_compile=True)`` + ``callable_tf`` will be compiled with ``tf.function(callable_tf, + jit_compile=True)`` and the resulting XLA computation will be embedded in JAX's XLA computation. If ``call_tf`` appears outside a JAX staging context, it will be called inline @@ -84,7 +91,8 @@ def call_tf(callable_tf: Callable, has_side_effects=True, custom gradients that may be defined for the code in ``callable_tf``. For an example and more details see the - `README `_. + `README + `_. Args: callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow @@ -92,20 +100,29 @@ def call_tf(callable_tf: Callable, has_side_effects=True, has_side_effects: if True then it ensures that instances of this primitive are not removed or replicated by JAX optimizations such as dead-code elimination. - output_shape_dtype: An optional declaration of the expected shapes and dtypes - from the called TensorFlow function. If given it will be used during JAX - tracing to form the abstract values of the results of the `call_tf`. If - not given then we form a `tf.Graph` for the called TensorFlow function and - we use the TensorFlow-inferred shapes and types. Must be a pytree matching the - structure of the nested structure returned from the TensorFlow function, - containing objects with `.shape` and `.dtype` attributes, - e.g., `jax.ShapeDtypeStruct` or `jax.Array`. - + output_shape_dtype: An optional declaration of the expected shapes and + dtypes from the called TensorFlow function. If given it will be used + during JAX tracing to form the abstract values of the results of the + `call_tf`. If not given then we form a `tf.Graph` for the called + TensorFlow function and we use the TensorFlow-inferred shapes and types. + Must be a pytree matching the structure of the nested structure returned + from the TensorFlow function, containing objects with `.shape` and + `.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`. + use_custom_call: PLEASE DO NOT USE IT since it is experimental. We may + change the name in the future. Returns: a JAX callable that can be invoked with JAX pytree arguments, in - op-by-op mode or in a staged context. This callable can be used with - JAX's reverse-mode autodiff (:func:`jax.grad`). + op-by-op mode or in a staged context. This callable can be used with JAX's + reverse-mode autodiff (:func:`jax.grad`). """ + # TODO(johnqiangzhang): use_custom_call only work together with jax.convert + # native_serialization. currently we need users set both options manually. + # We need derive this automatically from jax2tf.convert context automatically. + if use_custom_call and output_shape_dtype is None: + raise ValueError( + "Please provide the output_shape_dtype if enable use_custom_call." + ) + @jax.custom_vjp def make_call(*args_jax): """We wrap it all in `make_call` so that we can attach custom VJP.""" @@ -153,13 +170,30 @@ def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]: f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}") assert len(output_avals) == len(res_tf_flat) - checked_res_tf_flat = [ - check_tf_result(i, r_tf, r_aval) - for i, (r_tf, r_aval) in enumerate( - zip(res_tf_flat, - (output_avals if output_avals is not None - else (None,) * len(res_tf_flat))))] - return checked_res_tf_flat + try: + checked_res_tf_flat = [ + check_tf_result(i, r_tf, r_aval) + for i, (r_tf, r_aval) in enumerate( + zip( + res_tf_flat, + ( + output_avals + if output_avals is not None + else (None,) * len(res_tf_flat) + ), + ) + ) + ] + return checked_res_tf_flat + except Exception as e: # pylint: disable=broad-except + # When a TensorFlow function is not XLA-compilable. + # TODO(johnqiangzhang): We skip the output shape check for use_custom_call. + # Since non-compilable functions may not have a defined output shape in the + # concrete_fn. I will add this check later. + if use_custom_call: + return [] + else: + raise e # Prepare a tf.function ahead of time, to cache the concrete functions. This # won't be used in op-by-op execution mode. @@ -172,7 +206,9 @@ def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]: function_flat_tf=function_flat_tf, args_flat_sig_tf=args_flat_sig_tf, output_avals=output_avals, - has_side_effects=has_side_effects) + has_side_effects=has_side_effects, + use_custom_call=use_custom_call, + ) # We must have called callable_flat_tf by nοw assert res_treedef is not None @@ -388,8 +424,17 @@ def is_fully_known_shape(s): call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) -def _call_tf_lowering(ctx, *args_op, platform, - function_flat_tf, args_flat_sig_tf, **_): +def _call_tf_lowering( + ctx, + *args_op, + platform, + function_flat_tf, + args_flat_sig_tf, + has_side_effects, + use_custom_call, + output_avals, + **_, +): # This will most likely hit the cache, because we used it for abstract_eval # We use the same TF lowering device as for the embedding JAX computation. # One example when this is needed is when the code refers to variables on one @@ -400,8 +445,14 @@ def _call_tf_lowering(ctx, *args_op, platform, tf_platform = "GPU" else: raise ValueError("platform {platform} not supported") - code_gen, _ = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, # type: ignore - tf_platform) + code_gen, _ = _code_generator_and_avals( + function_flat_tf, + args_flat_sig_tf, # type: ignore + tf_platform, + use_custom_call, + has_side_effects, + output_avals, + ) assert code_gen is not None return code_gen(ctx.module_context, args_op) @@ -411,9 +462,15 @@ def _code_generator_and_avals( function_flat_tf, args_flat_sig_tf, tf_platform, -) -> Tuple[Optional[Callable[[mlir.ModuleContext, Sequence[ir.Value]], - Sequence[ir.Value]]], - Sequence[core.ShapedArray]]: + use_custom_call, + has_side_effects, + output_avals, +) -> Tuple[ + Optional[ + Callable[[mlir.ModuleContext, Sequence[ir.Value]], Sequence[ir.Value]] + ], + Sequence[core.ShapedArray], +]: # TODO(necula): we have refactored the code to not need to lower the code # just in order to get the avals, so in fact the returned avals from this # function are never used. We keep it here for now in case we detect @@ -441,6 +498,22 @@ def _code_generator_and_avals( else: captured_inputs.append(inp) + def code_gen_custom_call(ctx, args_op): # pylint: disable=unused-argument + captured_ops = tuple( + mlir.ir_constant(np.asarray(inp), canonicalize_types=False) + for inp in captured_inputs + ) + with jax2tf_internal.inside_call_tf(): + return emit_tf_embedded_graph_custom_call( + concrete_function_flat_tf, + tuple(args_op) + captured_ops, + has_side_effects, + output_avals, + ) + + if use_custom_call: + return code_gen_custom_call, () + def convert_to_spec(x): if isinstance(x, tf.TensorSpec): return x @@ -534,3 +607,60 @@ def _jax2tf_call_tf(*args: TfVal, return res_tf_flat jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf + + +def emit_tf_embedded_graph_custom_call( + concrete_function_flat_tf, + operands: List[ir.Value], + has_side_effects, + output_avals, +): + """Emits MLIR about tf.graph custom_call. + + All call_tf caller function information is stored in tf.metadata. + This includes: + (1) The caller function name: This name will be used by the runtime to execute + the callback. + (2) The FunctionDef Dict: This list includes the caller function and all + related callees. By storing this information in tf.metadata, we can easily + retrieve it at runtime. + (3) The platform where to run this call_tf function. + """ + call_target_name = "tf_embedded_graph" + + # Generate metadata as attributes: + func_def_list = [concrete_function_flat_tf.function_def] + [ + func.definition + for func in concrete_function_flat_tf.graph._functions.values() + ] + # TODO(gleasonk): Here, we encode the tf.FunctionDef bytes using the base64 + # algorithm. We do this because StableHLO does not currently have a standard + # way to store bytes. + tf_metadata = { + "call_tf_func_name": ir.StringAttr.get(concrete_function_flat_tf.name), + "function_def_list": ir.ArrayAttr.get( + [ + ir.StringAttr.get(base64.b64encode(f.SerializeToString())) + for f in func_def_list + ], + ), + } + + result_avals = output_avals + + result_types = util.flatten( + [mlir.aval_to_ir_types(aval) for aval in result_avals] + ) + + result = hlo.CustomCallOp( + result_types, + operands, + call_target_name=ir.StringAttr.get(call_target_name), + has_side_effect=ir.BoolAttr.get(has_side_effects), + api_version=mlir.i32_attr(2), + called_computations=ir.ArrayAttr.get([]), + backend_config=ir.StringAttr.get(""), + ) + # Store TF metadata in unregistered attribute + result.attributes["tf_metadata"] = ir.DictAttr.get(tf_metadata) + return result.results diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index bbb809213132..16cd3d33531c 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for call_tf.""" +import base64 from functools import partial from typing import Callable, Dict, Tuple import unittest +from absl import logging from absl.testing import absltest from absl.testing import parameterized @@ -27,6 +29,8 @@ from jax.config import config from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util +from jax._src.lib.mlir import ir +from tensorflow.core.framework import function_pb2 import numpy as np @@ -541,11 +545,11 @@ def fun_jax_pure(x): grad_jax_pure = jax.grad(grad_jax_pure) res_jax = grad_jax(np.float32(5.)) - print(f"Grad of {degree} degree is {res_jax}") + logging.info("Grad of %s degree is %s", degree, res_jax) self.assertAllClose(res_jax, grad_jax_pure(np.float32(5.))) def test_pmap(self): - print(f"Running test_pmap on {jax.local_device_count()} devices") + logging.info("Running test_pmap on %s devices", jax.local_device_count()) def plus_2_tf(x): return tf.math.add(2., x) @@ -695,8 +699,8 @@ def cos_tf_sin_jax(x): # Uses TF gradient for `cos_tf` and JAX gradient for `sin` jax.grad(cos_tf_sin_jax)(x) - print(jax.make_jaxpr(cos_tf_sin_jax)(x)) - print(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text()) + logging.info(jax.make_jaxpr(cos_tf_sin_jax)(x)) + logging.info(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text()) def test_tf_gather(self): """tf_gather gradient output is tf.IndexSlices.""" @@ -1303,6 +1307,131 @@ def fn(x): _, (g_f4_ft,) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x]) self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy()) + @classmethod + def _walk_stablehlo_operations(cls, op, cb): + """walk the stablehlo operation recursive with callback function.""" + cb(op) + for region in op.operation.regions: + for block in region: + for op in block: + cls._walk_stablehlo_operations(op, cb) + + def test_use_custom_call(self): + const = tf.Variable(0.0, dtype=tf.float32) + + @tf.function(jit_compile=True) + def tf_func_1(x): + return x * x + const + + @tf.function + def tf_func_2(x, y): + return tf_func_1(x) + y + + @tf.function + def tf_func_3(x, y, z): + return tf_func_2(x, y) + z, z + + x = jnp.array(3.0, dtype=jnp.float32) + y = jnp.array(3.0, dtype=jnp.float32) + z = jnp.array(5.0, dtype=jnp.float32) + output_shape_dtype = ( + jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(z.shape, z.dtype), + ) + f_jax = jax.jit(jax2tf.call_tf(tf_func_3, use_custom_call=False)) + stablehlo_module = f_jax.lower(x, y, z).compiler_ir("stablehlo") + self.assertNotIn("stablehlo.custom_call", str(stablehlo_module)) + + f_jax = jax.jit( + jax2tf.call_tf( + tf_func_3, + use_custom_call=True, + output_shape_dtype=output_shape_dtype, + ) + ) + stablehlo_module = f_jax.lower(x, y, z).compiler_ir("stablehlo") + self.assertIn("stablehlo.custom_call", str(stablehlo_module)) + + concrete_function_flat_tf = tf_func_3.get_concrete_function(x, y, z) + expect_function_def_dict = {} + expect_function_def_dict[ + concrete_function_flat_tf.function_def.signature.name + ] = concrete_function_flat_tf.function_def + for k, v in concrete_function_flat_tf.graph._functions.items(): + expect_function_def_dict[k] = v.definition + + deserialized_function_def_dict = {} + + def extract_func_def(op): + if op.operation.name != "stablehlo.custom_call": + return + tf_metadata = ir.DictAttr(op.attributes["tf_metadata"]) + function_def_list = ir.ArrayAttr(tf_metadata["function_def_list"]) + + for fdef_str in function_def_list: + fdef_str_bytes = base64.b64decode(str(fdef_str)[1:-1]) + fdef = function_pb2.FunctionDef() + fdef.ParseFromString(fdef_str_bytes) + deserialized_function_def_dict.update({fdef.signature.name: fdef}) + + self._walk_stablehlo_operations(stablehlo_module, extract_func_def) + + for k, _ in expect_function_def_dict.items(): + self.assertEqual( + expect_function_def_dict[k], deserialized_function_def_dict[k] + ) + + def test_use_custom_call_non_compilable(self): + deserialized_function_def_dict = {} + + def extract_func_def(op): + if op.operation.name != "stablehlo.custom_call": + return + tf_metadata = ir.DictAttr(op.attributes["tf_metadata"]) + function_def_list = ir.ArrayAttr(tf_metadata["function_def_list"]) + + for fdef_str in function_def_list: + fdef_str_bytes = base64.b64decode(str(fdef_str)[1:-1]) + fdef = function_pb2.FunctionDef() + fdef.ParseFromString(fdef_str_bytes) + deserialized_function_def_dict.update({fdef.signature.name: fdef}) + + @tf.function(jit_compile=False) + def my_op(x): + return tf.py_function(np.sin, [x], tf.float32) + + x = jnp.ones([10], dtype=jnp.float32) + output_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype) + f_jax = jax.jit( + jax2tf.call_tf( + my_op, + use_custom_call=False, + output_shape_dtype=output_shape_dtype, + ) + ) + + f_jax = jax.jit( + jax2tf.call_tf( + my_op, + use_custom_call=True, + output_shape_dtype=output_shape_dtype, + ) + ) + stablehlo_module = f_jax.lower(x).compiler_ir("stablehlo") + concrete_function_flat_tf = my_op.get_concrete_function(x) + expect_function_def_dict = {} + expect_function_def_dict[ + concrete_function_flat_tf.function_def.signature.name + ] = concrete_function_flat_tf.function_def + for k, v in concrete_function_flat_tf.graph._functions.items(): + expect_function_def_dict[k] = v.definition + + self._walk_stablehlo_operations(stablehlo_module, extract_func_def) + for k, _ in expect_function_def_dict.items(): + self.assertEqual( + expect_function_def_dict[k], deserialized_function_def_dict[k] + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 5b4325866d29cf022fb9ae207608d26b0091d48a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 24 Mar 2023 12:32:53 -0700 Subject: [PATCH 64/65] Delete remote TPU support. TPU VMs are the only supported way to use TPUs as of JAX 0.4.0. PiperOrigin-RevId: 519211267 --- build/BUILD.bazel | 17 +- build/build.py | 11 +- build/build_wheel.py | 17 -- cloud_tpu_colabs/JAX_demo.ipynb | 24 --- cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb | 28 --- cloud_tpu_colabs/Pmap_Cookbook.ipynb | 24 --- cloud_tpu_colabs/Wave_Equation.ipynb | 5 +- docs/jax-101/06-parallelism.ipynb | 18 +- docs/jax-101/06-parallelism.md | 13 +- jax/_src/lib/__init__.py | 6 - jax/_src/xla_bridge.py | 13 -- jax/tools/colab_tpu.py | 54 +++--- tests/notebooks/colab_tpu.ipynb | 222 ---------------------- 13 files changed, 30 insertions(+), 422 deletions(-) delete mode 100644 tests/notebooks/colab_tpu.ipynb diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 114290dfff3d..b9f3de1ce3ea 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -23,18 +23,6 @@ licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) -bool_flag( - name = "enable_remote_tpu", - build_setting_default = False, -) - -config_setting( - name = "remote_tpu_enabled", - flag_values = { - ":enable_remote_tpu": "True", - }, -) - py_binary( name = "build_wheel", srcs = ["build_wheel.py"], @@ -47,10 +35,7 @@ py_binary( "@xla//xla/python:xla_client", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", - ]) + select({ - ":remote_tpu_enabled": ["@xla//xla/python/tpu_driver/client:py_tpu_client"], - "//conditions:default": [], - }) + if_cuda([ + ]) + if_cuda([ "//jaxlib/cuda:cuda_gpu_support", "@local_config_cuda//cuda:cuda-nvvm", ]) + if_rocm([ diff --git a/build/build.py b/build/build.py index 0ac3a58e553b..6a08d325f281 100755 --- a/build/build.py +++ b/build/build.py @@ -219,8 +219,7 @@ def write_bazelrc(*, python_bin_path, remote_build, cpu, cuda_compute_capabilities, rocm_amdgpu_targets, bazel_options, target_cpu_features, wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl, - enable_tpu, enable_remote_tpu, enable_rocm, - enable_plugin_device): + enable_tpu, enable_rocm, enable_plugin_device): tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: @@ -286,8 +285,6 @@ def write_bazelrc(*, python_bin_path, remote_build, f.write("build --config=nonccl\n") if enable_tpu: f.write("build --config=tpu\n") - if enable_remote_tpu: - f.write("build --//build:enable_remote_tpu=true\n") if enable_rocm: f.write("build --config=rocm\n") if not enable_nccl: @@ -375,10 +372,6 @@ def main(): parser, "enable_tpu", help_str="Should we build with Cloud TPU VM support enabled?") - add_boolean_argument( - parser, - "enable_remote_tpu", - help_str="Should we build with remote Cloud TPU support enabled?") add_boolean_argument( parser, "enable_rocm", @@ -514,7 +507,6 @@ def main(): print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no")) print("TPU enabled: {}".format("yes" if args.enable_tpu else "no")) - print("Remote TPU enabled: {}".format("yes" if args.enable_remote_tpu else "no")) print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no")) if args.enable_rocm: @@ -542,7 +534,6 @@ def main(): enable_cuda=args.enable_cuda, enable_nccl=args.enable_nccl, enable_tpu=args.enable_tpu, - enable_remote_tpu=args.enable_remote_tpu, enable_rocm=args.enable_rocm, enable_plugin_device=args.enable_plugin_device, ) diff --git a/build/build_wheel.py b/build/build_wheel.py index a3f386c77016..ab4941b99af0 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -117,19 +117,6 @@ def patch_copy_xla_extension_stubs(dst_dir): f.write(src) -def patch_copy_tpu_client_py(dst_dir): - with open(r.Rlocation("xla/xla/python/tpu_driver/client/tpu_client.py")) as f: - src = f.read() - src = src.replace("from xla.python import xla_extension as _xla", - "from . import xla_extension as _xla") - src = src.replace("from xla.python import xla_client", - "from . import xla_client") - src = src.replace( - "from xla.python.tpu_driver.client import tpu_client_extension as _tpu_client", - "from . import tpu_client_extension as _tpu_client") - with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f: - f.write(src) - def verify_mac_libraries_dont_reference_chkstack(): """Verifies that xla_extension.so doesn't depend on ____chkstk_darwin. @@ -250,10 +237,6 @@ def prepare_wheel(sources_path): copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir) patch_copy_xla_extension_stubs(jaxlib_dir) - if exists("xla/xla/python/tpu_driver/client/tpu_client_extension.so"): - copy_to_jaxlib("xla/xla/python/tpu_driver/client/tpu_client_extension.so") - patch_copy_tpu_client_py(jaxlib_dir) - def edit_jaxlib_version(sources_path): version_regex = re.compile(r'__version__ = \"(.*)\"') diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index c0aa1980b5c8..6a6993f44ed2 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -1,29 +1,5 @@ { "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "hLEyhfMqmnrt" - }, - "source": [ - "## Colab JAX TPU Setup" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "5CTEVmyKmkfp" - }, - "outputs": [], - "source": [ - "import jax.tools.colab_tpu\n", - "jax.tools.colab_tpu.setup_tpu()" - ] - }, { "cell_type": "markdown", "metadata": { diff --git a/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb b/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb index 103c4be01233..1777d3d1ef79 100644 --- a/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb +++ b/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb @@ -25,30 +25,6 @@ "Alex Alemi" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "j-n2r719AKee", - "colab_type": "text" - }, - "source": [ - "# Cloud TPU Setup" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ReFcuyaKAxh4", - "colab_type": "code", - "colab": {} - }, - "source": [ - "from jax.tools import colab_tpu\n", - "colab_tpu.setup_tpu()" - ], - "execution_count": 0, - "outputs": [] - }, { "cell_type": "markdown", "metadata": { @@ -76,10 +52,6 @@ "from jax import vmap, jit, grad, ops, lax, config\n", "from jax import random as jr\n", "\n", - "# The following is required to use TPU Driver as JAX's backend.\n", - "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", - "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", - "\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "import matplotlib.cm as cm\n", diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index 7a6ad5807f79..e5b4c4d61907 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -14,30 +14,6 @@ "accelerator": "TPU" }, "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "LpPtl0n4rg6L", - "colab_type": "text" - }, - "source": [ - "# Colab JAX TPU Setup" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4DYY4Yyhq8vG", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import jax.tools.colab_tpu\n", - "jax.tools.colab_tpu.setup_tpu()" - ], - "execution_count": 0, - "outputs": [] - }, { "cell_type": "markdown", "metadata": { diff --git a/cloud_tpu_colabs/Wave_Equation.ipynb b/cloud_tpu_colabs/Wave_Equation.ipynb index 15966551fecf..0591739191e0 100644 --- a/cloud_tpu_colabs/Wave_Equation.ipynb +++ b/cloud_tpu_colabs/Wave_Equation.ipynb @@ -42,10 +42,7 @@ "outputs": [], "source": [ "# Grab other packages for this demo.\n", - "!pip install -U -q Pillow moviepy proglog scikit-image\n", - "\n", - "import jax.tools.colab_tpu\n", - "jax.tools.colab_tpu.setup_tpu()" + "!pip install -U -q Pillow moviepy proglog scikit-image" ] }, { diff --git a/docs/jax-101/06-parallelism.ipynb b/docs/jax-101/06-parallelism.ipynb index 6699106d14e6..9211efc2a7be 100644 --- a/docs/jax-101/06-parallelism.ipynb +++ b/docs/jax-101/06-parallelism.ipynb @@ -25,23 +25,9 @@ "id": "7mCgBzix2fd3" }, "source": [ - "## Colab TPU Setup\n", + "## TPU Setup\n", "\n", - "If you're running this code in Google Colab, be sure to choose *Runtime*→*Change Runtime Type* and choose **TPU** from the Hardware Accelerator menu.\n", - "\n", - "Once this is done, you can run the following to set up the Colab TPU for use with JAX:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "hn7HtC2QS92b" - }, - "outputs": [], - "source": [ - "import jax.tools.colab_tpu\n", - "jax.tools.colab_tpu.setup_tpu()" + "This notebook requires multiple accelerators and we recommend running it using Kaggle TPU VMs." ] }, { diff --git a/docs/jax-101/06-parallelism.md b/docs/jax-101/06-parallelism.md index f48c72ef6cde..03688f08d551 100644 --- a/docs/jax-101/06-parallelism.md +++ b/docs/jax-101/06-parallelism.md @@ -27,18 +27,9 @@ Conceptually, this is not very different from vectorisation, where the same oper +++ {"id": "7mCgBzix2fd3"} -## Colab TPU Setup +## TPU Setup -If you're running this code in Google Colab, be sure to choose *Runtime*→*Change Runtime Type* and choose **TPU** from the Hardware Accelerator menu. - -Once this is done, you can run the following to set up the Colab TPU for use with JAX: - -```{code-cell} ipython3 -:id: hn7HtC2QS92b - -import jax.tools.colab_tpu -jax.tools.colab_tpu.setup_tpu() -``` +This notebook requires multiple accelerators and we recommend running it using Kaggle TPU VMs. +++ {"id": "gN6VbcdRTcdE"} diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 3144402c80bc..121cbbf3ac2d 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -114,12 +114,6 @@ def _xla_gc_callback(*args): # Version number for MLIR:Python APIs, provided by jaxlib. mlir_api_version = xla_client.mlir_api_version -try: - from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error -except: - tpu_driver_client = None # type: ignore - - # TODO(rocm): check if we need the same for rocm. cuda_path: Optional[str] cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda") diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index ad814b9c7a01..9aed9638d544 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -34,7 +34,6 @@ from jax._src import lib from jax._src import distributed from jax._src.config import flags, bool_env, config, int_env -from jax._src.lib import tpu_driver_client from jax._src.lib import xla_client from jax._src import traceback_util from jax._src import util @@ -164,16 +163,6 @@ def get_compile_options( # Backends -def _make_tpu_driver_client() -> Optional[xla_client.Client]: - if tpu_driver_client is None: - logger.info("Remote TPU is not linked into jax; skipping remote TPU.") - return None - if FLAGS.jax_backend_target is None: - logger.info("No --jax_backend_target was provided; skipping remote TPU.") - return None - return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target) - - def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]: def _log_warning(): warnings.warn( @@ -218,8 +207,6 @@ def register_backend_factory(name: str, factory: BackendFactory, *, register_backend_factory('cpu', partial(xla_client.make_cpu_client, use_tfrt=True), priority=0) -register_backend_factory('tpu_driver', _make_tpu_driver_client, - priority=100) def make_gpu_client( diff --git a/jax/tools/colab_tpu.py b/jax/tools/colab_tpu.py index 4b9a02f5a094..09cf2df43510 100644 --- a/jax/tools/colab_tpu.py +++ b/jax/tools/colab_tpu.py @@ -14,34 +14,26 @@ """Utilities for running JAX on Cloud TPUs via Colab.""" -import requests -import os - -from jax.config import config - -TPU_DRIVER_MODE = 0 - - -def setup_tpu(tpu_driver_version='tpu_driver_20230216'): - """Sets up Colab to run on TPU. - - Note: make sure the Colab Runtime is set to Accelerator: TPU. - - Args - ---- - tpu_driver_version : (str) specify the version identifier for the tpu driver. - Set to "tpu_driver_nightly" to use the nightly tpu driver build. - """ - global TPU_DRIVER_MODE - - if not TPU_DRIVER_MODE: - colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] - url = f'http://{colab_tpu_addr}:8475/requestversion/{tpu_driver_version}' - requests.post(url) - TPU_DRIVER_MODE = 1 - - # The following is required to use TPU Driver as JAX's backend. - config.FLAGS.jax_xla_backend = "tpu_driver" - config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] - # TODO(skyewm): Remove this after SPMD is supported for colab tpu. - config.update('jax_array', False) +import textwrap + +message = """ +As of JAX 0.4.0, JAX only supports TPU VMs, not the older Colab TPUs. + +We recommend trying Kaggle Notebooks +(https://www.kaggle.com/code, click on "New Notebook" near the top) which offer +TPU VMs. You have to create an account, log in, and verify your account to get +accelerator support. +Once you do that, there's a new "TPU 1VM v3-8" accelerator option. This gives +you a TPU notebook environment similar to Colab, but using the newer TPU VM +architecture. This should be a less buggy, more performant, and overall better +experience than the older TPU node architecture. + +It is also possible to use Colab together with a self-hosted Jupyter kernel +running on a Cloud TPU VM. See +https://research.google.com/colaboratory/local-runtimes.html +for details. +""" + +def setup_tpu(tpu_driver_version=None): + """Returns an error. Do not use.""" + raise RuntimeError(textwrap.dedent(message)) diff --git a/tests/notebooks/colab_tpu.ipynb b/tests/notebooks/colab_tpu.ipynb deleted file mode 100644 index de1afce36688..000000000000 --- a/tests/notebooks/colab_tpu.ipynb +++ /dev/null @@ -1,222 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "JAX Colab TPU Test", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "TPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WkadOyTDCAWD", - "colab_type": "text" - }, - "source": [ - "# JAX Colab TPU Test\n", - "\n", - "This notebook is meant to be run in a [Colab](http://colab.research.google.com) TPU runtime as a basic check for JAX updates." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "_tKNrbqqBHwu", - "colab_type": "code", - "outputId": "bf0043b0-6f2b-44e4-9822-4f426b3d158e", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 68 - } - }, - "source": [ - "import jax\n", - "import jaxlib\n", - "\n", - "!cat /var/colab/hostname\n", - "print(jax.__version__)\n", - "print(jaxlib.__version__)" - ], - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "text": [ - "tpu-s-2dna7uebo6z96\n", - "0.1.64\n", - "0.1.45\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DzVStuLobcoG", - "colab_type": "text" - }, - "source": [ - "## TPU Setup" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "IXF0_gNCRH08", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - } - }, - "source": [ - "import jax.tools.colab_tpu\n", - "jax.tools.colab_tpu.setup_tpu()" - ], - "execution_count": 2 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oqEG21rADO1F", - "colab_type": "text" - }, - "source": [ - "## Confirm Device" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "8BwzMYhKGQj6", - "outputId": "d51b7f21-d300-4420-8c5c-483bace8617d", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - } - }, - "source": [ - "from jaxlib import tpu_client_extension\n", - "import jax\n", - "key = jax.random.PRNGKey(1701)\n", - "arr = jax.random.normal(key, (1000,))\n", - "device = arr.device_buffer.device()\n", - "print(f\"JAX device type: {device}\")\n", - "assert isinstance(device, tpu_client_extension.TpuDevice), \"unexpected JAX device type\"" - ], - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "text": [ - "JAX device type: TPU_0(host=0,(0,0,0,0))\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z0FUY9yUC4k1", - "colab_type": "text" - }, - "source": [ - "## Matrix Multiplication" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "eXn8GUl6CG5N", - "outputId": "9954a064-ef8b-4db3-aad7-85d07b50f678", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - } - }, - "source": [ - "import jax\n", - "import numpy as np\n", - "\n", - "# matrix multiplication on GPU\n", - "key = jax.random.PRNGKey(0)\n", - "x = jax.random.normal(key, (3000, 3000))\n", - "result = jax.numpy.dot(x, x.T).mean()\n", - "print(result)" - ], - "execution_count": 6, - "outputs": [ - { - "output_type": "stream", - "text": [ - "1.021576\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jCyKUn4-DCXn", - "colab_type": "text" - }, - "source": [ - "## XLA Compilation" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab_type": "code", - "id": "2GOn_HhDPuEn", - "outputId": "a4384c55-41fb-44be-845d-17b86b152068", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 51 - } - }, - "source": [ - "@jax.jit\n", - "def selu(x, alpha=1.67, lmbda=1.05):\n", - " return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n", - "x = jax.random.normal(key, (5000,))\n", - "result = selu(x).block_until_ready()\n", - "print(result)" - ], - "execution_count": 8, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[ 0.34676817 -0.7532211 1.7060809 ... 2.120809 -0.42622015\n", - " 0.13093244]\n" - ], - "name": "stdout" - } - ] - } - ] -} From f83a3474f99bae038fe39bcfd44ede38fe144566 Mon Sep 17 00:00:00 2001 From: archis Date: Fri, 24 Mar 2023 13:06:08 -0700 Subject: [PATCH 65/65] docs --- docs/jax.scipy.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index cad15d09fc66..2a5546e2e6f6 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -85,6 +85,7 @@ jax.scipy.signal istft stft welch + hilbert jax.scipy.sparse.linalg -----------------------