"""Diffusion MRI processing node builders for DTIFit and BedpostX.
This module provides factory functions to create configured Nipype nodes for:
- DTIFit: Tensor fitting using FSL's dtifit
- BedpostX: Bayesian estimation of diffusion parameters using FSL's bedpostx_gpu/bedpostx
"""
import os
import shutil
from pathlib import Path
from typing import Any, cast
from nipype import Node
from nipype.interfaces.base import isdefined
from nipype.interfaces.fsl import BEDPOSTX5, DTIFit
from nipype.interfaces.fsl.base import Info
from thesis.core.logging import get_logger
logger = get_logger(__name__)
__all__ = ["prepare_dtifit_node", "prepare_bedpostx_node"]
class _PreprocessDTIFit(DTIFit):
"""DTIFit variant that preserves absolute output prefixes.
Nipype's built-in ``DTIFit._list_outputs`` rewrites output paths into the
node working directory even when ``base_name`` is configured as an absolute
path. The preprocess workflow intentionally uses an absolute output prefix so
tensor outputs land directly in the patient results tree. This subclass keeps
absolute prefixes intact while delegating to the standard behavior for
relative ``base_name`` values.
"""
def _list_outputs(self) -> dict[str, Any]:
"""List expected outputs while honoring absolute ``base_name`` values."""
base_name_value = cast(str, getattr(self.inputs, "base_name"))
if not os.path.isabs(base_name_value):
inherited_outputs: dict[str, Any] = cast(dict[str, Any], super()._list_outputs())
return inherited_outputs
keys_to_ignore = {"outputtype", "environ", "args"}
opt_output = {
"tensor": getattr(self.inputs, "save_tensor"),
"sse": getattr(self.inputs, "sse"),
}
for output, input_flag in opt_output.items():
if isdefined(input_flag) and input_flag:
continue
keys_to_ignore.add(output)
outputs = cast(dict[str, Any], self.output_spec().get())
base_name = Path(base_name_value)
output_type = cast(str, getattr(self.inputs, "output_type"))
ext = Info.output_type_to_ext(output_type)
for key in set(outputs.keys()) - keys_to_ignore:
outputs[key] = str(base_name.parent / f"{base_name.name}_{key}{ext}")
return outputs
class _RobustBEDPOSTX5(BEDPOSTX5):
"""BEDPOSTX5 that removes a stale output directory before re-running.
FSL's bedpostx refuses to run if ``{out_dir}.bedpostX`` already exists.
When Nipype's cache is invalidated (e.g. because eddy re-ran), bedpostx
would otherwise fail with "already been processed" even though the cached
result is no longer coherent with the new inputs.
"""
def _run_interface(self, runtime: Any) -> Any:
import shutil as _shutil
from pathlib import Path as _Path
out_dir = getattr(self.inputs, "out_dir", None)
if out_dir:
out_str = str(out_dir)
# FSL appends ".bedpostX" to out_dir, so the expected output is
# out_dir + ".bedpostX". Remove both the expected output AND any
# double-nested ".bedpostX.bedpostX" directory that may exist from
# a prior run where Nipype cached the already-appended path.
bedpostx_out = _Path(out_str + ".bedpostX")
if bedpostx_out.exists():
logger.debug(f"Removing stale bedpostX output before re-run: {bedpostx_out}")
_shutil.rmtree(str(bedpostx_out))
double_nested = _Path(out_str + ".bedpostX.bedpostX")
if double_nested.exists():
logger.debug(f"Removing double-nested bedpostX output: {double_nested}")
_shutil.rmtree(str(double_nested))
return super()._run_interface(runtime)
def _detect_bedpostx_gpu() -> bool:
"""Check if bedpostx_gpu is available in PATH or FSLDIR.
Returns:
True if bedpostx_gpu is available, False otherwise.
Example:
>>> gpu_available = _detect_bedpostx_gpu()
>>> print(f"GPU support: {gpu_available}")
"""
# Check for bedpostx_gpu in PATH
if shutil.which("bedpostx_gpu"):
return True
# Check in FSLDIR/bin
fsldir = os.environ.get("FSLDIR")
if fsldir:
bedpostx_gpu_path = os.path.join(fsldir, "bin", "bedpostx_gpu")
if os.path.isfile(bedpostx_gpu_path) and os.access(bedpostx_gpu_path, os.X_OK):
return True
return False
[docs]
def prepare_dtifit_node(
use_wls: bool = True,
save_tensor: bool = True,
compute_sse: bool = True,
compute_kurt: bool = False,
name: str = "dtifit",
) -> Node:
"""Create a configured DTIFit node for tensor fitting.
Uses FSL's dtifit to fit diffusion tensors to DWI data and compute
derived metrics (FA, MD, eigenvalues, eigenvectors).
Args:
use_wls: Use weighted least squares fitting (--wls flag). Recommended
for better handling of noise and outliers.
save_tensor: Save the full diffusion tensor elements.
compute_sse: Compute sum of squared errors between model and data.
compute_kurt: Compute kurtosis tensors (--kurt flag). Requires additional
high b-value shells.
name: Node name for workflow visualization.
Returns:
Configured Nipype Node wrapping FSL DTIFit interface.
Raises:
ValueError: If name is empty.
Example:
>>> node = prepare_dtifit_node(use_wls=True, compute_kurt=False)
>>> node.inputs.dwi = "dwi.nii.gz"
>>> node.inputs.mask = "brain_mask.nii.gz"
>>> node.inputs.bvecs = "bvecs"
>>> node.inputs.bvals = "bvals"
>>> node.inputs.base_name = "dti"
"""
if not name:
raise ValueError("Node name cannot be empty")
logger.debug(
f"Creating DTIFit node: use_wls={use_wls}, save_tensor={save_tensor}, "
f"compute_sse={compute_sse}, compute_kurt={compute_kurt}"
)
# Create the DTIFit interface
dtifit = Node(_PreprocessDTIFit(), name=name)
dtifit.inputs.save_tensor = save_tensor
dtifit.inputs.sse = compute_sse
# Configure args based on parameters
args_list = []
if use_wls:
args_list.append("--wls")
if compute_kurt:
args_list.append("--kurt")
if args_list:
dtifit.inputs.args = " ".join(args_list)
logger.debug(f"DTIFit args: {dtifit.inputs.args}")
return dtifit
[docs]
def prepare_bedpostx_node(
n_fibres: int = 3,
model: int = 2,
burn_in: int = 1000,
n_jumps: int = 1250,
sample_every: int = 25,
use_gpu: bool = True,
name: str = "bedpostx",
) -> Node:
"""Create a configured BedpostX node for Bayesian diffusion parameter estimation.
Uses FSL's BEDPOSTX5 to estimate up to 5 crossing fiber populations per voxel
using Markov Chain Monte Carlo (MCMC) sampling. Supports GPU acceleration.
Args:
n_fibres: Number of fiber populations to model per voxel (1-5).
More fibers allow modeling complex crossings but increase computation time.
model: Diffusion model type:
1 = Ball and stick (single fiber orientation per population)
2 = Ball and sticks with ARD (automatic relevance determination)
3 = Ball and Bingham (allows fanning/spreading of fibers)
burn_in: Number of MCMC burn-in iterations to discard. Default 1000.
n_jumps: Number of MCMC jumps after burn-in. Default 1250.
sample_every: Sample every Nth MCMC iteration. Default 25.
Total samples = n_jumps / sample_every = 50 by default.
use_gpu: Attempt to use GPU-accelerated bedpostx_gpu if available.
Falls back to CPU version if GPU not detected.
name: Node name for workflow visualization.
Returns:
Configured Nipype Node wrapping FSL BEDPOSTX5 interface.
Raises:
ValueError: If n_fibres not in range [1, 5], model not in [1, 2, 3],
or name is empty.
Example:
>>> node = prepare_bedpostx_node(n_fibres=3, model=2, use_gpu=True)
>>> node.inputs.dwi = "dwi.nii.gz"
>>> node.inputs.mask = "brain_mask.nii.gz"
>>> node.inputs.bvecs = "bvecs"
>>> node.inputs.bvals = "bvals"
"""
# Validate parameters
if not name:
raise ValueError("Node name cannot be empty")
if not (1 <= n_fibres <= 5):
raise ValueError(f"n_fibres must be between 1 and 5 (BEDPOSTX5 limitation), got {n_fibres}")
if model not in [1, 2, 3]:
raise ValueError(
f"model must be 1 (ball and stick), 2 (ball and sticks with ARD), "
f"or 3 (ball and Bingham), got {model}"
)
logger.debug(
f"Creating BedpostX node: n_fibres={n_fibres}, model={model}, "
f"burn_in={burn_in}, n_jumps={n_jumps}, sample_every={sample_every}, "
f"use_gpu={use_gpu}"
)
# Check GPU availability if requested
gpu_available = False
if use_gpu:
gpu_available = _detect_bedpostx_gpu()
if gpu_available:
logger.info("GPU-accelerated bedpostx_gpu detected and will be used")
else:
logger.warning(
"GPU requested but bedpostx_gpu not found in PATH or FSLDIR/bin. "
"Falling back to CPU version (this may take significantly longer)"
)
# Create the BEDPOSTX5 interface
bedpostx = Node(_RobustBEDPOSTX5(), name=name)
# Configure interface parameters
bedpostx.inputs.n_fibres = n_fibres
bedpostx.inputs.model = model
bedpostx.inputs.burn_in = burn_in
bedpostx.inputs.n_jumps = n_jumps
bedpostx.inputs.sample_every = sample_every
# Set GPU flag if available
if gpu_available:
bedpostx.inputs.use_gpu = True
logger.debug(
f"BedpostX configured: n_fibres={n_fibres}, model={model}, "
f"burn_in={burn_in}, n_jumps={n_jumps}, sample_every={sample_every}, "
f"use_gpu={gpu_available}"
)
return bedpostx