Skip to content

Conversation

joglekara
Copy link
Contributor

Follows SciPy's implementation except for the ndim > 1 part.

Tested locally to reproduce osp.signal.hilbert for even and odd N but could use some guidance on how to implement a proper test

@google-cla
Copy link

google-cla bot commented Mar 21, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 21, 2023

Thanks! I think as far as testing goes, you could follow some of the FFT tests as a template: https://github.com/google/jax/blob/4bd7f408c60266641ea8db2957ea458edcdcf50d/tests/fft_test.py#L231

But the tests for functions in this file should go in https://github.com/google/jax/blob/main/tests/scipy_signal_test.py

@jakevdp jakevdp self-assigned this Mar 21, 2023
@joglekara
Copy link
Contributor Author

I wrote some of those tests :) Looks like ya'll have improved on them substantially!

E       jax._src.dtypes.TypePromotionError: Input dtypes ('float32', 'complex64') have no available implicit dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.

I'm getting some of these though and I can't quite grok how to make sure both sides are the same dtype. Can I get an assist please?

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! A few comments below. Also please add hilbert here so it will appear in the HTML docs: https://github.com/google/jax/blob/main/docs/jax.scipy.rst#jaxscipysignal

joglekara and others added 20 commits March 24, 2023 13:00
…ch via the apply_primitive route.

PiperOrigin-RevId: 518282464
The sharding tests are very important because jax2tf must do non-trivial manipulation of shardings, e.g., to wrap the inputs and the outputs with
sharding annotations.

Merged two test classes ShardedJitHloTest and Sharding. One of the classes was
just checking the annotations in the TF HLO, the other one was just running
the code and comparing the results. Now we do both in one test.

Refactored the code to log JAX HLO and to check the occurrences of annotations
in TF HLO. Now we support checking that the occurrence count is equal to a value
or greater or equal to a value.

Added more annotation checking. This makes the test more informative (because
there is no other good way to check that sharding was applied correctly).
But this makes it also possible that the test will fail when we change the
JAX lowering.

PiperOrigin-RevId: 518362978
…ce those are always True

PiperOrigin-RevId: 518368963
Since jax2tf.convert is called recursively for the purpose of
serializing the vjp function, we must ensure that if the primal
function is a pjit with shardings then the vjp function must also
be converted as a pjit.

Without this fix the serialization with gradients of a pjit function
will fail the an error that there are shardings but not pjit at
the top-level.
PiperOrigin-RevId: 518572905
add wald to random.py
PiperOrigin-RevId: 518597609
This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a
jaxpr with debug info (i.e. parameter names and result paths). The difference
with the machinery in #15080 is just to deal with pmap being final-style (i.e.
build the jaxpr at the last second, well after pytrees have been flattened away
and transformations have been applied), whereas the machinery for pjit in
imagine, plumbing for the former is a bit more long-range and subtle.

The main idea here is that we need to annotate and maintain debug info on the
lu.WrappedFun instance, which we first form at the api.py level, then pass
through transformations (which can either update or drop debug info), then
finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an
annotation, parallel with the in_type annotation used for dynamic shapes,
because the debug info has to be updated as transformations are applied, since
they might e.g. add tangent inputs and outputs.

In more detail: with an initial-style higher-orer primitive (like pjit), a
jaxpr is formed immediately. Transformations, like autodiff, are
jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return
a new jaxpr either with updated debug info or no debug info at all. (The initial
implementation in #15080 doesn't provide updated debug info in any of those
jaxpr-to-jaxpr transformation functions, so the debug info is only applied to
the jaxpr and then lowered to MLIR when the pjit as at the top level.)

For final-style, like pmap here, instead of transformations being
jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously,
transformations, like ad.JVPTrace.process_map, would need to produce a
WrappedFun with updated debug info or no debug info at all. (ALso analogously
to #15080, this PR only implements enough for the debug info to be preserved
for top-level pmaps.)

This PR doens't yet delete the trace-time debug info in partial_eval.py. But
that'll happen too!
jakevdp and others added 26 commits March 24, 2023 13:04
We're going to want to decompose these using series and
continued fraction representations, and for that we'll need
control flow

PiperOrigin-RevId: 518977008
…serialization

This is a pure refactor, no functionality should change.

PiperOrigin-RevId: 518982222
We decompose it into a series or a call to igammac.

PiperOrigin-RevId: 518993077
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.

PiperOrigin-RevId: 519018438
It is now decomposed into stablehlo ops.

PiperOrigin-RevId: 519122775
PiperOrigin-RevId: 519148548
PiperOrigin-RevId: 519154537
…item instead of bouncing to host.

PiperOrigin-RevId: 519170785
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
@joglekara
Copy link
Contributor Author

joglekara commented Mar 24, 2023

ugh, the rebase seems to have messed with the CLA. I might just make a new branch and reopen the PR, my GH branch juggling skills are not quite there..

plus we need to squash the commits anyway

@joglekara joglekara closed this by deleting the head repository Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.