Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion jax/experimental/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import jax
import jax.numpy as np
from jax import core
from jax import lax
from jax import ops
from jax.util import safe_map, safe_zip
Expand Down Expand Up @@ -141,7 +142,9 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):
y0: array or pytree of arrays representing the initial value for the state.
t: array of float times for evaluation, like `np.linspace(0., 10., 101)`,
in which the values must be strictly increasing.
*args: tuple of additional arguments for `func`.
*args: tuple of additional arguments for `func`, which must be arrays
scalars, or (nested) standard Python containers (tuples, lists, dicts,
namedtuples, i.e. pytrees) of those types.
rtol: float, relative local error tolerance for solver (optional).
atol: float, absolute local error tolerance for solver (optional).
mxstep: int, maximum number of steps to take for each timepoint (optional).
Expand All @@ -151,6 +154,12 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):
point in `t`, represented as an array (or pytree of arrays) with the same
shape/structure as `y0` except with a new leading axis of length `len(t)`.
"""
def _check_arg(arg):
if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg):
msg = ("The contents of odeint *args must be arrays or scalars, but got "
"\n{}.")
raise TypeError(msg.format(arg))
tree_map(_check_arg, args)
return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)

@partial(jax.jit, static_argnums=(0, 1, 2, 3))
Expand Down