Skip to content

Commit a15ac6b

Browse files
Merge pull request #32434 from mattjj:32399
PiperOrigin-RevId: 816456798
2 parents ec4f3de + 8a6a84f commit a15ac6b

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

jax/_src/core.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3548,7 +3548,14 @@ def substitute(aval: AbstractValue):
35483548
out_type = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
35493549
if type(d) is Var else d for d in a.shape))
35503550
if type(a) is DShapedArray else a for a in out_avals]
3551-
return out_type, call_jaxpr.effects
3551+
3552+
# jaxpr input effects are indexed to include jaxpr.constvars, but the eqn
3553+
# should have effects indexed only on its explicit arguments
3554+
effs = {e.replace(input_index=e.input_index - len(call_jaxpr.constvars))
3555+
if isinstance(e, effects.JaxprInputEffect)
3556+
else e for e in call_jaxpr.effects}
3557+
3558+
return out_type, effs
35523559

35533560
def _check_map(ctx_factory, prim, in_avals, params):
35543561
if "call_jaxpr" not in params:

jax/_src/pjit.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1980,7 +1980,15 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
19801980

19811981

19821982
def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
1983-
return jaxpr.out_avals, jaxpr.effects
1983+
# jaxpr input effects are indexed to include jaxpr.constvars, but the pjit eqn
1984+
# should have effects indexed only on its explicit arguments
1985+
if jaxpr.constvars:
1986+
effs = {e.replace(input_index=e.input_index - len(jaxpr.constvars))
1987+
if isinstance(e, effects.JaxprInputEffect)
1988+
else e for e in jaxpr.effects}
1989+
else:
1990+
effs = jaxpr.effects
1991+
return jaxpr.out_avals, effs
19841992
jit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
19851993

19861994

tests/jaxpr_effects_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,26 @@ def f(x):
257257
jaxpr = jax.make_jaxpr(f)(np.arange(jax.local_device_count()))
258258
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
259259

260+
def test_pjit_const_input_effect_indexing(self):
261+
# https://github.com/jax-ml/jax/issues/32399
262+
@jax.jit
263+
def bar(x, w):
264+
def scan_fn(x, _):
265+
c = jnp.array([])
266+
o = w[...] @ x
267+
x = jnp.concatenate([x, c], axis=-1)
268+
return x, None
269+
270+
x, _ = jax.lax.scan(scan_fn, x, None, length=10)
271+
return x
272+
273+
274+
@jax.jit
275+
def foo(w):
276+
return bar(jnp.zeros((1,)), w)
277+
278+
foo(jax.new_ref(jnp.eye(1))) # don't crash
279+
260280

261281
@jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls
262282
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)