-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
I am doing an ode computation that I want to vectorize using vmap. I need gradients. This all works fine until I add vmap --- then I get an AssertionError.
Here is a stripped-down script to reproduce it:
import jax
import jax.numpy as np
from jax.experimental.ode import build_odeint
def dx_dt(x, *args):
return 0.1*x
ode = build_odeint(dx_dt)
def f(x):
y0 = np.array([x, 0.1])
t = np.array([0., 5.])
y = ode(y0, t)
return y[-1].sum()
def g(x):
# Two initial values for the ODE
y0_arr = np.array([[x, 0.1],
[x, 0.2]])
# Run ODE twice
t = np.array([0., 5.])
y = jax.vmap(lambda y0: ode(y0, t))(y0_arr)
return y[:,-1].sum()
# Works
print(f(1.))
print(jax.value_and_grad(f)(1.))
print(g(1.))
# Doesn't work
print(jax.value_and_grad(g)(1.))
Here is the output:
1.8135929
(DeviceArray(1.8135929, dtype=float32), DeviceArray(1.6487213, dtype=float32))
3.792058
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-6-1446cf27a440> in <module>
30 # Doesn't work
31 print(g(1.))
---> 32 print(jax.value_and_grad(g)(1.))
~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
416 f_partial, dyn_args = _argnums_partial(f, argnums, args)
417 if not has_aux:
--> 418 ans, vjp_py = _vjp(f_partial, *dyn_args)
419 else:
420 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in _vjp(fun, *primals, **kwargs)
1334 if not has_aux:
1335 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1336 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1337 out_tree = out_tree()
1338 else:
~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
104 def vjp(traceable, primals, has_aux=False):
105 if not has_aux:
--> 106 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
107 else:
108 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
96 pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
97 aval_primals, const_primals = unzip2(pval_primals)
---> 98 assert all(aval_primal is None for aval_primal in aval_primals)
99 if not has_aux:
100 return const_primals, pval_tangents, jaxpr, consts
AssertionError:
Can you let me know if I'm trying to do something that is not supported? Or is there some workaround?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working