Source code for thesis.core.gpu

"""
GPU availability detection for the thesis framework.

Provides a single ``check_gpu()`` function that tests whether a compatible
GPU probtrackx2 binary and a suitable CUDA runtime are available on the
current system.  This check is intended to run **once at CLI startup** so
that every downstream tool sees a consistent ``config.hardware.gpu_enabled``
value, rather than each interface performing its own ad-hoc detection.

Typical usage::

    from thesis.core.gpu import check_gpu

    status = check_gpu()
    if status.available:
        print(f"GPU ready: {status.reason}")
    else:
        print(f"GPU unavailable: {status.reason}")
"""

import os
import re
import shutil
import subprocess
from typing import NamedTuple, Optional, Tuple

from thesis.core.logging import get_logger

logger = get_logger(__name__)

__all__ = ["GPUStatus", "GPU_BINARIES", "check_gpu", "present_gpu_binaries"]

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

# Candidate GPU probtrackx2 binaries in preference order.
# Versioned names (e.g. probtrackx2_gpu11.0) encode the minimum required
# CUDA major.minor version.  This is the single source of truth for the
# candidate list — other modules (e.g. the FSL interfaces) import it rather
# than redefining their own copy.
GPU_BINARIES = ["probtrackx2_gpu11.0", "probtrackx2_gpu"]

# Parses the required CUDA version from a versioned binary name.
_VERSION_RE = re.compile(r"gpu(\d+)\.(\d+)$")

# Parses "CUDA Version: X.Y" from nvidia-smi's header line.
_CUDA_VERSION_RE = re.compile(r"CUDA\s+Version:\s*(\d+)\.(\d+)", re.IGNORECASE)


# ---------------------------------------------------------------------------
# Result type
# ---------------------------------------------------------------------------


[docs] class GPUStatus(NamedTuple): """ Result of a GPU availability check. Attributes: available: ``True`` if a compatible GPU binary and CUDA runtime were both found. binary: Name of the selected GPU binary (e.g. ``"probtrackx2_gpu11.0"``), or ``None`` if unavailable. cuda_version: ``(major, minor)`` of the detected CUDA runtime, or ``None``. reason: Human-readable one-liner describing the outcome (suitable for log messages and user-facing warnings). """ available: bool binary: Optional[str] cuda_version: Optional[Tuple[int, int]] reason: str
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _get_cuda_version() -> Optional[Tuple[int, int]]: """ Query ``nvidia-smi`` for the maximum CUDA version supported by the driver. The header line of ``nvidia-smi`` output contains a string of the form:: CUDA Version: 12.4 which represents the highest CUDA version the installed driver can support. A newer driver is forward-compatible with any lower CUDA runtime, so this value is used as an upper bound for compatibility checks. Returns: ``(major, minor)`` on success, or ``None`` if ``nvidia-smi`` is not found, times out, exits with a non-zero code, or produces no recognisable CUDA version string. """ if shutil.which("nvidia-smi") is None: logger.debug("nvidia-smi not found on PATH") return None try: result = subprocess.run( ["nvidia-smi"], capture_output=True, text=True, timeout=15, ) except subprocess.TimeoutExpired: logger.debug("nvidia-smi timed out after 15 s") return None except OSError as exc: logger.debug("nvidia-smi could not be executed: {}", exc) return None if result.returncode != 0: logger.debug("nvidia-smi exited with code {}", result.returncode) return None match = _CUDA_VERSION_RE.search(result.stdout) if not match: logger.debug("nvidia-smi output contained no 'CUDA Version:' string") return None return int(match.group(1)), int(match.group(2)) def _parse_required_cuda(binary_name: str) -> Optional[Tuple[int, int]]: """ Extract the minimum required CUDA major.minor version from a binary name. For example ``"probtrackx2_gpu11.0"`` → ``(11, 0)``. Returns ``None`` for unversioned names like ``"probtrackx2_gpu"``. """ match = _VERSION_RE.search(binary_name) if match: return int(match.group(1)), int(match.group(2)) return None
[docs] def present_gpu_binaries() -> list: """Return the subset of :data:`GPU_BINARIES` found on ``$FSLDIR/bin`` or ``$PATH``. Candidates are returned in :data:`GPU_BINARIES` preference order. This is a presence check only — CUDA compatibility is validated separately by :func:`check_gpu`. """ fsl_bin = os.path.join(os.environ.get("FSLDIR", ""), "bin") return [ name for name in GPU_BINARIES if os.path.isfile(os.path.join(fsl_bin, name)) or shutil.which(name) is not None ]
# --------------------------------------------------------------------------- # Public API # ---------------------------------------------------------------------------
[docs] def check_gpu() -> GPUStatus: """ Check whether a compatible GPU probtrackx2 binary is available. Performs three checks in sequence: 1. **Binary presence** — looks for each candidate in ``$FSLDIR/bin/`` then ``$PATH``. 2. **CUDA runtime** — calls ``nvidia-smi`` to determine the maximum CUDA version supported by the installed driver. 3. **Version compatibility** — for versioned binaries (e.g. ``probtrackx2_gpu11.0``) the CUDA runtime must be ≥ the encoded version; unversioned binaries (``probtrackx2_gpu``) only require that *some* CUDA runtime is present. The first binary that passes all three checks is selected. Returns: :class:`GPUStatus` with ``available=True`` and the chosen binary when a compatible setup is detected, or ``available=False`` with a descriptive ``reason`` otherwise. """ present = present_gpu_binaries() if not present: return GPUStatus( available=False, binary=None, cuda_version=None, reason=( f"no GPU probtrackx2 binary found " f"(searched: {', '.join(GPU_BINARIES)} in $FSLDIR/bin and $PATH)" ), ) # GPU binary/binaries exist — now check the CUDA runtime. cuda = _get_cuda_version() cuda_str = f"{cuda[0]}.{cuda[1]}" if cuda else "unknown" if cuda is None: return GPUStatus( available=False, binary=None, cuda_version=None, reason=( f"GPU binary found ({', '.join(present)}) " f"but CUDA runtime is not available " f"(nvidia-smi not found or did not report a CUDA version)" ), ) # Try each present binary in preference order. for name in present: required = _parse_required_cuda(name) if required is not None and cuda < required: req_str = f"{required[0]}.{required[1]}" logger.debug( "GPU binary '%s' requires CUDA >= %s; available CUDA: %s — skipping", name, req_str, cuda_str, ) continue return GPUStatus( available=True, binary=name, cuda_version=cuda, reason=f"using '{name}' (CUDA runtime: {cuda_str})", ) # Every present binary failed the version check. mismatches = [] for name in present: req = _parse_required_cuda(name) if req: mismatches.append(f"{name} needs CUDA >= {req[0]}.{req[1]}") else: mismatches.append(name) return GPUStatus( available=False, binary=None, cuda_version=cuda, reason=(f"CUDA version mismatch (available: {cuda_str}): " + "; ".join(mismatches)), )