|
36 | 36 | from jax._src import source_info_util
|
37 | 37 | from jax._src import test_util as jtu
|
38 | 38 | from jax._src import xla_bridge as xb
|
39 |
| -from jax.experimental import jax2tf |
40 |
| -from jax.experimental.jax2tf.tests import tf_test_util |
41 | 39 | from jax._src.shard_map import shard_map
|
42 | 40 | from jax.experimental import pjit
|
43 | 41 | from jax.sharding import PartitionSpec as P
|
44 | 42 |
|
45 | 43 | import numpy as np
|
46 |
| -import tensorflow as tf |
| 44 | +try: |
| 45 | + import tensorflow as tf |
| 46 | + from jax.experimental import jax2tf |
| 47 | + from jax.experimental.jax2tf.tests import tf_test_util |
| 48 | + JaxToTfTestCase = tf_test_util.JaxToTfTestCase |
| 49 | +except ImportError: |
| 50 | + tf = None |
| 51 | + jax2tf = None # type: ignore[assignment] |
| 52 | + tf_test_util = None # type: ignore[assignment] |
| 53 | + JaxToTfTestCase = jtu.JaxTestCase # type: ignore[misc] |
47 | 54 |
|
48 | 55 | config.parse_flags_with_absl()
|
49 | 56 |
|
50 | 57 |
|
| 58 | +@unittest.skipIf(tf is None, "Test requires tensorflow") |
51 | 59 | @jtu.thread_unsafe_test_class()
|
52 |
| -class Jax2TfTest(tf_test_util.JaxToTfTestCase): |
| 60 | +class Jax2TfTest(JaxToTfTestCase): |
53 | 61 |
|
54 | 62 | def setUp(self):
|
55 | 63 | super().setUp()
|
@@ -1209,7 +1217,7 @@ def f_simple(x):
|
1209 | 1217 | include_xla_op_metadata=False
|
1210 | 1218 | )
|
1211 | 1219 |
|
1212 |
| - def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str): |
| 1220 | + def assertAllOperationStartWith(self, g: "tf.Graph", scope_name: str): |
1213 | 1221 | """Assert all operations name start with ```scope_name```.
|
1214 | 1222 |
|
1215 | 1223 | Also the scope_name only occur one time.
|
@@ -1631,8 +1639,9 @@ def loss(features, params):
|
1631 | 1639 | )
|
1632 | 1640 |
|
1633 | 1641 |
|
| 1642 | +@unittest.skipIf(tf is None, "Test requires tensorflow") |
1634 | 1643 | @jtu.with_config(jax_enable_custom_prng=True)
|
1635 |
| -class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): |
| 1644 | +class Jax2tfWithCustomPRNGTest(JaxToTfTestCase): |
1636 | 1645 |
|
1637 | 1646 | def test_key_argument(self):
|
1638 | 1647 | func = lambda key: jax.random.uniform(key, ())
|
@@ -1661,7 +1670,8 @@ def func():
|
1661 | 1670 | self.assertEqual(tf_result, jax_result)
|
1662 | 1671 |
|
1663 | 1672 |
|
1664 |
| -class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): |
| 1673 | +@unittest.skipIf(tf is None, "Test requires tensorflow") |
| 1674 | +class Jax2TfVersioningTest(JaxToTfTestCase): |
1665 | 1675 | # Use a separate test case with the default jax_serialization_version
|
1666 | 1676 | def setUp(self):
|
1667 | 1677 | self.use_max_serialization_version = False
|
|
0 commit comments