From 011639cf3621c52c00ffc1a24abf7f4dacf19966 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 14 May 2025 21:28:15 +0000 Subject: [PATCH] Reenable CUDA version checks from Python. These had been accidentally broken at some point in the plugin switchover.. --- CHANGELOG.md | 2 + jax/_src/xla_bridge.py | 136 -------------------------------- jax_plugins/cuda/__init__.py | 148 +++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 136 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0c30132c169..9fd4e50304d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. given its name. * Changes + * Additional checking for the versions of CUDA package dependencies was + reenabled, having been accidentally disabled in a previous release. * JAX nightly packages are now published to artifact registry. To install these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation). * `jax.sharding.PartitionSpec` no longer inherits from a tuple. diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 72a16d5fbe5c..ce0c36fdcca4 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -31,7 +31,6 @@ import pkgutil import platform as py_platform import threading -import traceback from typing import Any, Sequence, Union import warnings @@ -311,141 +310,6 @@ def _check_cuda_compute_capability(devices_to_check): ) -def _check_cuda_versions(raise_on_first_error: bool = False, - debug: bool = False): - assert cuda_versions is not None - results: list[dict[str, Any]] = [] - - def _make_msg(name: str, - runtime_version: int, - build_version: int, - min_supported: int, - debug_msg: bool = False): - if debug_msg: - return (f"Package: {name}\n" - f"Version JAX was built against: {build_version}\n" - f"Minimum supported: {min_supported}\n" - f"Installed version: {runtime_version}") - if min_supported: - req_str = (f"The local installation version must be no lower than " - f"{min_supported}.") - else: - req_str = ("The local installation must be the same version as " - "the version against which JAX was built.") - msg = (f"Outdated {name} installation found.\n" - f"Version JAX was built against: {build_version}\n" - f"Minimum supported: {min_supported}\n" - f"Installed version: {runtime_version}\n" - f"{req_str}") - return msg - - def _version_check(name: str, - get_version, - get_build_version, - scale_for_comparison: int = 1, - min_supported_version: int = 0): - """Checks the runtime CUDA component version against the JAX one. - - Args: - name: Of the CUDA component. - get_version: A function to get the local runtime version of the component. - get_build_version: A function to get the build version of the component. - scale_for_comparison: For rounding down a version to ignore patch/minor. - min_supported_version: An absolute minimum version required. Must be - passed without rounding down. - - Raises: - RuntimeError: If the component is not found, or is of unsupported version, - and if raising the error is not deferred till later. - """ - - build_version = get_build_version() - try: - version = get_version() - except Exception as e: - err_msg = f"Unable to load {name}. Is it installed?" - if raise_on_first_error: - raise RuntimeError(err_msg) from e - err_msg += f"\n{traceback.format_exc()}" - results.append({"name": name, "installed": False, "msg": err_msg}) - return - - if not min_supported_version: - min_supported_version = build_version // scale_for_comparison - passed = min_supported_version <= version - - if not passed or debug: - msg = _make_msg(name=name, - runtime_version=version, - build_version=build_version, - min_supported=min_supported_version, - debug_msg=passed) - if not passed and raise_on_first_error: - raise RuntimeError(msg) - else: - record = {"name": name, - "installed": True, - "msg": msg, - "passed": passed, - "build_version": build_version, - "version": version, - "minimum_supported": min_supported_version} - results.append(record) - - _version_check("CUDA", cuda_versions.cuda_runtime_get_version, - cuda_versions.cuda_runtime_build_version, - scale_for_comparison=10, - min_supported_version=12010) - _version_check( - "cuDNN", - cuda_versions.cudnn_get_version, - cuda_versions.cudnn_build_version, - # NVIDIA promise both backwards and forwards compatibility for cuDNN patch - # versions: - # https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat - scale_for_comparison=100, - min_supported_version=9100 - ) - _version_check("cuFFT", cuda_versions.cufft_get_version, - cuda_versions.cufft_build_version, - # Ignore patch versions. - scale_for_comparison=100) - _version_check("cuSOLVER", cuda_versions.cusolver_get_version, - cuda_versions.cusolver_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=11400) - _version_check("cuPTI", cuda_versions.cupti_get_version, - cuda_versions.cupti_build_version, - min_supported_version=18) - _version_check("cuBLAS", cuda_versions.cublas_get_version, - cuda_versions.cublas_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=120100) - _version_check("cuSPARSE", cuda_versions.cusparse_get_version, - cuda_versions.cusparse_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=12100) - - errors = [] - debug_results = [] - for result in results: - message: str = result['msg'] - if not result['installed'] or not result['passed']: - errors.append(message) - else: - debug_results.append(message) - - join_str = f'\n{"-" * 50}\n' - if debug_results: - print(f'CUDA components status (debug):\n' - f'{join_str.join(debug_results)}') - if errors: - raise RuntimeError(f'Unable to use CUDA because of the ' - f'following issues with CUDA components:\n' - f'{join_str.join(errors)}') def get_num_nodes_from_gpu_topology(topology: str) -> int: try: diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 1be29326c95f..9df7fc69ff1a 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -17,6 +17,8 @@ import logging import os import pathlib +import traceback +from typing import Any from jax._src.lib import triton from jax._src.lib import xla_client @@ -29,8 +31,12 @@ cuda_plugin_extension = importlib.import_module( f'{pkg_name}.cuda_plugin_extension' ) + cuda_versions = importlib.import_module( + f'{pkg_name}._versions' + ) except ImportError: cuda_plugin_extension = None + cuda_versions = None else: break @@ -76,11 +82,153 @@ def _get_library_path(): return None +def _check_cuda_versions(raise_on_first_error: bool = False, + debug: bool = False): + assert cuda_versions is not None + results: list[dict[str, Any]] = [] + + def _make_msg(name: str, + runtime_version: int, + build_version: int, + min_supported: int, + debug_msg: bool = False): + if debug_msg: + return (f"Package: {name}\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}") + if min_supported: + req_str = (f"The local installation version must be no lower than " + f"{min_supported}.") + else: + req_str = ("The local installation must be the same version as " + "the version against which JAX was built.") + msg = (f"Outdated {name} installation found.\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}\n" + f"{req_str}") + return msg + + def _version_check(name: str, + get_version, + get_build_version, + scale_for_comparison: int = 1, + min_supported_version: int = 0): + """Checks the runtime CUDA component version against the JAX one. + + Args: + name: Of the CUDA component. + get_version: A function to get the local runtime version of the component. + get_build_version: A function to get the build version of the component. + scale_for_comparison: For rounding down a version to ignore patch/minor. + min_supported_version: An absolute minimum version required. Must be + passed without rounding down. + + Raises: + RuntimeError: If the component is not found, or is of unsupported version, + and if raising the error is not deferred till later. + """ + + build_version = get_build_version() + try: + version = get_version() + except Exception as e: + err_msg = f"Unable to load {name}. Is it installed?" + if raise_on_first_error: + raise RuntimeError(err_msg) from e + err_msg += f"\n{traceback.format_exc()}" + results.append({"name": name, "installed": False, "msg": err_msg}) + return + + if not min_supported_version: + min_supported_version = build_version // scale_for_comparison + passed = min_supported_version <= version + + if not passed or debug: + msg = _make_msg(name=name, + runtime_version=version, + build_version=build_version, + min_supported=min_supported_version, + debug_msg=passed) + if not passed and raise_on_first_error: + raise RuntimeError(msg) + else: + record = {"name": name, + "installed": True, + "msg": msg, + "passed": passed, + "build_version": build_version, + "version": version, + "minimum_supported": min_supported_version} + results.append(record) + + _version_check("CUDA", cuda_versions.cuda_runtime_get_version, + cuda_versions.cuda_runtime_build_version, + scale_for_comparison=10, + min_supported_version=12010) + _version_check( + "cuDNN", + cuda_versions.cudnn_get_version, + cuda_versions.cudnn_build_version, + # NVIDIA promise both backwards and forwards compatibility for cuDNN patch + # versions: + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/developer/forward-compatibility.html#cudnn-api-compatibility + scale_for_comparison=100, + ) + _version_check("cuFFT", cuda_versions.cufft_get_version, + cuda_versions.cufft_build_version, + # Ignore patch versions. + scale_for_comparison=100) + _version_check("cuSOLVER", cuda_versions.cusolver_get_version, + cuda_versions.cusolver_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=11400) + _version_check("cuPTI", cuda_versions.cupti_get_version, + cuda_versions.cupti_build_version, + min_supported_version=18) + _version_check("cuBLAS", cuda_versions.cublas_get_version, + cuda_versions.cublas_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=120100) + _version_check("cuSPARSE", cuda_versions.cusparse_get_version, + cuda_versions.cusparse_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=12100) + + errors = [] + debug_results = [] + for result in results: + message: str = result['msg'] + if not result['installed'] or not result['passed']: + errors.append(message) + else: + debug_results.append(message) + + join_str = f'\n{"-" * 50}\n' + if debug_results: + print(f'CUDA components status (debug):\n' + f'{join_str.join(debug_results)}') + if errors: + raise RuntimeError(f'Unable to use CUDA because of the ' + f'following issues with CUDA components:\n' + f'{join_str.join(errors)}') + + def initialize(): path = _get_library_path() if path is None: return + if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): + _check_cuda_versions(raise_on_first_error=True) + else: + print('Skipped CUDA versions constraints check due to the ' + 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') + options = xla_client.generate_pjrt_gpu_plugin_options() c_api = xb.register_plugin( 'cuda', priority=500, library_path=str(path), options=options