change while loop batching fixed point condition #7781
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Adjust the while loop fixed point to address the error in #7063.
Here's the relevant repro:
The translation rule for while loop requires that the predicate shape is a prefix for the carry shapes (if predicate is scalar, this is trivially true). However, when predicate is not a scalar, we need to make sure that this is the case. Previously, if the predicate was not a scalar, we'd batch every value in the carry and batch the predicate. This would ensure that the predicate shape remains a prefix of the carry shape, satisfying the condition in the while loop. However, this overbatching can cause problems.
In the repro, the batching rule for
while
is triggered twice, once for axis'a'
and once for axis'b'
. The body has thex
value in the carry, which mapped across'a'
, and closes overy
(which has a'b
' axis). However, the input and output carry axes do not include'b'
, because it is summed out with thepsum
.With the current batching rule, for
'a'
we properly batch the body and cond of the while loop. In the batching rule for'b'
, specifically in the fixed point, we instantiate the output for the cond_jaxpr becausebool(cond_jaxpr.avals.out_shape[0])
isTrue
(this is because we have batched the cond for'a'
). This causes us to map the entirecarry
, which adds an axis to the carry for'b'
when none needs to exist.What we'd instead like to do is not
instantiate
if the predicate is is non-scalar. We check thepred_bat
produced bybatching.batch_jaxpr
(which tells us if the predicate is batched or not). If it is, we do need to batch everything and reproduce theinstantiate=True
behavior. If it is not, we need to make sure to have the batched predicate shape still be a prefix of the carry shapes. We accomplish this with some extra bookkeeping and modifications tobatching.batch_jaxpr
to allow us to thread axes in as opposed to bools indicating whether something is batched on the leading dimension.Co-authored-by: Sharad Vikram [email protected]
Co-authored-by: Adam Paszke [email protected]