Skip to content

Conversation

maxwillzq
Copy link
Contributor

See ticket #14989

copybara-service bot pushed a commit that referenced this pull request Mar 20, 2023
copybara-service bot pushed a commit that referenced this pull request Mar 20, 2023
copybara-service bot pushed a commit that referenced this pull request Mar 22, 2023
copybara-service bot pushed a commit that referenced this pull request Mar 22, 2023
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 22, 2023
copybara-service bot pushed a commit that referenced this pull request Mar 23, 2023
copybara-service bot pushed a commit that referenced this pull request Mar 23, 2023
copybara-service bot pushed a commit that referenced this pull request Mar 23, 2023
copybara-service bot pushed a commit that referenced this pull request Mar 23, 2023
@maxwillzq maxwillzq closed this Mar 29, 2023
clrpackages pushed a commit to clearlinux-pkgs/pypi-jax that referenced this pull request Apr 4, 2023
…0.4.8

Adam Paszke (1):
      Optimize canonicalize_shape

Anish Tondwalkar (12):
      geqrf_p and householder_product_p directly call custom_calls
      Eigh primitive is now a customcall
      [jax2tf] Add back_compat test for LuDecomposition
      LuDecomposition moved from fallback path to custom_call
      Refactor special functions into their own module.
      Migrate igamma_p off xla_fallback
      [jaxlib] fix build w/ depenency on stablehlo_serialization
      Migrate igammac_p off xla_fallback path
      Migrate igamma_grad_a_p off xla_fallback
      Migrate random_gamma_grad off xla_fallback
      Migrate besseli0e off xla_fallback
      Migrate regularized_incomplete_beta_p off xla_fallback

Blake Hechtman (1):
      [LAX:RBG] Allow any type to RngBitGenerator. BF16 values are heavily quantized for long distributions which leads to failing the distribution test but in reality the distributions match.

Colin Gaffney (1):
      Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing.

Cristian Garcia (2):
      Fix typing for register_pytree_with_keys
      add trailing-whitespace pre-commit hook

Emilio Cota (2):
      math_benchmark: add dot op
      math_benchmark: add --set_env flag

Etienne Pot (1):
      Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses

Frederic Bastien (6):
      Add reference to the C code where I was looking for them. Also add some high-level description of what is needed.
      Fix inspect_array_sharding with grad.
      WAR the dependency issue in the nightly CI container.
      Add missing file
      WAR ssh timeout like:
      remove another dependency not currently needed.

George Necula (29):
      Stop using version 1 of XlaCallModuleOp
      [jax2tf] Work around a bug in jax2tf tests, to unblock the test
      [jax2tf] Refactor the JAX native lowering to separate out the TF-specific parts
      [jax2tf] Fix platform enforcement for native serialization
      [jax2tf] Refactor the native lowering
      [jax2tf] Documentation for the native serialization mode
      [jax2tf] Rename experimental_native_lowering to native_serialization
      [jax2tf] Add the first version of a custom call backwards compatibility test
      [jax2tf] Add tests for approx_top_k.
      [jax2tf] Refactor the backwards compatibility tests.
      [jax2tf] Add backwards compatibility tests for lax.eigh custom calls
      [jax2tf] Add Sharding backward compatibility test
      [jax2tf] Minor improvement in an error message
      [jax2tf] Re-enable fixed tests
      [jax2tf] Improvements to the documentation
      [jax2tf] Fix test that requires non-native serialization
      [jax2tf] Add backward compatibility tests for qr custom calls
      [jax2tf] Update CHANGELOG for native serialization.
      [jax2tf] Ensure that the gradient function is serialized natively.
      [jax2tf] Clean up the jax2tf sharding_tests.
      [jax2tf] Fix grad of pjit in native lowering.
      [jax2tf] Fix tests broken by upgrade of XlaCallModule
      [jax2tf] Minor addition to the documentation
      [jax2tf] Create a jax_export library with JAX-only pieces for native serialization
      [jax2tf] A simple failing test on TPU with native serialization
      [jax2tf] Turn an error into a warning with native serialization
      [shape_poly] Fixed bug with dimension variables in unused args
      [shape_poly] Refactor the computation of the dimension variables in native serialization
      Remove jax2tf experimental_native_lowering.

Ivy Zheng (1):
      Add an optional `flatten_func` argument to custom node registration even when `flatten_with_keys` is given, for better perf for those in need.

Jake VanderPlas (32):
      Fix jnp.sort & jnp.vdot in no-jit mode
      jnp.argmin/max: correctly handle out argument
      Add xla garbage collection to gc.callback
      jnp.argsort: fix annotations & behavior under disable_jit
      checkify_test: avoid passing argument to at[i].get()
      jnp.ndarray.at: deprecate passing additional arguments by position
      [typing] better annotations for jnp.ndarray.at
      Remove leading underscores in jax._src.numpy.util
      [typing] add type annotations to index_update code
      [sparse] add BCOO lowering for div
      jnp.arange: better validation of inputs
      README: improve Colab TPU installation discussion
      Add regression test for #4780
      jnp.einsum: make signature match documentation
      jax.numpy reductions: validate axis for scalar input
      Improve error for tolist() and tobytes() on tracer objects
      DOC: remove jax 0.4.1 banner from index page
      Sharp bits: refer to ndarray.at in out-of-bound indexing discussion
      Improve error for indexing with string
      jnp.mean: fix incorrect return value for large arrays
      Document ShapeDtypeStruct
      jax.random: remove scale from wald function
      DOC: add formulae for distributions in jax.random
      internal: refactor array methods into separate private submodule
      Fix mypy issue in jax/experimental/jet.py
      lax_numpy: move quantile-based functions to reductions.py
      jax.typing: recommend instance check in Python 3.10 or newer
      [sparse] fix coo efficiency warning
      CI: add numpy & scipy to mypy env
      jax.scipy.linalg.expm: support batched inputs
      Add deprecation warnings for several top-level jax imports
      Add minimal pyproject.toml specifying build system

Jake Vanderplas (1):
      Copybara import of the project:

Jieying Luo (1):
      [PJRT C API] Add parsing PJRT client create options from json file.

John QiangZhang (2):
      Add padding option "SAME_LOWER" for ticket jax-ml/jax#14990
      [1/n] store embedded tf.graph to stablehlo.custom_call

Kevin Gleason (2):
      Add support for StableHLO Serialized Portable Artifacts in JAX2TF.
      Improve handling of dynamic shapes in jax native serialization

Mark Sandler (1):
      Fixes broken examples, and (invalid) comment for PartitionSpec

Matthew Johnson (14):
      [shard_map] bug fix: extend axis env in partial_eval_custom rule
      update docs to remove stale reference to laziness optimization
      use Partial to make ravel_pytree unflatteners jit-friendly
      [dynamic-shapes] don't require buf objects have dtype attribute
      add test for #7155, fixes #7155
      [pytrees] fix function underlying tree-flattening with keys
      make mlir arg and result names work with static_argnums/argnames
      make mlir arg and result names work with pmap
      separate register_pytree_node and register_pytree_with_keys tests
      add experimental jax_log_checkpoint_residuals option
      enable pjit recursive typechecking
      improve scan error messages
      fix jax.Array.round()
      [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit

Misha (1):
      Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions.

Neil Girdhar (1):
      Correct register_pytree_with_keys annotation

Parker Schuh (12):
      Move PyBuffer methods used by PyArray to c++.
      Rollforward with fixes: Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
      [Rollforward] Move PyBuffer methods used by PyArray to c++.
      Hide jit-of-pmap warning.
      Use batched_device_put for pxla.shard_sharded_device_array_slow_path.
      Add benchmarks for np.array, device_put, and _arrays.
      Implement copy_to_host_async and _value with a single call to
      Avoid extra construction of ShapedArray in array __getitem__.
      Avoid extra construction of ShapedArray in array __getitem__.
      Delete the C++ GetEnableJaxArray() flag.
      Redefine `compile_and_serialize` as `serialize(lowered.compile())`.
      Guard ArrayImpl checks by xla_extension_version.

Peter Hawkins (46):
      Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.
      Redefine jnp.DeviceArray as jax.Array.
      Increase precision of detrend test on TPU.
      Split source_info_util into its own Bazel target.
      Split Mesh and ResourceEnv into a new module jax._src.mesh.
      [JAX] Split _src/xla_bridge.py into a separate Bazel target.
      Split _src/profiler into a separate BUILD target.
      Split _src files custom_api_util, deprecations, effects and environment_info into separate Bazel targets.
      No changes.
      Move _src/tree_util.py into a separate Bazel target.
      Split _src/mesh into a separate Bazel target.
      Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
      Split basearray into separate Bazel module.
      Switch JAX to use the OpenXLA repository.
      Fix build breakage from OpenXLA switch.
      Improve type of jnp.mgrid[...].
      Make PRNG seed types more liberal in what they accept.
      Relax the argument type annotation of dynamic_index_in_dim.
      Make Tracer types on JaxprTrace more precise.
      Add device_buffers property to jax.Array type.
      Improve pytype inference for Sharding type.
      Mark jax.numpy.DeviceArray as deprecated. Use jax.Array instead.
      Relax argument type annotation for lax.dynamic_slice.
      Remove references to jax.config.jax_array, which is always True at head.
      Relax type of the `values` argument to .at[...].set(...) and friends.
      [JAX] Check for AttributeError from getattr(), not KeyError.
      [JAX] Delete ShardedDeviceArray.
      Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated.
      Relax type annotations on lax slicing functions.
      Fix mypy failures in jax2tf.
      Revert: `custom_vjp` symbolic zeros support
      [XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes.
      [XLA:Python] Allow passing ExecutableBuildOptions to outfeed receiver.
      Move jax._src.typing into a separate Bazel target.
      Increase minimum NumPy version to 1.21.
      Delete remote TPU support.
      Suppress mypy warnings about missing imports.
      Split dtype argument from other arguments in special functions.
      Add support for using pip-installed CUDA wheels.
      Fix duplicate definition of 'cuda' extra in setup.py.
      Use pytype_strict_library() in Bazel build rules.
      Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
      Split core.py and several files in an SCC with it into a separate Bazel build target.
      Update the CUDA installation instructions.
      Add version constraints to CUDA pip wheel dependencies.
      Recommend --local_test_jobs in bazel test command line on GPU.

Ravin Kumar (2):
      Update user_guides.rst
      Fix hessian llnk

Rebecca Chen (2):
      Silence some pytype errors.
      Silence some pytype errors.

Roy Frostig (1):
      `custom_vjp` symbolic zeros support

Ruoxin Sang (1):
      Fix typo "compileable"->"compilable".

Sharad Vikram (3):
      Add print statement to help debug spurious test failure
      Copy seq_lengths before creating descriptor
      Add jaxlib version guard for rnn test

Shawn Presser (1):
      autodidax: fix jaxpr_subcomp return type annotation

Skye Wanderman-Milne (5):
      Update versions and changelog for jax + jaxlib 0.4.6 release
      Remove PJRT C API bypass.
      Remove 'pjrt_c_api_unimplemented' pytest mark.
      Bump minimum jaxlib version from 0.4.6 to 0.4.7.
      Turn on PJRT C API by default.

Yash Katariya (52):
      Fix the type annotation on shard_args
      Use shard_args and global_result_handlers since the `aval_to_result_handler` and `dispatch.device_put` will be removed soon.
      Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
      Make `.devices()` a set rather than a list because the code looks at sharding.device_set
      [Rollback] Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
      Fix copy_array_to_devices_with_sharding by making it take a committed argument so that the Array created has the right semantics.
      [Rollback] Move PyBuffer methods used by PyArray to c++.
      Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
      [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
      Go via `_rewriting_take` if reducing on a dim for __getitem__ so that we can preserve the sharding and run it via XLA which will do sharding propagation.
      Bump minimum jaxlib version to 0.4.6 which means xla_extension_version == 137 and mlir_api_version == 45
      Don't use the dispatch.device_put path since it is deprecated and will be removed soon.
      Add Cuda 12 build configs to bazelrc
      If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too.
      Fix the usage of device_put_handlers since that is deprecated. Use batched_device_put instead
      Make `pxla.replicate` go via batched_device_put rather than `pxla.device_put`.
      Delete jax_jit.device_put since it is not used anywhere except for 1 test. Replace it with batched_device_put
      batched_device_put was fixed to correctly use the x64 flag so there is no need to canonicalize dtype anymore.
      Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
      Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now
      Remove the check for `if not isinstance(old_token, array.ArrayImpl)` since py_executable always return jax.Arrays
      Remove references to `jax.config.jax_jit_pjit_api_merge`, which is always True at head.
      Remove the helper jit functions from api.py
      [Jax cleanup]
      Remove dispatch.result_handlers since they are not used.
      Clean up pjit after jax.Array
      Remove in_positional_semantics and out_positional_semantics from xmap
      Error if jax_array or jax_jit_pjit_api_merge is set to False.
      Remove _PositionalSemantics class since it is not used anymore because jax.Array always has GLOBAL semantics
      Improve the empty mesh error message raised in pjit if mesh is not used and Pspec is passed to in|out_shardings
      Add benchmarks for accessing index and replica id in addressable_shards
      Optimize accessing `index` and `replica_id` of
      Remove pxla.OutputType enum class now that the only output can be jax.Array
      Remove C++ jit support since it has been replaced with Pjit. Keep `CompiledFunction` alive as a shim which cannot be instantiated but will work for isinstance checks.
      Raise a better error message when there is a device assignment mismatch via the apply_primitive route.
      Remove the config.jax_array and jax_jit_pjit_api_merge flag usage since those are always True
      Deprecated xla_call_p since it has been replaced with pjit.pjit_p
      After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host.
      If each host has the full value of the Array, allow fetching it to host. Fixes #15162
      Add SDA deprecation warning to pytest.ini
      Add `src` argument to device_put as an experimental arg
      Delete benchmark and pmap_benchmark files as they are legacy and replaced with api_benchmark.py
      Temporarily fix the compilation cache test which is failing on latest jaxlib release
      Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
      Prepare for jax and jaxlib 0.4.7 release
      Update the commit in workspace too
      Finish jax and jaxlib 0.4.7 release
      Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12
      Deprecate FROM_GDA and remove its support from pjit's code since jax.Array inside pjit has sharding inference capabilities by default.
      Remove MeshPspecSharding since it has been more than 3 months since it was deprecated (Nov 2, 2022). The replacement name is NamedSharding.
      Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning
      Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago

Yu-Hang 'Maxin' Tang (1):
      add build option to create editable jaxlib

Yuanzhong Xu (1):
      Enable more mesh shape assignment

jax authors (3):
      jnp.einsum is parametrizable with dot_general.
      fix typo: "one of more" -> "one or more"
      Internal Code Change

jiayaobo (3):
      add rayleigh distribution to random.py
      add wald random generator
      remove scale in wald docstring

mehdiataei (1):
      Fixed spelling error in msgs

pizzud (1):
      lazy_loader_module: Move to new internal_test_util directory.

vfdev (1):
      Typo fix in ResizeMethod docstring, scale.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant