Skip to content

vmap + odeint + grad --> AssertionError #2531

@dsheldon

Description

@dsheldon

I am doing an ode computation that I want to vectorize using vmap. I need gradients. This all works fine until I add vmap --- then I get an AssertionError.

Here is a stripped-down script to reproduce it:

import jax
import jax.numpy as np
from jax.experimental.ode import build_odeint

def dx_dt(x, *args):
    return 0.1*x

ode = build_odeint(dx_dt)

def f(x):
    y0 = np.array([x, 0.1])
    t = np.array([0., 5.])
    y = ode(y0, t)
    return y[-1].sum()

def g(x):
    # Two initial values for the ODE
    y0_arr = np.array([[x, 0.1],
                       [x, 0.2]])
    
    # Run ODE twice
    t = np.array([0., 5.])
    y = jax.vmap(lambda y0: ode(y0, t))(y0_arr)
    return y[:,-1].sum()

# Works
print(f(1.))
print(jax.value_and_grad(f)(1.))
print(g(1.))

# Doesn't work
print(jax.value_and_grad(g)(1.))

Here is the output:

1.8135929
(DeviceArray(1.8135929, dtype=float32), DeviceArray(1.6487213, dtype=float32))
3.792058
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-6-1446cf27a440> in <module>
     30 # Doesn't work
     31 print(g(1.))
---> 32 print(jax.value_and_grad(g)(1.))

~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    416     f_partial, dyn_args = _argnums_partial(f, argnums, args)
    417     if not has_aux:
--> 418       ans, vjp_py = _vjp(f_partial, *dyn_args)
    419     else:
    420       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in _vjp(fun, *primals, **kwargs)
   1334   if not has_aux:
   1335     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1336     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1337     out_tree = out_tree()
   1338   else:

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    104 def vjp(traceable, primals, has_aux=False):
    105   if not has_aux:
--> 106     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    107   else:
    108     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/opt/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     96   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
     97   aval_primals, const_primals = unzip2(pval_primals)
---> 98   assert all(aval_primal is None for aval_primal in aval_primals)
     99   if not has_aux:
    100     return const_primals, pval_tangents, jaxpr, consts

AssertionError: 

Can you let me know if I'm trying to do something that is not supported? Or is there some workaround?

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions