Skip to content

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Mar 11, 2019

When a while_loop init_val is not mapped/batched in a vmap, we need to broadcast it out since the result of the body_fun will be mapped. Since the init_val can be a tuple in general, we need to handle mixed mapped/unmapped tuple elements, i.e. handle a mixture of NoneType and int bdims for init_val.

This also let us delete some redundant code.

This bug was another case of me not thinking in terms of tuple-input/tuple-output primitives, since those are rare (and while_loop is one of them).

fixes #489

When a while_loop init_val is not mapped/batched in a vmap, we need to
broadcast it out since the result of the body_fun will be mapped. Since
the init_val can be a tuple in general, we need to handle mixed
mapped/unmapped tuple elements, i.e. handle a mixture of NoneType and
int bdims for init_val.

This also let us delete some redundant code.

fixes #489
@mattjj mattjj self-assigned this Mar 11, 2019
@mattjj mattjj mentioned this pull request Mar 11, 2019
@mattjj mattjj merged commit 673bb11 into master Mar 11, 2019
@mattjj mattjj deleted the vmap-while branch March 11, 2019 23:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

while_loop vmap error

2 participants