-
Notifications
You must be signed in to change notification settings - Fork 3.2k
hilbert transform #15121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hilbert transform #15121
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Thanks! I think as far as testing goes, you could follow some of the FFT tests as a template: https://github.com/google/jax/blob/4bd7f408c60266641ea8db2957ea458edcdcf50d/tests/fft_test.py#L231 But the tests for functions in this file should go in https://github.com/google/jax/blob/main/tests/scipy_signal_test.py |
I wrote some of those tests :) Looks like ya'll have improved on them substantially! E jax._src.dtypes.TypePromotionError: Input dtypes ('float32', 'complex64') have no available implicit dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting inputs to the desired output type, or set jax_numpy_dtype_promotion=standard. I'm getting some of these though and I can't quite grok how to make sure both sides are the same dtype. Can I get an assist please? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! A few comments below. Also please add hilbert
here so it will appear in the HTML docs: https://github.com/google/jax/blob/main/docs/jax.scipy.rst#jaxscipysignal
…ch via the apply_primitive route. PiperOrigin-RevId: 518282464
PiperOrigin-RevId: 518308356
The sharding tests are very important because jax2tf must do non-trivial manipulation of shardings, e.g., to wrap the inputs and the outputs with sharding annotations. Merged two test classes ShardedJitHloTest and Sharding. One of the classes was just checking the annotations in the TF HLO, the other one was just running the code and comparing the results. Now we do both in one test. Refactored the code to log JAX HLO and to check the occurrences of annotations in TF HLO. Now we support checking that the occurrence count is equal to a value or greater or equal to a value. Added more annotation checking. This makes the test more informative (because there is no other good way to check that sharding was applied correctly). But this makes it also possible that the test will fail when we change the JAX lowering. PiperOrigin-RevId: 518362978
…ce those are always True PiperOrigin-RevId: 518368963
PiperOrigin-RevId: 518412422
PiperOrigin-RevId: 518417901
PiperOrigin-RevId: 518418760
Since jax2tf.convert is called recursively for the purpose of serializing the vjp function, we must ensure that if the primal function is a pjit with shardings then the vjp function must also be converted as a pjit. Without this fix the serialization with gradients of a pjit function will fail the an error that there are shardings but not pjit at the top-level.
PiperOrigin-RevId: 518572905
add wald to random.py
PiperOrigin-RevId: 518597609
PiperOrigin-RevId: 518634912
This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a jaxpr with debug info (i.e. parameter names and result paths). The difference with the machinery in #15080 is just to deal with pmap being final-style (i.e. build the jaxpr at the last second, well after pytrees have been flattened away and transformations have been applied), whereas the machinery for pjit in imagine, plumbing for the former is a bit more long-range and subtle. The main idea here is that we need to annotate and maintain debug info on the lu.WrappedFun instance, which we first form at the api.py level, then pass through transformations (which can either update or drop debug info), then finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an annotation, parallel with the in_type annotation used for dynamic shapes, because the debug info has to be updated as transformations are applied, since they might e.g. add tangent inputs and outputs. In more detail: with an initial-style higher-orer primitive (like pjit), a jaxpr is formed immediately. Transformations, like autodiff, are jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return a new jaxpr either with updated debug info or no debug info at all. (The initial implementation in #15080 doesn't provide updated debug info in any of those jaxpr-to-jaxpr transformation functions, so the debug info is only applied to the jaxpr and then lowered to MLIR when the pjit as at the top level.) For final-style, like pmap here, instead of transformations being jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously, transformations, like ad.JVPTrace.process_map, would need to produce a WrappedFun with updated debug info or no debug info at all. (ALso analogously to #15080, this PR only implements enough for the debug info to be preserved for top-level pmaps.) This PR doens't yet delete the trace-time debug info in partial_eval.py. But that'll happen too!
PiperOrigin-RevId: 518969936
We're going to want to decompose these using series and continued fraction representations, and for that we'll need control flow PiperOrigin-RevId: 518977008
…serialization This is a pure refactor, no functionality should change. PiperOrigin-RevId: 518982222
PiperOrigin-RevId: 518984018
PiperOrigin-RevId: 518987577
We decompose it into a series or a call to igammac. PiperOrigin-RevId: 518993077
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
Now that all functionality needed by frameworks is implemented, let's remove the possibility of not noticing missing functionality due to the bypass. PiperOrigin-RevId: 519018438
PiperOrigin-RevId: 519120624
It is now decomposed into stablehlo ops. PiperOrigin-RevId: 519122775
PiperOrigin-RevId: 519148548
PiperOrigin-RevId: 519154537
client_loop: send disconnect: Broken pipe https://github.com/google/jax/actions/runs/4500333187/jobs/7919324156#step:8:42
…item instead of bouncing to host. PiperOrigin-RevId: 519170785
PiperOrigin-RevId: 519191714
PiperOrigin-RevId: 519194911
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0. PiperOrigin-RevId: 519211267
ugh, the rebase seems to have messed with the CLA. I might just make a new branch and reopen the PR, my GH branch juggling skills are not quite there.. plus we need to squash the commits anyway |
Follows SciPy's implementation except for the
ndim > 1
part.Tested locally to reproduce
osp.signal.hilbert
for even and odd N but could use some guidance on how to implement a proper test