Skip to content

Conversation

hawkinsp
Copy link
Collaborator

Fixes #779

@hawkinsp hawkinsp requested a review from mattjj May 29, 2019 21:15
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Is there a way to write a Scatter that avoids the explicit broadcast, or do you know that's not possible? (IIRC you mentioned that Scatter can't do this without the explicit broadcast, though even if it could, this approach is much more favorable to maintain our sanities. I'm not sure my mind could survive another descent into Gather and Scatter.)

@hawkinsp
Copy link
Collaborator Author

Is there a way to write a Scatter that avoids the explicit broadcast, or do you know that's not possible? (IIRC you mentioned that Scatter can't do this without the explicit broadcast, though even if it could, this approach is much more favorable to maintain our sanities. I'm not sure my mind could survive another descent into Gather and Scatter.)

I thought about this briefly and decided (a) I couldn't think of a better way of using scatter to do this, although I am not completely certain given how complex scatter is, and (b) this fixes David's immediate problem. So why not check this in now and revise it if we come up with a better way?

@hawkinsp hawkinsp merged commit c65ccda into jax-ml:master May 29, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

implement scatter batching rule when indices are batched and updates are not

3 participants