Skip to content

Conversation

sharadmv
Copy link
Collaborator

As of #4008, there are no more default implementations of process_custom_jvp_call and process_custom_vjp_call. This results in an error whenever a function that has a custom jvp/vjp appears in a CallbackTrace.

This PR adds the default implementation, though it should probably preserve the custom info, since CallbackTrace keeps most primitive the same. We can iterate on the initial implementation below, if that is the best way forward.

process_custom_vjp_call to `jax.experimental.callback`
@google-cla google-cla bot added the cla: yes label Oct 16, 2020
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mattjj mattjj added the pull ready Ready for copybara import and testing label Oct 16, 2020
@copybara-service copybara-service bot merged commit 9ea1311 into jax-ml:master Oct 16, 2020
@sharadmv sharadmv deleted the callback-custom-jvp branch August 26, 2021 23:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants