Skip to content

Conversation

sharadmv
Copy link
Collaborator

@sharadmv sharadmv commented Sep 1, 2021

Adjust the while loop fixed point to address the error in #7063.

Here's the relevant repro:

import jax
from jax import lax
from jax.experimental import maps

def f(x):
  z = x + lax.axis_index('a')
  y = x + lax.axis_index('b')
  def cond(carry):
    i, x = carry
    return x < 5
  def body(carry):
    i, x = carry
    return i + 1, x + lax.psum(y, 'b')
  return lax.while_loop(cond, body, (0, z))[1]
maps.xmap(f, axis_sizes=dict(a=2, b=10), out_axes=(['a']), in_axes={})(1.)

The translation rule for while loop requires that the predicate shape is a prefix for the carry shapes (if predicate is scalar, this is trivially true). However, when predicate is not a scalar, we need to make sure that this is the case. Previously, if the predicate was not a scalar, we'd batch every value in the carry and batch the predicate. This would ensure that the predicate shape remains a prefix of the carry shape, satisfying the condition in the while loop. However, this overbatching can cause problems.

In the repro, the batching rule for while is triggered twice, once for axis 'a' and once for axis 'b'. The body has the x value in the carry, which mapped across 'a', and closes over y (which has a 'b' axis). However, the input and output carry axes do not include 'b', because it is summed out with the psum.

With the current batching rule, for 'a' we properly batch the body and cond of the while loop. In the batching rule for 'b', specifically in the fixed point, we instantiate the output for the cond_jaxpr because bool(cond_jaxpr.avals.out_shape[0]) is True (this is because we have batched the cond for 'a'). This causes us to map the entire carry, which adds an axis to the carry for 'b' when none needs to exist.

What we'd instead like to do is not instantiate if the predicate is is non-scalar. We check the pred_bat produced by batching.batch_jaxpr (which tells us if the predicate is batched or not). If it is, we do need to batch everything and reproduce the instantiate=True behavior. If it is not, we need to make sure to have the batched predicate shape still be a prefix of the carry shapes. We accomplish this with some extra bookkeeping and modifications to batching.batch_jaxpr to allow us to thread axes in as opposed to bools indicating whether something is batched on the leading dimension.

Co-authored-by: Sharad Vikram [email protected]
Co-authored-by: Adam Paszke [email protected]

@google-cla google-cla bot added the cla: yes label Sep 1, 2021
@sharadmv
Copy link
Collaborator Author

sharadmv commented Sep 1, 2021

This is an updated version of #7206 that fixes some issues that popped up in internal tests.

@sharadmv
Copy link
Collaborator Author

sharadmv commented Sep 1, 2021

PTAL @apaszke

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Sep 2, 2021
Fixes jax-ml#7063

Co-authored-by: Sharad Vikram <[email protected]>
Co-authored-by: Adam Paszke <[email protected]>
@copybara-service copybara-service bot merged commit 5dba8cf into jax-ml:main Sep 3, 2021
@sharadmv sharadmv deleted the while-batching branch September 3, 2021 03:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants