Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 5 additions & 18 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,24 +285,11 @@ def zeros_like_batched(batched_args, batch_dims):


def bdim_at_front(x, bdim, broadcast_size=1, force_broadcast=False):
if bdim is None:
return broadcast(x, broadcast_size, force_broadcast=force_broadcast)
else:
return move_dim_to_front(x, bdim)
return moveaxis(broadcast_size, 0, bdim, x, force_broadcast=force_broadcast)

def move_dim_to_front(x, dim):
aval = get_aval(x)
if type(aval) is AbstractTuple:
return pack(map(partial(move_dim_to_front, dim=dim), x))
elif isinstance(aval, ShapedArray):
assert 0 <= dim < onp.ndim(x)
if dim == 0:
return x
else:
perm = (dim,) + tuple(range(dim)) + tuple(range(dim + 1, onp.ndim(x)))
return x.transpose(perm)
else:
raise TypeError(type(x))
assert dim is not None
return moveaxis(None, 0, dim, x)

def dimsize(dim, x):
aval = get_aval(x)
Expand All @@ -323,7 +310,7 @@ def dimsize(dim, x):
else:
raise TypeError(type(dim))

def moveaxis(sz, dst, src, x):
def moveaxis(sz, dst, src, x, force_broadcast=True):
aval = get_aval(x)
if type(aval) is AbstractTuple:
if type(src) is tuple and type(dst) is tuple:
Expand All @@ -341,7 +328,7 @@ def moveaxis(sz, dst, src, x):
return x
else:
if src is None:
x = broadcast(x, sz, force_broadcast=True)
x = broadcast(x, sz, force_broadcast=force_broadcast)
src = 0
dst_ = dst % (aval.ndim + 1)
if src == dst_:
Expand Down
11 changes: 5 additions & 6 deletions jax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3683,12 +3683,11 @@ def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr,
size = sizes.pop()
assert not sizes

if init_val_bd is None:
# TODO(mattjj): if cond_consts_bd is also None, we could keep cond_fun
# unbatched and avoid the masking logic, but we ignore that optimiztaion
init_val = batching.bdim_at_front(init_val, init_val_bd, size,
force_broadcast=True)
init_val_bd = 0
# TODO(mattjj): if cond_consts_bd is also None, we could keep cond_fun
# unbatched and avoid the masking logic, but we ignore that optimiztaion
init_val = batching.bdim_at_front(init_val, init_val_bd, size,
force_broadcast=True)
init_val_bd = 0

def batched_cond_fun(batched_loop_carry):
@lu.wrap_init
Expand Down
13 changes: 13 additions & 0 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,19 @@ def fun(x):
expected = (onp.array([10, 11]), onp.array([20, 20]))
self.assertAllClose(ans, expected, check_dtypes=False)

def testIssue489(self):
def f(key):
def body_fn(uk):
key = uk[1]
u = random.uniform(key, (), dtype=np.float64)
key, _ = random.split(key)
return u, key

u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
return u

print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash


if __name__ == '__main__':
absltest.main()