Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Exclude tensorflow py_dep on free-threading and 3.14 builds
TF doesn't produce wheels for these versions

PiperOrigin-RevId: 815942724
  • Loading branch information
nitins17 authored and Google-ML-Automation committed Oct 7, 2025
commit a272fb2f551e64a26f9ce2d8afffda3c3e6f5775
3 changes: 1 addition & 2 deletions jax/experimental/jax2tf/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ jax_multiplatform_test(
deps = [
":tf_test_util",
"//jax/experimental/jax2tf",
] + py_deps([
"tensorflow",
] + py_deps("tensorflow") + py_deps([
"absl/testing",
"absl/logging",
]),
Expand Down
24 changes: 17 additions & 7 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,28 @@
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax._src.shard_map import shard_map
from jax.experimental import pjit
from jax.sharding import PartitionSpec as P

import numpy as np
import tensorflow as tf
try:
import tensorflow as tf
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
JaxToTfTestCase = tf_test_util.JaxToTfTestCase
except ImportError:
tf = None
jax2tf = None # type: ignore[assignment]
tf_test_util = None # type: ignore[assignment]
JaxToTfTestCase = jtu.JaxTestCase # type: ignore[misc]

config.parse_flags_with_absl()


@unittest.skipIf(tf is None, "Test requires tensorflow")
@jtu.thread_unsafe_test_class()
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
class Jax2TfTest(JaxToTfTestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -1209,7 +1217,7 @@ def f_simple(x):
include_xla_op_metadata=False
)

def assertAllOperationStartWith(self, g: tf.Graph, scope_name: str):
def assertAllOperationStartWith(self, g: "tf.Graph", scope_name: str):
"""Assert all operations name start with ```scope_name```.

Also the scope_name only occur one time.
Expand Down Expand Up @@ -1631,8 +1639,9 @@ def loss(features, params):
)


@unittest.skipIf(tf is None, "Test requires tensorflow")
@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
class Jax2tfWithCustomPRNGTest(JaxToTfTestCase):

def test_key_argument(self):
func = lambda key: jax.random.uniform(key, ())
Expand Down Expand Up @@ -1661,7 +1670,8 @@ def func():
self.assertEqual(tf_result, jax_result)


class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
@unittest.skipIf(tf is None, "Test requires tensorflow")
class Jax2TfVersioningTest(JaxToTfTestCase):
# Use a separate test case with the default jax_serialization_version
def setUp(self):
self.use_max_serialization_version = False
Expand Down
18 changes: 8 additions & 10 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFF
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@nvidia_wheel_versions//:versions.bzl", "NVIDIA_WHEEL_VERSIONS")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND")
load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library")
load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("@test_shard_count//:test_shard_count.bzl", "USE_MINIMAL_SHARD_COUNT")
Expand Down Expand Up @@ -69,14 +69,11 @@ PLATFORM_TAGS_DICT = {
("Windows", "AMD64"): ("win", "amd64"),
}

# TODO(vam): remove this once zstandard builds against Python >3.13
def get_zstandard():
if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"):
return []
return ["@pypi//zstandard"]

def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]):
if HERMETIC_PYTHON_VERSION in excluded_py_versions:
py_ver = HERMETIC_PYTHON_VERSION
if HERMETIC_PYTHON_VERSION_KIND == "ft":
py_ver += "-ft"
if py_ver in excluded_py_versions:
return []
return [package]

Expand All @@ -103,9 +100,10 @@ _py_deps = {
"tensorflow_core": [],
"tensorstore": get_optional_dep("@pypi//tensorstore"),
"torch": [],
"tensorflow": ["@pypi//tensorflow"],
"tensorflow": get_optional_dep("@pypi//tensorflow", ["3.13-ft", "3.14", "3.14-ft"]),
"tpu_ops": [],
"zstandard": get_zstandard(),
# TODO(vam): remove this once zstandard builds against Python >3.13
"zstandard": get_optional_dep("@pypi//zstandard", ["3.13", "3.13-ft", "3.14", "3.14-ft"]),
}

def all_py_deps(excluded = []):
Expand Down
Loading