Skip to content

Commit d03fffb

Browse files
nitins17Google-ML-Automation
authored andcommitted
Exclude tensorflow py_dep on free-threading and 3.14 builds
TF doesn't produce wheels for these versions PiperOrigin-RevId: 815756217
1 parent 6fd0b21 commit d03fffb

File tree

3 files changed

+25
-19
lines changed

3 files changed

+25
-19
lines changed

jax/experimental/jax2tf/tests/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ jax_multiplatform_test(
7878
deps = [
7979
":tf_test_util",
8080
"//jax/experimental/jax2tf",
81-
] + py_deps([
82-
"tensorflow",
81+
] + py_deps("tensorflow") + py_deps([
8382
"absl/testing",
8483
"absl/logging",
8584
]),

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,28 @@
3636
from jax._src import source_info_util
3737
from jax._src import test_util as jtu
3838
from jax._src import xla_bridge as xb
39-
from jax.experimental import jax2tf
40-
from jax.experimental.jax2tf.tests import tf_test_util
4139
from jax._src.shard_map import shard_map
4240
from jax.experimental import pjit
4341
from jax.sharding import PartitionSpec as P
4442

4543
import numpy as np
46-
import tensorflow as tf
44+
try:
45+
import tensorflow as tf
46+
from jax.experimental import jax2tf
47+
from jax.experimental.jax2tf.tests import tf_test_util
48+
JaxToTfTestCase = tf_test_util.JaxToTfTestCase
49+
except ImportError:
50+
tf = None
51+
jax2tf = None # type: ignore[assignment]
52+
tf_test_util = None # type: ignore[assignment]
53+
JaxToTfTestCase = jtu.JaxTestCase # type: ignore[misc]
4754

4855
config.parse_flags_with_absl()
4956

5057

58+
@unittest.skipIf(tf is None, "Test requires tensorflow")
5159
@jtu.thread_unsafe_test_class()
52-
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
60+
class Jax2TfTest(JaxToTfTestCase):
5361

5462
def setUp(self):
5563
super().setUp()
@@ -1209,7 +1217,7 @@ def f_simple(x):
12091217
include_xla_op_metadata=False
12101218
)
12111219

1212-
def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str):
1220+
def assertAllOperationStartWith(self, g: "tf.Graph", scope_name: str):
12131221
"""Assert all operations name start with ```scope_name```.
12141222
12151223
Also the scope_name only occur one time.
@@ -1631,8 +1639,9 @@ def loss(features, params):
16311639
)
16321640

16331641

1642+
@unittest.skipIf(tf is None, "Test requires tensorflow")
16341643
@jtu.with_config(jax_enable_custom_prng=True)
1635-
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
1644+
class Jax2tfWithCustomPRNGTest(JaxToTfTestCase):
16361645

16371646
def test_key_argument(self):
16381647
func = lambda key: jax.random.uniform(key, ())
@@ -1661,7 +1670,8 @@ def func():
16611670
self.assertEqual(tf_result, jax_result)
16621671

16631672

1664-
class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
1673+
@unittest.skipIf(tf is None, "Test requires tensorflow")
1674+
class Jax2TfVersioningTest(JaxToTfTestCase):
16651675
# Use a separate test case with the default jax_serialization_version
16661676
def setUp(self):
16671677
self.use_max_serialization_version = False

jaxlib/jax.bzl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFF
2121
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
2222
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
2323
load("@nvidia_wheel_versions//:versions.bzl", "NVIDIA_WHEEL_VERSIONS")
24-
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
24+
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND")
2525
load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library")
2626
load("@rules_python//python:defs.bzl", "py_library", "py_test")
2727
load("@test_shard_count//:test_shard_count.bzl", "USE_MINIMAL_SHARD_COUNT")
@@ -69,14 +69,11 @@ PLATFORM_TAGS_DICT = {
6969
("Windows", "AMD64"): ("win", "amd64"),
7070
}
7171

72-
# TODO(vam): remove this once zstandard builds against Python >3.13
73-
def get_zstandard():
74-
if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"):
75-
return []
76-
return ["@pypi//zstandard"]
77-
7872
def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]):
79-
if HERMETIC_PYTHON_VERSION in excluded_py_versions:
73+
py_ver = HERMETIC_PYTHON_VERSION
74+
if HERMETIC_PYTHON_VERSION_KIND == "ft":
75+
py_ver += "-ft"
76+
if py_ver in excluded_py_versions:
8077
return []
8178
return [package]
8279

@@ -103,9 +100,9 @@ _py_deps = {
103100
"tensorflow_core": [],
104101
"tensorstore": get_optional_dep("@pypi//tensorstore"),
105102
"torch": [],
106-
"tensorflow": ["@pypi//tensorflow"],
103+
"tensorflow": get_optional_dep("@pypi//tensorflow", ["3.13-ft", "3.14", "3.14-ft"]),
107104
"tpu_ops": [],
108-
"zstandard": get_zstandard(),
105+
"zstandard": get_optional_dep("@pypi//zstandard", ["3.13", "3.13-ft", "3.14", "3.14-ft"]),
109106
}
110107

111108
def all_py_deps(excluded = []):

0 commit comments

Comments
 (0)