Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
given its name.

* Changes
* Additional checking for the versions of CUDA package dependencies was
reenabled, having been accidentally disabled in a previous release.
* JAX nightly packages are now published to artifact registry. To install
these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation).
* `jax.sharding.PartitionSpec` no longer inherits from a tuple.
Expand Down
136 changes: 0 additions & 136 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import pkgutil
import platform as py_platform
import threading
import traceback
from typing import Any, Sequence, Union
import warnings

Expand Down Expand Up @@ -311,141 +310,6 @@ def _check_cuda_compute_capability(devices_to_check):
)


def _check_cuda_versions(raise_on_first_error: bool = False,
debug: bool = False):
assert cuda_versions is not None
results: list[dict[str, Any]] = []

def _make_msg(name: str,
runtime_version: int,
build_version: int,
min_supported: int,
debug_msg: bool = False):
if debug_msg:
return (f"Package: {name}\n"
f"Version JAX was built against: {build_version}\n"
f"Minimum supported: {min_supported}\n"
f"Installed version: {runtime_version}")
if min_supported:
req_str = (f"The local installation version must be no lower than "
f"{min_supported}.")
else:
req_str = ("The local installation must be the same version as "
"the version against which JAX was built.")
msg = (f"Outdated {name} installation found.\n"
f"Version JAX was built against: {build_version}\n"
f"Minimum supported: {min_supported}\n"
f"Installed version: {runtime_version}\n"
f"{req_str}")
return msg

def _version_check(name: str,
get_version,
get_build_version,
scale_for_comparison: int = 1,
min_supported_version: int = 0):
"""Checks the runtime CUDA component version against the JAX one.

Args:
name: Of the CUDA component.
get_version: A function to get the local runtime version of the component.
get_build_version: A function to get the build version of the component.
scale_for_comparison: For rounding down a version to ignore patch/minor.
min_supported_version: An absolute minimum version required. Must be
passed without rounding down.

Raises:
RuntimeError: If the component is not found, or is of unsupported version,
and if raising the error is not deferred till later.
"""

build_version = get_build_version()
try:
version = get_version()
except Exception as e:
err_msg = f"Unable to load {name}. Is it installed?"
if raise_on_first_error:
raise RuntimeError(err_msg) from e
err_msg += f"\n{traceback.format_exc()}"
results.append({"name": name, "installed": False, "msg": err_msg})
return

if not min_supported_version:
min_supported_version = build_version // scale_for_comparison
passed = min_supported_version <= version

if not passed or debug:
msg = _make_msg(name=name,
runtime_version=version,
build_version=build_version,
min_supported=min_supported_version,
debug_msg=passed)
if not passed and raise_on_first_error:
raise RuntimeError(msg)
else:
record = {"name": name,
"installed": True,
"msg": msg,
"passed": passed,
"build_version": build_version,
"version": version,
"minimum_supported": min_supported_version}
results.append(record)

_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
cuda_versions.cuda_runtime_build_version,
scale_for_comparison=10,
min_supported_version=12010)
_version_check(
"cuDNN",
cuda_versions.cudnn_get_version,
cuda_versions.cudnn_build_version,
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
# versions:
# https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
scale_for_comparison=100,
min_supported_version=9100
)
_version_check("cuFFT", cuda_versions.cufft_get_version,
cuda_versions.cufft_build_version,
# Ignore patch versions.
scale_for_comparison=100)
_version_check("cuSOLVER", cuda_versions.cusolver_get_version,
cuda_versions.cusolver_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=11400)
_version_check("cuPTI", cuda_versions.cupti_get_version,
cuda_versions.cupti_build_version,
min_supported_version=18)
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=120100)
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
cuda_versions.cusparse_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=12100)

errors = []
debug_results = []
for result in results:
message: str = result['msg']
if not result['installed'] or not result['passed']:
errors.append(message)
else:
debug_results.append(message)

join_str = f'\n{"-" * 50}\n'
if debug_results:
print(f'CUDA components status (debug):\n'
f'{join_str.join(debug_results)}')
if errors:
raise RuntimeError(f'Unable to use CUDA because of the '
f'following issues with CUDA components:\n'
f'{join_str.join(errors)}')

def get_num_nodes_from_gpu_topology(topology: str) -> int:
try:
Expand Down
148 changes: 148 additions & 0 deletions jax_plugins/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import logging
import os
import pathlib
import traceback
from typing import Any

from jax._src.lib import triton
from jax._src.lib import xla_client
Expand All @@ -29,8 +31,12 @@
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'
)
cuda_versions = importlib.import_module(
f'{pkg_name}._versions'
)
except ImportError:
cuda_plugin_extension = None
cuda_versions = None
else:
break

Expand Down Expand Up @@ -76,11 +82,153 @@ def _get_library_path():
return None


def _check_cuda_versions(raise_on_first_error: bool = False,
debug: bool = False):
assert cuda_versions is not None
results: list[dict[str, Any]] = []

def _make_msg(name: str,
runtime_version: int,
build_version: int,
min_supported: int,
debug_msg: bool = False):
if debug_msg:
return (f"Package: {name}\n"
f"Version JAX was built against: {build_version}\n"
f"Minimum supported: {min_supported}\n"
f"Installed version: {runtime_version}")
if min_supported:
req_str = (f"The local installation version must be no lower than "
f"{min_supported}.")
else:
req_str = ("The local installation must be the same version as "
"the version against which JAX was built.")
msg = (f"Outdated {name} installation found.\n"
f"Version JAX was built against: {build_version}\n"
f"Minimum supported: {min_supported}\n"
f"Installed version: {runtime_version}\n"
f"{req_str}")
return msg

def _version_check(name: str,
get_version,
get_build_version,
scale_for_comparison: int = 1,
min_supported_version: int = 0):
"""Checks the runtime CUDA component version against the JAX one.

Args:
name: Of the CUDA component.
get_version: A function to get the local runtime version of the component.
get_build_version: A function to get the build version of the component.
scale_for_comparison: For rounding down a version to ignore patch/minor.
min_supported_version: An absolute minimum version required. Must be
passed without rounding down.

Raises:
RuntimeError: If the component is not found, or is of unsupported version,
and if raising the error is not deferred till later.
"""

build_version = get_build_version()
try:
version = get_version()
except Exception as e:
err_msg = f"Unable to load {name}. Is it installed?"
if raise_on_first_error:
raise RuntimeError(err_msg) from e
err_msg += f"\n{traceback.format_exc()}"
results.append({"name": name, "installed": False, "msg": err_msg})
return

if not min_supported_version:
min_supported_version = build_version // scale_for_comparison
passed = min_supported_version <= version

if not passed or debug:
msg = _make_msg(name=name,
runtime_version=version,
build_version=build_version,
min_supported=min_supported_version,
debug_msg=passed)
if not passed and raise_on_first_error:
raise RuntimeError(msg)
else:
record = {"name": name,
"installed": True,
"msg": msg,
"passed": passed,
"build_version": build_version,
"version": version,
"minimum_supported": min_supported_version}
results.append(record)

_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
cuda_versions.cuda_runtime_build_version,
scale_for_comparison=10,
min_supported_version=12010)
_version_check(
"cuDNN",
cuda_versions.cudnn_get_version,
cuda_versions.cudnn_build_version,
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
# versions:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/developer/forward-compatibility.html#cudnn-api-compatibility
scale_for_comparison=100,
)
_version_check("cuFFT", cuda_versions.cufft_get_version,
cuda_versions.cufft_build_version,
# Ignore patch versions.
scale_for_comparison=100)
_version_check("cuSOLVER", cuda_versions.cusolver_get_version,
cuda_versions.cusolver_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=11400)
_version_check("cuPTI", cuda_versions.cupti_get_version,
cuda_versions.cupti_build_version,
min_supported_version=18)
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=120100)
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
cuda_versions.cusparse_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=12100)

errors = []
debug_results = []
for result in results:
message: str = result['msg']
if not result['installed'] or not result['passed']:
errors.append(message)
else:
debug_results.append(message)

join_str = f'\n{"-" * 50}\n'
if debug_results:
print(f'CUDA components status (debug):\n'
f'{join_str.join(debug_results)}')
if errors:
raise RuntimeError(f'Unable to use CUDA because of the '
f'following issues with CUDA components:\n'
f'{join_str.join(errors)}')


def initialize():
path = _get_library_path()
if path is None:
return

if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"):
_check_cuda_versions(raise_on_first_error=True)
else:
print('Skipped CUDA versions constraints check due to the '
'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.')

options = xla_client.generate_pjrt_gpu_plugin_options()
c_api = xb.register_plugin(
'cuda', priority=500, library_path=str(path), options=options
Expand Down
Loading