Skip to content

Conversation

dkweiss31
Copy link
Collaborator

Closes DYN-231

Add Floquet solvers and associated tests. The API as it stands now supplies two functions: floquet and floquet_t for obtaining the floquet modes (and associated quasienergies) at t=0, and obtaining the floquet modes at other times t, respectively. I'm not thrilled about the name floquet_t, nor the name of e.g. the respective integrator FloquetIntegrator_t, so any suggestions for better names are welcome.

Copy link

linear bot commented Aug 27, 2024

@dkweiss31 dkweiss31 added feature New feature or request 🌀 solvers Topic: solvers labels Aug 27, 2024
@dkweiss31 dkweiss31 changed the title Dkweiss31/floquet Add floquet and floquet_t Aug 27, 2024
@gautierronan
Copy link
Collaborator

Quick comment raised by @BenjaminDAnjou: jnp.linalg.eig is only supported on the CPU backend for now (see jax-ml/jax#1259). This jax issue is live since 2019, so I don't think we can expect progress on this front.

Is there any other way to compute the Floquet basis @dkweiss31? Otherwise, maybe we could perform the workaround suggested in the linked issue, that is callback to the CPU implementation of jnp.linalg.eig using external callbacks. Note that this is also what PyTorch does for eig on GPU.

@dkweiss31
Copy link
Collaborator Author

Thanks for this, I hadn't appreciated that eig is only supported on CPUs. I'm not aware of another robust numerical technique for computing the floquet modes. Happy to try out the callbacks workaround

Copy link
Collaborator

@gautierronan gautierronan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @dkweiss31 for the PR, and sorry again for the late review.

In the most recent commits, I propose several changes to your PR:

  • Unified the implementation of floquet_t0 and floquet_t in a single integrator. This is mainly such that we don't need to recompute the propagator multiple times. Here, it is computed only once both for the t=t0 modes and for the later propagation of the t=t modes.
  • Removed batching over T. This is kinda annoying, but related to not having support for batching over tsave. I think it really simplifies the implementation. Again, I'm quite down to rethink this in a later PR. In the meantime, I would propose to simply provide an example in the main docstring to show how to batch over multiple drive periods using an external jax.vmap. This is rather straightforward to achieve actually.
  • Enforced that tsave[0] and tsave[-1] live within one period of each other. The other solution would indeed be to use a jnp.mod(tsave, T) to simplify the life of users, but for now, it simplifies our life haha.
  • Proposed to rename quasiens to quasienergies and floquet_modes to modes in an effort to have short but meaningful names. Happy to reconsider this one if you think it was better previously.
  • Removed the safe argument in favor of equinox runtime error checking, which is quite neat.

Let me know if all these changes work with you!

To finish the PR, I think there's mainly two things remaining:

  • Adding the docstring example for batching over T/tsave (drive period)
  • Fixing tests with this new version

@dkweiss31
Copy link
Collaborator Author

Hey @gautierronan , thank you very much for this review and for the significant simplifications you introduced. I am very happy with all of your proposed changes and will write up an example for T batching shortly.

I had been working with jax 0.4.29, and upgraded based on your comment that the vectorized option in pure_callback is deprecated. On updating how we call pure_callback I ran into issues with my tests hanging, which is apparently a known issue for jax > 0.4.31 jax-ml/jax#24255. It seems however that there is an implementation of eig for GPU in the works jax-ml/jax#24663 that should hopefully be merged soon. I propose then to just use jax.numpy.linalg.eig for now, and then make updates once the new GPU implementation is available.

gautierronan
gautierronan previously approved these changes Nov 15, 2024
Copy link
Collaborator

@gautierronan gautierronan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @dkweiss31! All good on my end! Let's wait for the review of @pierreguilmin and aim for merging next week :)

@dkweiss31
Copy link
Collaborator Author

Sounds great, thanks very much @gautierronan !

- Nit fixes to documentation.
- Add str representation for modes and quasienergies.
- Add example of cartesian batching.
pierreguilmin
pierreguilmin previously approved these changes Nov 18, 2024
Copy link
Collaborator

@pierreguilmin pierreguilmin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful work, congratulations @dkweiss31, that's an amazing addition to the library! 😍 And @gautierronan I loved how your thorough review helped it converge into this clean and simple API. I'm adding a few nits on top. Merging now!

@pierreguilmin pierreguilmin merged commit 9f961b9 into main Nov 18, 2024
1 check passed
@pierreguilmin pierreguilmin deleted the dkweiss31/floquet branch November 18, 2024 09:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request 🌀 solvers Topic: solvers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants