Source code for thesis.workflows.preprocess.nodes.diffusion

"""Diffusion MRI processing node builders for DTIFit and BedpostX.

This module provides factory functions to create configured Nipype nodes for:
- DTIFit: Tensor fitting using FSL's dtifit
- BedpostX: Bayesian estimation of diffusion parameters using FSL's bedpostx_gpu/bedpostx
"""

import os
import shutil
from pathlib import Path
from typing import Any, cast

from nipype import Node
from nipype.interfaces.base import isdefined
from nipype.interfaces.fsl import BEDPOSTX5, DTIFit
from nipype.interfaces.fsl.base import Info

from thesis.core.logging import get_logger

logger = get_logger(__name__)

__all__ = ["prepare_dtifit_node", "prepare_bedpostx_node"]


class _PreprocessDTIFit(DTIFit):
    """DTIFit variant that preserves absolute output prefixes.

    Nipype's built-in ``DTIFit._list_outputs`` rewrites output paths into the
    node working directory even when ``base_name`` is configured as an absolute
    path. The preprocess workflow intentionally uses an absolute output prefix so
    tensor outputs land directly in the patient results tree. This subclass keeps
    absolute prefixes intact while delegating to the standard behavior for
    relative ``base_name`` values.
    """

    def _list_outputs(self) -> dict[str, Any]:
        """List expected outputs while honoring absolute ``base_name`` values."""
        base_name_value = cast(str, getattr(self.inputs, "base_name"))
        if not os.path.isabs(base_name_value):
            inherited_outputs: dict[str, Any] = cast(dict[str, Any], super()._list_outputs())
            return inherited_outputs

        keys_to_ignore = {"outputtype", "environ", "args"}
        opt_output = {
            "tensor": getattr(self.inputs, "save_tensor"),
            "sse": getattr(self.inputs, "sse"),
        }
        for output, input_flag in opt_output.items():
            if isdefined(input_flag) and input_flag:
                continue
            keys_to_ignore.add(output)

        outputs = cast(dict[str, Any], self.output_spec().get())
        base_name = Path(base_name_value)
        output_type = cast(str, getattr(self.inputs, "output_type"))
        ext = Info.output_type_to_ext(output_type)
        for key in set(outputs.keys()) - keys_to_ignore:
            outputs[key] = str(base_name.parent / f"{base_name.name}_{key}{ext}")
        return outputs


class _RobustBEDPOSTX5(BEDPOSTX5):
    """BEDPOSTX5 that removes a stale output directory before re-running.

    FSL's bedpostx refuses to run if ``{out_dir}.bedpostX`` already exists.
    When Nipype's cache is invalidated (e.g. because eddy re-ran), bedpostx
    would otherwise fail with "already been processed" even though the cached
    result is no longer coherent with the new inputs.
    """

    def _run_interface(self, runtime: Any) -> Any:
        import shutil as _shutil
        from pathlib import Path as _Path

        out_dir = getattr(self.inputs, "out_dir", None)
        if out_dir:
            out_str = str(out_dir)
            # FSL appends ".bedpostX" to out_dir, so the expected output is
            # out_dir + ".bedpostX".  Remove both the expected output AND any
            # double-nested ".bedpostX.bedpostX" directory that may exist from
            # a prior run where Nipype cached the already-appended path.
            bedpostx_out = _Path(out_str + ".bedpostX")
            if bedpostx_out.exists():
                logger.debug(f"Removing stale bedpostX output before re-run: {bedpostx_out}")
                _shutil.rmtree(str(bedpostx_out))
            double_nested = _Path(out_str + ".bedpostX.bedpostX")
            if double_nested.exists():
                logger.debug(f"Removing double-nested bedpostX output: {double_nested}")
                _shutil.rmtree(str(double_nested))
        return super()._run_interface(runtime)


def _detect_bedpostx_gpu() -> bool:
    """Check if bedpostx_gpu is available in PATH or FSLDIR.

    Returns:
        True if bedpostx_gpu is available, False otherwise.

    Example:
        >>> gpu_available = _detect_bedpostx_gpu()
        >>> print(f"GPU support: {gpu_available}")
    """
    # Check for bedpostx_gpu in PATH
    if shutil.which("bedpostx_gpu"):
        return True
    # Check in FSLDIR/bin
    fsldir = os.environ.get("FSLDIR")
    if fsldir:
        bedpostx_gpu_path = os.path.join(fsldir, "bin", "bedpostx_gpu")
        if os.path.isfile(bedpostx_gpu_path) and os.access(bedpostx_gpu_path, os.X_OK):
            return True
    return False


[docs] def prepare_dtifit_node( use_wls: bool = True, save_tensor: bool = True, compute_sse: bool = True, compute_kurt: bool = False, name: str = "dtifit", ) -> Node: """Create a configured DTIFit node for tensor fitting. Uses FSL's dtifit to fit diffusion tensors to DWI data and compute derived metrics (FA, MD, eigenvalues, eigenvectors). Args: use_wls: Use weighted least squares fitting (--wls flag). Recommended for better handling of noise and outliers. save_tensor: Save the full diffusion tensor elements. compute_sse: Compute sum of squared errors between model and data. compute_kurt: Compute kurtosis tensors (--kurt flag). Requires additional high b-value shells. name: Node name for workflow visualization. Returns: Configured Nipype Node wrapping FSL DTIFit interface. Raises: ValueError: If name is empty. Example: >>> node = prepare_dtifit_node(use_wls=True, compute_kurt=False) >>> node.inputs.dwi = "dwi.nii.gz" >>> node.inputs.mask = "brain_mask.nii.gz" >>> node.inputs.bvecs = "bvecs" >>> node.inputs.bvals = "bvals" >>> node.inputs.base_name = "dti" """ if not name: raise ValueError("Node name cannot be empty") logger.debug( f"Creating DTIFit node: use_wls={use_wls}, save_tensor={save_tensor}, " f"compute_sse={compute_sse}, compute_kurt={compute_kurt}" ) # Create the DTIFit interface dtifit = Node(_PreprocessDTIFit(), name=name) dtifit.inputs.save_tensor = save_tensor dtifit.inputs.sse = compute_sse # Configure args based on parameters args_list = [] if use_wls: args_list.append("--wls") if compute_kurt: args_list.append("--kurt") if args_list: dtifit.inputs.args = " ".join(args_list) logger.debug(f"DTIFit args: {dtifit.inputs.args}") return dtifit
[docs] def prepare_bedpostx_node( n_fibres: int = 3, model: int = 2, burn_in: int = 1000, n_jumps: int = 1250, sample_every: int = 25, use_gpu: bool = True, name: str = "bedpostx", ) -> Node: """Create a configured BedpostX node for Bayesian diffusion parameter estimation. Uses FSL's BEDPOSTX5 to estimate up to 5 crossing fiber populations per voxel using Markov Chain Monte Carlo (MCMC) sampling. Supports GPU acceleration. Args: n_fibres: Number of fiber populations to model per voxel (1-5). More fibers allow modeling complex crossings but increase computation time. model: Diffusion model type: 1 = Ball and stick (single fiber orientation per population) 2 = Ball and sticks with ARD (automatic relevance determination) 3 = Ball and Bingham (allows fanning/spreading of fibers) burn_in: Number of MCMC burn-in iterations to discard. Default 1000. n_jumps: Number of MCMC jumps after burn-in. Default 1250. sample_every: Sample every Nth MCMC iteration. Default 25. Total samples = n_jumps / sample_every = 50 by default. use_gpu: Attempt to use GPU-accelerated bedpostx_gpu if available. Falls back to CPU version if GPU not detected. name: Node name for workflow visualization. Returns: Configured Nipype Node wrapping FSL BEDPOSTX5 interface. Raises: ValueError: If n_fibres not in range [1, 5], model not in [1, 2, 3], or name is empty. Example: >>> node = prepare_bedpostx_node(n_fibres=3, model=2, use_gpu=True) >>> node.inputs.dwi = "dwi.nii.gz" >>> node.inputs.mask = "brain_mask.nii.gz" >>> node.inputs.bvecs = "bvecs" >>> node.inputs.bvals = "bvals" """ # Validate parameters if not name: raise ValueError("Node name cannot be empty") if not (1 <= n_fibres <= 5): raise ValueError(f"n_fibres must be between 1 and 5 (BEDPOSTX5 limitation), got {n_fibres}") if model not in [1, 2, 3]: raise ValueError( f"model must be 1 (ball and stick), 2 (ball and sticks with ARD), " f"or 3 (ball and Bingham), got {model}" ) logger.debug( f"Creating BedpostX node: n_fibres={n_fibres}, model={model}, " f"burn_in={burn_in}, n_jumps={n_jumps}, sample_every={sample_every}, " f"use_gpu={use_gpu}" ) # Check GPU availability if requested gpu_available = False if use_gpu: gpu_available = _detect_bedpostx_gpu() if gpu_available: logger.info("GPU-accelerated bedpostx_gpu detected and will be used") else: logger.warning( "GPU requested but bedpostx_gpu not found in PATH or FSLDIR/bin. " "Falling back to CPU version (this may take significantly longer)" ) # Create the BEDPOSTX5 interface bedpostx = Node(_RobustBEDPOSTX5(), name=name) # Configure interface parameters bedpostx.inputs.n_fibres = n_fibres bedpostx.inputs.model = model bedpostx.inputs.burn_in = burn_in bedpostx.inputs.n_jumps = n_jumps bedpostx.inputs.sample_every = sample_every # Set GPU flag if available if gpu_available: bedpostx.inputs.use_gpu = True logger.debug( f"BedpostX configured: n_fibres={n_fibres}, model={model}, " f"burn_in={burn_in}, n_jumps={n_jumps}, sample_every={sample_every}, " f"use_gpu={gpu_available}" ) return bedpostx