Skip to content

Conversation

dpiponi
Copy link
Contributor

@dpiponi dpiponi commented Feb 5, 2020

No description provided.

@xingyousong
Copy link

+1 much support

@ValeryTyumen
Copy link

+1

@shoyer
Copy link
Collaborator

shoyer commented May 20, 2020

How much code does this generate in the JAXpr/XLA? I am a little worried that this could generate surprisingly large amounts of code.

@mattjj
Copy link
Collaborator

mattjj commented May 20, 2020

@shoyer the amount of code is a small constant factor times the logarithm of the sequence length. Months ago, when Dan first showed this to me, he demoed it on a sequence of length 1M with no problem IIRC.

@shoyer
Copy link
Collaborator

shoyer commented May 20, 2020

Very nice, I notice now that it only calls _scan recursively exactly once.

@mattjj mattjj self-assigned this May 21, 2020
@mattjj mattjj self-requested a review May 21, 2020 15:06
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks so much, @dpiponi. Sorry I took 3.5 months longer to merge this than I said...

@mattjj
Copy link
Collaborator

mattjj commented May 21, 2020

Internal tests pass!

@mattjj mattjj merged commit c459280 into jax-ml:master May 21, 2020
@AdrienCorenflos
Copy link
Contributor

Hi,
It would maybe be good to precise that this is the work-efficient implementation of the prefix sum that you implemented, at least if I read the algorithm correctly.
I can do it, but just wanted to check that I wasn't misreading the recursion.
Adrien

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants