Skip to content

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Oct 16, 2023

This is adds a JAX Enhancement Proposal (JEP) discussing the scope of NumPy & SciPy wrappers going forward. Comments welcome!

Rendered preview: https://jax--18137.org.readthedocs.build/en/18137/jep/18137-numpy-scipy-scope.html

@jakevdp jakevdp self-assigned this Oct 16, 2023
@jakevdp jakevdp marked this pull request as draft October 16, 2023 19:00
@jakevdp jakevdp added the JEP JAX enhancement proposal label Oct 16, 2023
@jakevdp jakevdp changed the title Initial draft of Numpy/Scipy scope JEP JEP 18137: Scope of JAX NumPy & SciPy Wrappers Oct 16, 2023
@chrisflesher
Copy link
Contributor

chrisflesher commented Oct 18, 2023

First of all I'd like to thank the JAX team for your efforts. You guys are my heros!

Our group is currently using JAX to develop underwater mapping software. We have found it great to easily convert existing code from standard numpy / scipy to JAX. Part of the reason why I wrote the Rotation / Slerp classes was because we were able to convert 3000+ lines of existing code to JAX in about an hour, pretty cool!

The main thing I'd like to question is Axis 5: Functional vs. Object-Oriented APIs.

Is there a more explicit way to determine whether or not a class implementation is acceptable? Like a checklist or something? For example I'm wondering what the issue with the Rotation class is? In my mind it seemed like a good fit because:

  • it uses the standard scipy API
  • the class is compatible with vmap (due to jax.numpy.vectorize decorators)
  • the class is compatible with grad (due to single JAX array)
  • unclear what a decent 3rd party alternative would be (we originally tried jax_transformations3d but found it difficult to use in practice)

For the Slerp class I get your point how it might be better to use a pure function here. For us the main benefit was being able to use a standard API to port existing code.

I had submitted a PR for the CubicHermiteSpline class recently, this follows a similar design pattern as Rotation and Slerp. I was hoping to eventually use this to eventually to add a loss function to the optax repo.

Think short term it totally makes sense that you guys are trying to find ways to live within your available resources. However it makes me wonder if the long term vision is really to avoid implementing most of the numpy / scipy API? Some of this stuff seems like it could help other people port existing code to JAX (e.g. scipy.interpolate and scipy.ndimage).

@jakevdp
Copy link
Collaborator Author

jakevdp commented Oct 18, 2023

However it makes me wonder if the long term vision is really to avoid implementing most of the numpy / scipy API?

This document is an attempt to write down the long-term vision. We have never intended to implement the entire numpy/scipy API as part of the core JAX package (though of course we would welcome partner projects which implement parts of it that don't fit into jax itself).

For example, if the optax loss function needs a cubic hermite spline, is there any reason not to implement that spline functionality in optax?

@chrisflesher
Copy link
Contributor

chrisflesher commented Oct 19, 2023

I closed the CubicHermiteSpline PR, it sounds like it is being discouraged.

@chrisflesher
Copy link
Contributor

chrisflesher commented Oct 22, 2023

If you guys want to remove Rotation and Slerp from this repo I started a 3rd party version here:
https://github.com/chrisflesher/jax-scipy-spatial

@jakevdp jakevdp requested a review from froystig October 30, 2023 20:05
@jakevdp jakevdp marked this pull request as ready for review October 30, 2023 20:05
This was referenced Nov 3, 2023
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 4, 2023
@copybara-service copybara-service bot merged commit 28b512a into jax-ml:main Nov 4, 2023
@jakevdp jakevdp deleted the jep-numpy-scipy branch November 4, 2023 02:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

JEP JAX enhancement proposal pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants