"""
GPU availability detection for the thesis framework.
Provides a single ``check_gpu()`` function that tests whether a compatible
GPU probtrackx2 binary and a suitable CUDA runtime are available on the
current system. This check is intended to run **once at CLI startup** so
that every downstream tool sees a consistent ``config.hardware.gpu_enabled``
value, rather than each interface performing its own ad-hoc detection.
Typical usage::
from thesis.core.gpu import check_gpu
status = check_gpu()
if status.available:
print(f"GPU ready: {status.reason}")
else:
print(f"GPU unavailable: {status.reason}")
"""
import os
import re
import shutil
import subprocess
from typing import NamedTuple, Optional, Tuple
from thesis.core.logging import get_logger
logger = get_logger(__name__)
__all__ = ["GPUStatus", "GPU_BINARIES", "check_gpu", "present_gpu_binaries"]
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Candidate GPU probtrackx2 binaries in preference order.
# Versioned names (e.g. probtrackx2_gpu11.0) encode the minimum required
# CUDA major.minor version. This is the single source of truth for the
# candidate list — other modules (e.g. the FSL interfaces) import it rather
# than redefining their own copy.
GPU_BINARIES = ["probtrackx2_gpu11.0", "probtrackx2_gpu"]
# Parses the required CUDA version from a versioned binary name.
_VERSION_RE = re.compile(r"gpu(\d+)\.(\d+)$")
# Parses "CUDA Version: X.Y" from nvidia-smi's header line.
_CUDA_VERSION_RE = re.compile(r"CUDA\s+Version:\s*(\d+)\.(\d+)", re.IGNORECASE)
# ---------------------------------------------------------------------------
# Result type
# ---------------------------------------------------------------------------
[docs]
class GPUStatus(NamedTuple):
"""
Result of a GPU availability check.
Attributes:
available: ``True`` if a compatible GPU binary and CUDA runtime
were both found.
binary: Name of the selected GPU binary (e.g. ``"probtrackx2_gpu11.0"``),
or ``None`` if unavailable.
cuda_version: ``(major, minor)`` of the detected CUDA runtime, or ``None``.
reason: Human-readable one-liner describing the outcome (suitable
for log messages and user-facing warnings).
"""
available: bool
binary: Optional[str]
cuda_version: Optional[Tuple[int, int]]
reason: str
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _get_cuda_version() -> Optional[Tuple[int, int]]:
"""
Query ``nvidia-smi`` for the maximum CUDA version supported by the driver.
The header line of ``nvidia-smi`` output contains a string of the form::
CUDA Version: 12.4
which represents the highest CUDA version the installed driver can support.
A newer driver is forward-compatible with any lower CUDA runtime, so
this value is used as an upper bound for compatibility checks.
Returns:
``(major, minor)`` on success, or ``None`` if ``nvidia-smi`` is not
found, times out, exits with a non-zero code, or produces no
recognisable CUDA version string.
"""
if shutil.which("nvidia-smi") is None:
logger.debug("nvidia-smi not found on PATH")
return None
try:
result = subprocess.run(
["nvidia-smi"],
capture_output=True,
text=True,
timeout=15,
)
except subprocess.TimeoutExpired:
logger.debug("nvidia-smi timed out after 15 s")
return None
except OSError as exc:
logger.debug("nvidia-smi could not be executed: {}", exc)
return None
if result.returncode != 0:
logger.debug("nvidia-smi exited with code {}", result.returncode)
return None
match = _CUDA_VERSION_RE.search(result.stdout)
if not match:
logger.debug("nvidia-smi output contained no 'CUDA Version:' string")
return None
return int(match.group(1)), int(match.group(2))
def _parse_required_cuda(binary_name: str) -> Optional[Tuple[int, int]]:
"""
Extract the minimum required CUDA major.minor version from a binary name.
For example ``"probtrackx2_gpu11.0"`` → ``(11, 0)``.
Returns ``None`` for unversioned names like ``"probtrackx2_gpu"``.
"""
match = _VERSION_RE.search(binary_name)
if match:
return int(match.group(1)), int(match.group(2))
return None
[docs]
def present_gpu_binaries() -> list:
"""Return the subset of :data:`GPU_BINARIES` found on ``$FSLDIR/bin`` or ``$PATH``.
Candidates are returned in :data:`GPU_BINARIES` preference order. This is a
presence check only — CUDA compatibility is validated separately by
:func:`check_gpu`.
"""
fsl_bin = os.path.join(os.environ.get("FSLDIR", ""), "bin")
return [
name
for name in GPU_BINARIES
if os.path.isfile(os.path.join(fsl_bin, name)) or shutil.which(name) is not None
]
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
def check_gpu() -> GPUStatus:
"""
Check whether a compatible GPU probtrackx2 binary is available.
Performs three checks in sequence:
1. **Binary presence** — looks for each candidate in ``$FSLDIR/bin/``
then ``$PATH``.
2. **CUDA runtime** — calls ``nvidia-smi`` to determine the maximum CUDA
version supported by the installed driver.
3. **Version compatibility** — for versioned binaries (e.g.
``probtrackx2_gpu11.0``) the CUDA runtime must be ≥ the encoded
version; unversioned binaries (``probtrackx2_gpu``) only require that
*some* CUDA runtime is present.
The first binary that passes all three checks is selected.
Returns:
:class:`GPUStatus` with ``available=True`` and the chosen binary when
a compatible setup is detected, or ``available=False`` with a
descriptive ``reason`` otherwise.
"""
present = present_gpu_binaries()
if not present:
return GPUStatus(
available=False,
binary=None,
cuda_version=None,
reason=(
f"no GPU probtrackx2 binary found "
f"(searched: {', '.join(GPU_BINARIES)} in $FSLDIR/bin and $PATH)"
),
)
# GPU binary/binaries exist — now check the CUDA runtime.
cuda = _get_cuda_version()
cuda_str = f"{cuda[0]}.{cuda[1]}" if cuda else "unknown"
if cuda is None:
return GPUStatus(
available=False,
binary=None,
cuda_version=None,
reason=(
f"GPU binary found ({', '.join(present)}) "
f"but CUDA runtime is not available "
f"(nvidia-smi not found or did not report a CUDA version)"
),
)
# Try each present binary in preference order.
for name in present:
required = _parse_required_cuda(name)
if required is not None and cuda < required:
req_str = f"{required[0]}.{required[1]}"
logger.debug(
"GPU binary '%s' requires CUDA >= %s; available CUDA: %s — skipping",
name,
req_str,
cuda_str,
)
continue
return GPUStatus(
available=True,
binary=name,
cuda_version=cuda,
reason=f"using '{name}' (CUDA runtime: {cuda_str})",
)
# Every present binary failed the version check.
mismatches = []
for name in present:
req = _parse_required_cuda(name)
if req:
mismatches.append(f"{name} needs CUDA >= {req[0]}.{req[1]}")
else:
mismatches.append(name)
return GPUStatus(
available=False,
binary=None,
cuda_version=cuda,
reason=(f"CUDA version mismatch (available: {cuda_str}): " + "; ".join(mismatches)),
)