-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Add a GPU implementation of lax.linalg.eig
.
#24663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
+1,214
−55
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oh.... This is... Amazing! Thanks enormously for this, really, it's been on my secretive wishlist for so long... |
35199e1
to
1dd1a7b
Compare
7738ef6
to
fcf6754
Compare
dkweiss31
added a commit
to dynamiqs/dynamiqs
that referenced
this pull request
Nov 14, 2024
… issue with jax versions >0.4.31 jax-ml/jax#24255 . It seems however that there is a GPU implementation of eig in the works jax-ml/jax#24663, making this pure callback stuff unnecesary
fcf6754
to
16f37cb
Compare
pierreguilmin
added a commit
to dynamiqs/dynamiqs
that referenced
this pull request
Nov 18, 2024
* initial commit for floquet, not yet working * floquet not working still, wip * against my wishes I split up floquet and floquet_t: I found myself constantly using conditional logic, and it made more sense to just split them into separate functions since they are doing different things. Code is running but have not updated the tests yet * bug fix * renamed FloquetIntegratort -> FloquetIntegrator_t for better visual differentiation (still don't like the name tho) * separate out FloquetQubit and FloquetQubit_t in tests added more tests, all passing * Towards allowing T batching, some batching tests failing still * test for batching timecallable not working, while constant Hamiltonian does work * Fixed batching by using internal methods _sepropagator and _floquet to avoid repeated calls to e.g. _flat_vectorize which was yielding weird/incorrect results * Streamline batching tests, all passing * ruff nits * one more nit * quasi_energies -> quasienergies also save T as part of FloquetResult added to mkdocs.yml * floquet modes are vectors, so need the extra dimension of size 1 on the end use my definition of sigmap * various comment nits * nit * move _check_periodic into _check_floquet * ruff * can use final_propagator attribute, now we don't need the hacky options thing * ruff nits * wrap jnp.linalg.eig to ensure it is called on a cpu, as that function is not supported on gpus, see jax-ml/jax#1259 * ruff format and lint * fix merge done with GUI * per conversations with Ronan simplified the api to a single function `floquet` which first computes the t0=tsave[..., 0] floquet modes, then uses those to compute the floquet modes for the remaining times if tsave.shape[-1] > 1. * added warning about not taking tsave % T, updated docstrings * updated docs, small updates to docstrings * got rid of unnecesary _broadcast_floquet_args function * stop allowing for batching over tsave, for now * forgot to remove duplicated test cases now that tsave batching not allowed * Rename `quasiens` to `quasienergies` and `floquet_modes` to `modes`. * Unify implementation of `floquet_t0` and `floquet_t` to call sepropagator only once. * Remove the `safe` argument in favor of runtime error checking by equinox. * Remove support for batching over `T` for now. * use allclose instead of strict equality for floating point differences * update tests to reflect no more batching on T :sadface: * Additional changes flagged by Ronan re no more T batching * updated eig_callback_cpu to reflect deprecations * removed eig callback due to some of my tests hanging: this is a known issue with jax versions >0.4.31 jax-ml/jax#24255 . It seems however that there is a GPU implementation of eig in the works jax-ml/jax#24663, making this pure callback stuff unnecesary * fix typo * fixed error_if for periodic checking * Added T vmap example * Nit: docstring formulation and runtime checking condition * Ruff formatting * Minor fixes to FloquetResult - Nit fixes to documentation. - Add str representation for modes and quasienergies. - Add example of cartesian batching. * Minor fixes to apis/floquet.py * Change type of drive period T to float * Add missing imports to floquet API example code block --------- Co-authored-by: Ronan Gautier <[email protected]> Co-authored-by: pierreguilmin <[email protected]>
16f37cb
to
eeb395f
Compare
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable. PiperOrigin-RevId: 697631402
eeb395f
to
ccb3317
Compare
Amazing! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add a GPU implementation of
lax.linalg.eig
.This feature has been in the queue for a long time (see #1259), and some folks have found that they can use
pure_callback
to call the CPU version as a workaround. It has recently come up that there can be issues when usingpure_callback
with JAX calls in the body (#24255; this should be investigated separately).This change adds a native solution for computing
lax.linalg.eig
on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on MAGMA can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the
jax_gpu_use_magma
configuration variable is set to"on"
. By default, we try to dlopenlibmagma.so
, but the path to a non-standard installation location can be specified using theJAX_GPU_MAGMA_PATH
environment variable.