From 91eaa24b559fa7aab724a59025d5b90f50e6bb62 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 24 Apr 2020 01:41:41 -0700 Subject: [PATCH] add more ode tests --- tests/ode_test.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/ode_test.py b/tests/ode_test.py index 5a9a41535798..94b27a4eb646 100644 --- a/tests/ode_test.py +++ b/tests/ode_test.py @@ -65,6 +65,43 @@ def dynamics(y, t): jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, rtol=tol, atol=tol) + @jtu.skip_on_devices("tpu") + def test_decay(self): + def decay(y, t, arg1, arg2): + return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2) + + rng = onp.random.RandomState(0) + arg1 = rng.randn(3) + arg2 = rng.randn(3) + + def integrate(y0, ts): + return odeint(decay, y0, ts, arg1, arg2) + + y0 = rng.randn(3) + ts = np.linspace(0.1, 0.2, 4) + + tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 + jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, + rtol=tol, atol=tol) + + @jtu.skip_on_devices("tpu") + def test_swoop(self): + def swoop(y, t, arg1, arg2): + return np.array(y - np.sin(t) - np.cos(t) * arg1 + arg2) + + ts = np.array([0.1, 0.2]) + tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 + + y0 = np.linspace(0.1, 0.9, 10) + integrate = lambda y0, ts: odeint(swoop, y0, ts, 0.1, 0.2) + jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, + rtol=tol, atol=tol) + + big_y0 = np.linspace(1.1, 10.9, 10) + integrate = lambda y0, ts: odeint(swoop, big_y0, ts, 0.1, 0.3) + jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2, + rtol=tol, atol=tol) + if __name__ == '__main__': absltest.main()