From a272fb2f551e64a26f9ce2d8afffda3c3e6f5775 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 6 Oct 2025 17:05:13 -0700 Subject: [PATCH] Exclude tensorflow py_dep on free-threading and 3.14 builds TF doesn't produce wheels for these versions PiperOrigin-RevId: 815942724 --- jax/experimental/jax2tf/tests/BUILD | 3 +-- jax/experimental/jax2tf/tests/jax2tf_test.py | 24 ++++++++++++++------ jaxlib/jax.bzl | 18 +++++++-------- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/jax/experimental/jax2tf/tests/BUILD b/jax/experimental/jax2tf/tests/BUILD index fc07cc0b3e97..a7d599321318 100644 --- a/jax/experimental/jax2tf/tests/BUILD +++ b/jax/experimental/jax2tf/tests/BUILD @@ -78,8 +78,7 @@ jax_multiplatform_test( deps = [ ":tf_test_util", "//jax/experimental/jax2tf", - ] + py_deps([ - "tensorflow", + ] + py_deps("tensorflow") + py_deps([ "absl/testing", "absl/logging", ]), diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index dea74838f175..095af221a8f3 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -36,20 +36,28 @@ from jax._src import source_info_util from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax.experimental import jax2tf -from jax.experimental.jax2tf.tests import tf_test_util from jax._src.shard_map import shard_map from jax.experimental import pjit from jax.sharding import PartitionSpec as P import numpy as np -import tensorflow as tf +try: + import tensorflow as tf + from jax.experimental import jax2tf + from jax.experimental.jax2tf.tests import tf_test_util + JaxToTfTestCase = tf_test_util.JaxToTfTestCase +except ImportError: + tf = None + jax2tf = None # type: ignore[assignment] + tf_test_util = None # type: ignore[assignment] + JaxToTfTestCase = jtu.JaxTestCase # type: ignore[misc] config.parse_flags_with_absl() +@unittest.skipIf(tf is None, "Test requires tensorflow") @jtu.thread_unsafe_test_class() -class Jax2TfTest(tf_test_util.JaxToTfTestCase): +class Jax2TfTest(JaxToTfTestCase): def setUp(self): super().setUp() @@ -1209,7 +1217,7 @@ def f_simple(x): include_xla_op_metadata=False ) - def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str): + def assertAllOperationStartWith(self, g: "tf.Graph", scope_name: str): """Assert all operations name start with ```scope_name```. Also the scope_name only occur one time. @@ -1631,8 +1639,9 @@ def loss(features, params): ) +@unittest.skipIf(tf is None, "Test requires tensorflow") @jtu.with_config(jax_enable_custom_prng=True) -class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): +class Jax2tfWithCustomPRNGTest(JaxToTfTestCase): def test_key_argument(self): func = lambda key: jax.random.uniform(key, ()) @@ -1661,7 +1670,8 @@ def func(): self.assertEqual(tf_result, jax_result) -class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): +@unittest.skipIf(tf is None, "Test requires tensorflow") +class Jax2TfVersioningTest(JaxToTfTestCase): # Use a separate test case with the default jax_serialization_version def setUp(self): self.use_max_serialization_version = False diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 0a3ce0a1e3dc..6d191bf32d9e 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -21,7 +21,7 @@ load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFF load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@nvidia_wheel_versions//:versions.bzl", "NVIDIA_WHEEL_VERSIONS") -load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") +load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") load("@rules_python//python:defs.bzl", "py_library", "py_test") load("@test_shard_count//:test_shard_count.bzl", "USE_MINIMAL_SHARD_COUNT") @@ -69,14 +69,11 @@ PLATFORM_TAGS_DICT = { ("Windows", "AMD64"): ("win", "amd64"), } -# TODO(vam): remove this once zstandard builds against Python >3.13 -def get_zstandard(): - if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"): - return [] - return ["@pypi//zstandard"] - def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]): - if HERMETIC_PYTHON_VERSION in excluded_py_versions: + py_ver = HERMETIC_PYTHON_VERSION + if HERMETIC_PYTHON_VERSION_KIND == "ft": + py_ver += "-ft" + if py_ver in excluded_py_versions: return [] return [package] @@ -103,9 +100,10 @@ _py_deps = { "tensorflow_core": [], "tensorstore": get_optional_dep("@pypi//tensorstore"), "torch": [], - "tensorflow": ["@pypi//tensorflow"], + "tensorflow": get_optional_dep("@pypi//tensorflow", ["3.13-ft", "3.14", "3.14-ft"]), "tpu_ops": [], - "zstandard": get_zstandard(), + # TODO(vam): remove this once zstandard builds against Python >3.13 + "zstandard": get_optional_dep("@pypi//zstandard", ["3.13", "3.13-ft", "3.14", "3.14-ft"]), } def all_py_deps(excluded = []):