-
Notifications
You must be signed in to change notification settings - Fork 3.2k
hilbert transform #15121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hilbert transform #15121
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Thanks! I think as far as testing goes, you could follow some of the FFT tests as a template: https://github.com/google/jax/blob/4bd7f408c60266641ea8db2957ea458edcdcf50d/tests/fft_test.py#L231 But the tests for functions in this file should go in https://github.com/google/jax/blob/main/tests/scipy_signal_test.py |
I wrote some of those tests :) Looks like ya'll have improved on them substantially! E jax._src.dtypes.TypePromotionError: Input dtypes ('float32', 'complex64') have no available implicit dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting inputs to the desired output type, or set jax_numpy_dtype_promotion=standard. I'm getting some of these though and I can't quite grok how to make sure both sides are the same dtype. Can I get an assist please? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! A few comments below. Also please add hilbert
here so it will appear in the HTML docs: https://github.com/google/jax/blob/main/docs/jax.scipy.rst#jaxscipysignal
…ch via the apply_primitive route. PiperOrigin-RevId: 518282464
Follows SciPy's implementation except for the
ndim > 1
part.Tested locally to reproduce
osp.signal.hilbert
for even and odd N but could use some guidance on how to implement a proper test