-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Add same_lower for lax conv op (#1) #14990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 20, 2023
PiperOrigin-RevId: 518055103
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 20, 2023
PiperOrigin-RevId: 518055103
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 22, 2023
PiperOrigin-RevId: 518055103
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 22, 2023
PiperOrigin-RevId: 518055103
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this pull request
Mar 22, 2023
PiperOrigin-RevId: 518055103
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 23, 2023
PiperOrigin-RevId: 518055103
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 23, 2023
PiperOrigin-RevId: 518055103
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 23, 2023
PiperOrigin-RevId: 518055103
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 23, 2023
PiperOrigin-RevId: 518984018
clrpackages
pushed a commit
to clearlinux-pkgs/pypi-jax
that referenced
this pull request
Apr 4, 2023
…0.4.8 Adam Paszke (1): Optimize canonicalize_shape Anish Tondwalkar (12): geqrf_p and householder_product_p directly call custom_calls Eigh primitive is now a customcall [jax2tf] Add back_compat test for LuDecomposition LuDecomposition moved from fallback path to custom_call Refactor special functions into their own module. Migrate igamma_p off xla_fallback [jaxlib] fix build w/ depenency on stablehlo_serialization Migrate igammac_p off xla_fallback path Migrate igamma_grad_a_p off xla_fallback Migrate random_gamma_grad off xla_fallback Migrate besseli0e off xla_fallback Migrate regularized_incomplete_beta_p off xla_fallback Blake Hechtman (1): [LAX:RBG] Allow any type to RngBitGenerator. BF16 values are heavily quantized for long distributions which leads to failing the distribution test but in reality the distributions match. Colin Gaffney (1): Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing. Cristian Garcia (2): Fix typing for register_pytree_with_keys add trailing-whitespace pre-commit hook Emilio Cota (2): math_benchmark: add dot op math_benchmark: add --set_env flag Etienne Pot (1): Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses Frederic Bastien (6): Add reference to the C code where I was looking for them. Also add some high-level description of what is needed. Fix inspect_array_sharding with grad. WAR the dependency issue in the nightly CI container. Add missing file WAR ssh timeout like: remove another dependency not currently needed. George Necula (29): Stop using version 1 of XlaCallModuleOp [jax2tf] Work around a bug in jax2tf tests, to unblock the test [jax2tf] Refactor the JAX native lowering to separate out the TF-specific parts [jax2tf] Fix platform enforcement for native serialization [jax2tf] Refactor the native lowering [jax2tf] Documentation for the native serialization mode [jax2tf] Rename experimental_native_lowering to native_serialization [jax2tf] Add the first version of a custom call backwards compatibility test [jax2tf] Add tests for approx_top_k. [jax2tf] Refactor the backwards compatibility tests. [jax2tf] Add backwards compatibility tests for lax.eigh custom calls [jax2tf] Add Sharding backward compatibility test [jax2tf] Minor improvement in an error message [jax2tf] Re-enable fixed tests [jax2tf] Improvements to the documentation [jax2tf] Fix test that requires non-native serialization [jax2tf] Add backward compatibility tests for qr custom calls [jax2tf] Update CHANGELOG for native serialization. [jax2tf] Ensure that the gradient function is serialized natively. [jax2tf] Clean up the jax2tf sharding_tests. [jax2tf] Fix grad of pjit in native lowering. [jax2tf] Fix tests broken by upgrade of XlaCallModule [jax2tf] Minor addition to the documentation [jax2tf] Create a jax_export library with JAX-only pieces for native serialization [jax2tf] A simple failing test on TPU with native serialization [jax2tf] Turn an error into a warning with native serialization [shape_poly] Fixed bug with dimension variables in unused args [shape_poly] Refactor the computation of the dimension variables in native serialization Remove jax2tf experimental_native_lowering. Ivy Zheng (1): Add an optional `flatten_func` argument to custom node registration even when `flatten_with_keys` is given, for better perf for those in need. Jake VanderPlas (32): Fix jnp.sort & jnp.vdot in no-jit mode jnp.argmin/max: correctly handle out argument Add xla garbage collection to gc.callback jnp.argsort: fix annotations & behavior under disable_jit checkify_test: avoid passing argument to at[i].get() jnp.ndarray.at: deprecate passing additional arguments by position [typing] better annotations for jnp.ndarray.at Remove leading underscores in jax._src.numpy.util [typing] add type annotations to index_update code [sparse] add BCOO lowering for div jnp.arange: better validation of inputs README: improve Colab TPU installation discussion Add regression test for #4780 jnp.einsum: make signature match documentation jax.numpy reductions: validate axis for scalar input Improve error for tolist() and tobytes() on tracer objects DOC: remove jax 0.4.1 banner from index page Sharp bits: refer to ndarray.at in out-of-bound indexing discussion Improve error for indexing with string jnp.mean: fix incorrect return value for large arrays Document ShapeDtypeStruct jax.random: remove scale from wald function DOC: add formulae for distributions in jax.random internal: refactor array methods into separate private submodule Fix mypy issue in jax/experimental/jet.py lax_numpy: move quantile-based functions to reductions.py jax.typing: recommend instance check in Python 3.10 or newer [sparse] fix coo efficiency warning CI: add numpy & scipy to mypy env jax.scipy.linalg.expm: support batched inputs Add deprecation warnings for several top-level jax imports Add minimal pyproject.toml specifying build system Jake Vanderplas (1): Copybara import of the project: Jieying Luo (1): [PJRT C API] Add parsing PJRT client create options from json file. John QiangZhang (2): Add padding option "SAME_LOWER" for ticket jax-ml/jax#14990 [1/n] store embedded tf.graph to stablehlo.custom_call Kevin Gleason (2): Add support for StableHLO Serialized Portable Artifacts in JAX2TF. Improve handling of dynamic shapes in jax native serialization Mark Sandler (1): Fixes broken examples, and (invalid) comment for PartitionSpec Matthew Johnson (14): [shard_map] bug fix: extend axis env in partial_eval_custom rule update docs to remove stale reference to laziness optimization use Partial to make ravel_pytree unflatteners jit-friendly [dynamic-shapes] don't require buf objects have dtype attribute add test for #7155, fixes #7155 [pytrees] fix function underlying tree-flattening with keys make mlir arg and result names work with static_argnums/argnames make mlir arg and result names work with pmap separate register_pytree_node and register_pytree_with_keys tests add experimental jax_log_checkpoint_residuals option enable pjit recursive typechecking improve scan error messages fix jax.Array.round() [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit Misha (1): Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. Neil Girdhar (1): Correct register_pytree_with_keys annotation Parker Schuh (12): Move PyBuffer methods used by PyArray to c++. Rollforward with fixes: Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases [Rollforward] Move PyBuffer methods used by PyArray to c++. Hide jit-of-pmap warning. Use batched_device_put for pxla.shard_sharded_device_array_slow_path. Add benchmarks for np.array, device_put, and _arrays. Implement copy_to_host_async and _value with a single call to Avoid extra construction of ShapedArray in array __getitem__. Avoid extra construction of ShapedArray in array __getitem__. Delete the C++ GetEnableJaxArray() flag. Redefine `compile_and_serialize` as `serialize(lowered.compile())`. Guard ArrayImpl checks by xla_extension_version. Peter Hawkins (46): Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets. Redefine jnp.DeviceArray as jax.Array. Increase precision of detrend test on TPU. Split source_info_util into its own Bazel target. Split Mesh and ResourceEnv into a new module jax._src.mesh. [JAX] Split _src/xla_bridge.py into a separate Bazel target. Split _src/profiler into a separate BUILD target. Split _src files custom_api_util, deprecations, effects and environment_info into separate Bazel targets. No changes. Move _src/tree_util.py into a separate Bazel target. Split _src/mesh into a separate Bazel target. Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py Split basearray into separate Bazel module. Switch JAX to use the OpenXLA repository. Fix build breakage from OpenXLA switch. Improve type of jnp.mgrid[...]. Make PRNG seed types more liberal in what they accept. Relax the argument type annotation of dynamic_index_in_dim. Make Tracer types on JaxprTrace more precise. Add device_buffers property to jax.Array type. Improve pytype inference for Sharding type. Mark jax.numpy.DeviceArray as deprecated. Use jax.Array instead. Relax argument type annotation for lax.dynamic_slice. Remove references to jax.config.jax_array, which is always True at head. Relax type of the `values` argument to .at[...].set(...) and friends. [JAX] Check for AttributeError from getattr(), not KeyError. [JAX] Delete ShardedDeviceArray. Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated. Relax type annotations on lax slicing functions. Fix mypy failures in jax2tf. Revert: `custom_vjp` symbolic zeros support [XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes. [XLA:Python] Allow passing ExecutableBuildOptions to outfeed receiver. Move jax._src.typing into a separate Bazel target. Increase minimum NumPy version to 1.21. Delete remote TPU support. Suppress mypy warnings about missing imports. Split dtype argument from other arguments in special functions. Add support for using pip-installed CUDA wheels. Fix duplicate definition of 'cuda' extra in setup.py. Use pytype_strict_library() in Bazel build rules. Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval. Split core.py and several files in an SCC with it into a separate Bazel build target. Update the CUDA installation instructions. Add version constraints to CUDA pip wheel dependencies. Recommend --local_test_jobs in bazel test command line on GPU. Ravin Kumar (2): Update user_guides.rst Fix hessian llnk Rebecca Chen (2): Silence some pytype errors. Silence some pytype errors. Roy Frostig (1): `custom_vjp` symbolic zeros support Ruoxin Sang (1): Fix typo "compileable"->"compilable". Sharad Vikram (3): Add print statement to help debug spurious test failure Copy seq_lengths before creating descriptor Add jaxlib version guard for rnn test Shawn Presser (1): autodidax: fix jaxpr_subcomp return type annotation Skye Wanderman-Milne (5): Update versions and changelog for jax + jaxlib 0.4.6 release Remove PJRT C API bypass. Remove 'pjrt_c_api_unimplemented' pytest mark. Bump minimum jaxlib version from 0.4.6 to 0.4.7. Turn on PJRT C API by default. Yash Katariya (52): Fix the type annotation on shard_args Use shard_args and global_result_handlers since the `aval_to_result_handler` and `dispatch.device_put` will be removed soon. Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases Make `.devices()` a set rather than a list because the code looks at sharding.device_set [Rollback] Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases Fix copy_array_to_devices_with_sharding by making it take a committed argument so that the Array created has the right semantics. [Rollback] Move PyBuffer methods used by PyArray to c++. Rollback the device_put_sharded and device_put_replicated change of using batched_device_put [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put Go via `_rewriting_take` if reducing on a dim for __getitem__ so that we can preserve the sharding and run it via XLA which will do sharding propagation. Bump minimum jaxlib version to 0.4.6 which means xla_extension_version == 137 and mlir_api_version == 45 Don't use the dispatch.device_put path since it is deprecated and will be removed soon. Add Cuda 12 build configs to bazelrc If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too. Fix the usage of device_put_handlers since that is deprecated. Use batched_device_put instead Make `pxla.replicate` go via batched_device_put rather than `pxla.device_put`. Delete jax_jit.device_put since it is not used anywhere except for 1 test. Replace it with batched_device_put batched_device_put was fixed to correctly use the x64 flag so there is no need to canonicalize dtype anymore. Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now Remove the check for `if not isinstance(old_token, array.ArrayImpl)` since py_executable always return jax.Arrays Remove references to `jax.config.jax_jit_pjit_api_merge`, which is always True at head. Remove the helper jit functions from api.py [Jax cleanup] Remove dispatch.result_handlers since they are not used. Clean up pjit after jax.Array Remove in_positional_semantics and out_positional_semantics from xmap Error if jax_array or jax_jit_pjit_api_merge is set to False. Remove _PositionalSemantics class since it is not used anymore because jax.Array always has GLOBAL semantics Improve the empty mesh error message raised in pjit if mesh is not used and Pspec is passed to in|out_shardings Add benchmarks for accessing index and replica id in addressable_shards Optimize accessing `index` and `replica_id` of Remove pxla.OutputType enum class now that the only output can be jax.Array Remove C++ jit support since it has been replaced with Pjit. Keep `CompiledFunction` alive as a shim which cannot be instantiated but will work for isinstance checks. Raise a better error message when there is a device assignment mismatch via the apply_primitive route. Remove the config.jax_array and jax_jit_pjit_api_merge flag usage since those are always True Deprecated xla_call_p since it has been replaced with pjit.pjit_p After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host. If each host has the full value of the Array, allow fetching it to host. Fixes #15162 Add SDA deprecation warning to pytest.ini Add `src` argument to device_put as an experimental arg Delete benchmark and pmap_benchmark files as they are legacy and replaced with api_benchmark.py Temporarily fix the compilation cache test which is failing on latest jaxlib release Add deprecation warning for FROM_GDA usage since that argument is not required anymore. Prepare for jax and jaxlib 0.4.7 release Update the commit in workspace too Finish jax and jaxlib 0.4.7 release Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12 Deprecate FROM_GDA and remove its support from pjit's code since jax.Array inside pjit has sharding inference capabilities by default. Remove MeshPspecSharding since it has been more than 3 months since it was deprecated (Nov 2, 2022). The replacement name is NamedSharding. Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago Yu-Hang 'Maxin' Tang (1): add build option to create editable jaxlib Yuanzhong Xu (1): Enable more mesh shape assignment jax authors (3): jnp.einsum is parametrizable with dot_general. fix typo: "one of more" -> "one or more" Internal Code Change jiayaobo (3): add rayleigh distribution to random.py add wald random generator remove scale in wald docstring mehdiataei (1): Fixed spelling error in msgs pizzud (1): lazy_loader_module: Move to new internal_test_util directory. vfdev (1): Typo fix in ResizeMethod docstring, scale.py
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
See ticket #14989