Skip to content

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Nov 1, 2024

Add a GPU implementation of lax.linalg.eig.

This feature has been in the queue for a long time (see #1259), and some folks have found that they can use pure_callback to call the CPU version as a workaround. It has recently come up that there can be issues when using pure_callback with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing lax.linalg.eig on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on MAGMA can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the jax_gpu_use_magma configuration variable is set to "on". By default, we try to dlopen libmagma.so, but the path to a non-standard installation location can be specified using the JAX_GPU_MAGMA_PATH environment variable.

@PhilipVinc
Copy link
Contributor

Oh....

This is...

Amazing!

Thanks enormously for this, really, it's been on my secretive wishlist for so long...

@copybara-service copybara-service bot force-pushed the test_691072237 branch 5 times, most recently from 35199e1 to 1dd1a7b Compare November 7, 2024 18:04
@copybara-service copybara-service bot force-pushed the test_691072237 branch 2 times, most recently from 7738ef6 to fcf6754 Compare November 13, 2024 22:47
dkweiss31 added a commit to dynamiqs/dynamiqs that referenced this pull request Nov 14, 2024
… issue with jax versions >0.4.31 jax-ml/jax#24255 . It seems however that there is a GPU implementation of eig in the works jax-ml/jax#24663, making this pure callback stuff unnecesary
pierreguilmin added a commit to dynamiqs/dynamiqs that referenced this pull request Nov 18, 2024
* initial commit for floquet, not yet working

* floquet not working still, wip

* against my wishes I split up floquet and floquet_t: I found myself constantly using conditional logic, and it made more sense to just split them into separate functions since they are doing different things. Code is running but have not updated the tests yet

* bug fix

* renamed FloquetIntegratort -> FloquetIntegrator_t for better visual differentiation (still don't like the name tho)

* separate out FloquetQubit and FloquetQubit_t in tests

added more tests, all passing

* Towards allowing T batching, some batching tests failing still

* test for batching timecallable not working, while constant Hamiltonian does work

* Fixed batching by using internal methods _sepropagator and _floquet to avoid repeated calls to e.g. _flat_vectorize which was yielding weird/incorrect results

* Streamline batching tests, all passing

* ruff nits

* one more nit

* quasi_energies -> quasienergies

also save T as part of FloquetResult

added to mkdocs.yml

* floquet modes are vectors, so need the extra dimension of size 1 on the end

use my definition of sigmap

* various comment nits

* nit

* move _check_periodic into _check_floquet

* ruff

* can use final_propagator attribute, now we don't need the hacky options thing

* ruff nits

* wrap jnp.linalg.eig to ensure it is called on a cpu, as that function is not supported on gpus, see jax-ml/jax#1259

* ruff format and lint

* fix merge done with GUI

* per conversations with Ronan simplified the api to a single function `floquet` which first computes the t0=tsave[..., 0] floquet modes, then uses those to compute the floquet modes for the remaining times if tsave.shape[-1] > 1.

* added warning about not taking tsave % T, updated docstrings

* updated docs, small updates to docstrings

* got rid of unnecesary _broadcast_floquet_args function

* stop allowing for batching over tsave, for now

* forgot to remove duplicated test cases now that tsave batching not allowed

* Rename `quasiens` to `quasienergies` and `floquet_modes` to `modes`.

* Unify implementation of `floquet_t0` and `floquet_t` to call sepropagator only once.

* Remove the `safe` argument in favor of runtime error checking by equinox.

* Remove support for batching over `T` for now.

* use allclose instead of strict equality for floating point differences

* update tests to reflect no more batching on T :sadface:

* Additional changes flagged by Ronan re no more T batching

* updated eig_callback_cpu to reflect deprecations

* removed eig callback due to some of my tests hanging: this is a known issue with jax versions >0.4.31 jax-ml/jax#24255 . It seems however that there is a GPU implementation of eig in the works jax-ml/jax#24663, making this pure callback stuff unnecesary

* fix typo

* fixed error_if for periodic checking

* Added T vmap example

* Nit: docstring formulation and runtime checking condition

* Ruff formatting

* Minor fixes to FloquetResult

- Nit fixes to documentation.
- Add str representation for modes and quasienergies.
- Add example of cartesian batching.

* Minor fixes to apis/floquet.py

* Change type of drive period T to float

* Add missing imports to floquet API example code block

---------

Co-authored-by: Ronan Gautier <[email protected]>
Co-authored-by: pierreguilmin <[email protected]>
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
@copybara-service copybara-service bot merged commit ccb3317 into main Nov 18, 2024
1 check was pending
@copybara-service copybara-service bot deleted the test_691072237 branch November 18, 2024 16:12
@qiyang-ustc
Copy link

qiyang-ustc commented Dec 16, 2024

Amazing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants