Skip to content

stax.serial.apply_fun is not a valid JAX type inside odeint  #2920

@skrsna

Description

@skrsna

Hi,
FWIW, I'm using a self-built jax and jaxlib following instructions from #2083.

#
# Name                    Version                   Build  Channel
jax                       0.1.64                    <pip>
jaxlib                    0.1.45                    <pip>

I'm trying to do get gradients through an ODE solver. First, I ran into AssertionError issue #2718 and I think I solved it by passing all the arguments directly into odeint. Then I followed instructions to solve another AssertionError issue #2531 by doing vmap of grads instead of grads of vmap . Now I'm getting the following error.

Full trace back.

----> 1 batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], [U1,U2], [U1_params,U2_params])

~/Code/jax/jax/api.py in batched_fun(*args)
    805     _check_axis_sizes(in_tree, args_flat, in_axes_flat)
    806     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 807                               lambda: _flatten_axes(out_tree(), out_axes))
    808     return tree_unflatten(out_tree(), out_flat)
    809 

~/Code/jax/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     32   # executes a batched version of `fun` following out_dim_dests
     33   batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34   return batched_fun.call_wrapped(*in_vals)
     35 
     36 @lu.transformation_with_aux

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

~/Code/jax/jax/api.py in value_and_grad_f(*args, **kwargs)
    436     f_partial, dyn_args = argnums_partial(f, argnums, args)
    437     if not has_aux:
--> 438       ans, vjp_py = _vjp(f_partial, *dyn_args)
    439     else:
    440       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

~/Code/jax/jax/api.py in _vjp(fun, *primals, **kwargs)
   1437   if not has_aux:
   1438     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1439     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1440     out_tree = out_tree()
   1441   else:

~/Code/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    104 def vjp(traceable, primals, has_aux=False):
    105   if not has_aux:
--> 106     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    107   else:
    108     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/Code/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     93   _, in_tree = tree_flatten(((primals, primals), {}))
     94   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 95   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     96   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
     97   assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    435   with new_master(trace_type, bottom=bottom) as master:
    436     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 437     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    438     assert not env
    439     del master

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
    152     flat_fun, out_tree = flatten_fun(f, in_tree)
    153     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 154                        name=flat_fun.__name__)
    155     return tree_unflatten(out_tree(), out)
    156 

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params)
    342     name = params.get('name', f.__name__)
    343     params = dict(params, name=wrap_name(name, 'jvp'))
--> 344     result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
    345     primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
    346     return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    175     in_pvs, in_consts = unzip2([t.pval for t in tracers])
    176     fun, aux = partial_eval(f, self, in_pvs)
--> 177     out_flat = call_primitive.bind(fun, *in_consts, **params)
    178     out_pvs, jaxpr, env = aux()
    179     env_tracers = map(self.full_raise, env)

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/batching.py in process_call(self, call_primitive, f, tracers, params)
    146     else:
    147       f, dims_out = batch_subtrace(f, self.master, dims)
--> 148       vals_out = call_primitive.bind(f, *vals, **params)
    149       return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
    150 

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
    999   if top_trace is None:
   1000     with new_sublevel():
-> 1001       outs = primitive.impl(f, *args, **params)
   1002   else:
   1003     tracers = map(top_trace.full_raise, args)

~/Code/jax/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, *args)
    460 
    461 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name):
--> 462   compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
    463   try:
    464     return compiled_fun(*args)

~/Code/jax/jax/linear_util.py in memoized_fun(fun, *args)
    219       fun.populate_stores(stores)
    220     else:
--> 221       ans = call(fun, *args)
    222       cache[key] = (ans, fun.stores)
    223     return ans

~/Code/jax/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, *arg_specs)
    477   pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
    478   jaxpr, pvals, consts = pe.trace_to_jaxpr(
--> 479       fun, pvals, instantiate=False, stage_out=True, bottom=True)
    480 
    481   _map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))

~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    435   with new_master(trace_type, bottom=bottom) as master:
    436     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 437     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    438     assert not env
    439     del master

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

<ipython-input-17-de50dc731d85> in loss(batch_y0, batch_t, batch_y, params, ufuncs, uparams)
      1 @partial(jit, static_argnums=(4,))
      2 def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
----> 3     pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams)
      4     loss = np.mean(np.abs(pred_y-batch_y))
      5     return loss

~/Code/jax/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
    152     shape/structure as `y0` except with a new leading axis of length `len(t)`.
    153   """
--> 154   return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
    155 
    156 @partial(jax.jit, static_argnums=(0, 1, 2, 3))

~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
    149       dyn_args = args
    150     args_flat, in_tree = tree_flatten((dyn_args, kwargs))
--> 151     _check_args(args_flat)
    152     flat_fun, out_tree = flatten_fun(f, in_tree)
    153     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,

~/Code/jax/jax/api.py in _check_args(args)
   1558     if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
   1559       raise TypeError("Argument '{}' of type {} is not a valid JAX type"
-> 1560                       .format(arg, type(arg)))
   1561 
   1562 def _valid_jaxtype(arg):

TypeError: Argument '<function serial.<locals>.apply_fun at 0x2b06c3d6f7a0>' of type <class 'function'> is not a valid JAX type

I'm passing two stax.Serial modules with three Dense layers each as an input to odeint to integrate the Lotka-Volterra ODEs. ufuncs and uparams contains apply functions and params of stax.Serial module.

def lv_UDE(y,t,params,ufuncs,uparams):
    R, F = y
    alpha, theta = params
    U1, U2 = ufuncs
    U1_params, U2_params = uparams
    dRdt = alpha*R - U1(U1_params, y)
    dFdt = -theta*F + U2(U2_params, y)
    return np.array([dRdt,dFdt])

I'm trying to get gradients through an odeint w.r.t uparams. Is there a workaround to pass stax.Serial modules as an argument? Thanks in advance.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions