"""
Configuration for the preprocessing workflow.
This module provides Pydantic models for validating and managing
preprocessing workflow parameters, including input file patterns,
acquisition parameters, BET settings, BedpostX configuration,
SynthSeg options, registration chains, and output naming conventions.
"""
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict
from pydantic import BaseModel, ConfigDict, Field, field_validator
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.logging import get_logger
from thesis.core.utils import resolve_path
logger = get_logger(__name__)
__all__ = [
"AcqParamsConfig",
"BetConfig",
"BedpostXConfig",
"PreprocessSynthSegConfig",
"RegistrationStepConfig",
"RegistrationChainConfig",
"LabelTransformConfig",
"DTIFitConfig",
"PreprocessConfig",
"PreprocessPaths",
"prepare_preprocess_paths",
]
# ---------------------------------------------------------------------------
# Base Config
# ---------------------------------------------------------------------------
class BaseConfig(BaseModel):
"""Base config model that rejects unknown fields to catch typos early."""
model_config = ConfigDict(extra="forbid")
# ---------------------------------------------------------------------------
# Helper Dataclasses
# ---------------------------------------------------------------------------
[docs]
class AcqParamsConfig(BaseConfig):
"""
Acquisition parameters for TOPUP distortion correction.
``create_acqparams_file`` derives the readout time from ``bandwidth`` and
``phase_encoding_dirs`` and emits fixed AP/PA encoding rows, so only these
two fields are needed.
Attributes:
bandwidth: Effective readout bandwidth in Hz per pixel.
phase_encoding_dirs: Number of phase-encoding lines used to derive the
readout time.
"""
bandwidth: float = Field(
default=1685.0,
gt=0.0,
description="Effective readout bandwidth in Hz per pixel",
)
phase_encoding_dirs: int = Field(
default=96,
gt=1,
description="Number of phase-encoding lines used to derive readout time",
)
[docs]
class BetConfig(BaseConfig):
"""
Brain extraction (BET) parameters for different modalities.
Attributes:
frac_dwi: Fractional intensity threshold for DWI brain extraction (0-1).
frac_t1: Fractional intensity threshold for T1 brain extraction (0-1).
frac_t2: Fractional intensity threshold for T2 brain extraction (0-1).
robust: Whether to use robust brain center estimation.
padding: Whether to apply padding (improves BET performance).
radius: Optional brain radius in mm (if not specified, BET estimates it).
"""
frac_dwi: float = Field(
default=0.3,
ge=0.0,
le=1.0,
description="Fractional intensity threshold for DWI (0-1)",
)
frac_t1: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Fractional intensity threshold for T1 (0-1)",
)
frac_t2: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Fractional intensity threshold for T2 (0-1)",
)
robust: bool = Field(
default=True,
description="Use robust brain center estimation",
)
padding: bool = Field(
default=True,
description="Apply padding (improves BET performance)",
)
radius: Optional[int] = Field(
default=None,
gt=0,
description="Brain radius in mm (if None, BET estimates it)",
)
[docs]
class BedpostXConfig(BaseConfig):
"""
BedpostX fibre orientation estimation parameters.
Attributes:
n_fibres: Number of fibre populations to model per voxel (1-3).
model: Model type (1 = single-shell, 2 = multi-shell, 3 = ball-and-sticks).
burn_in: Number of MCMC burn-in jumps.
n_jumps: Total number of MCMC jumps.
sample_every: Sample every Nth jump.
use_gpu: Whether to use GPU acceleration (requires CUDA-enabled BedpostX).
weight: ARD weight (only for model 3).
"""
n_fibres: int = Field(
default=2,
ge=1,
le=3,
description="Number of fibre populations to model per voxel (1-3)",
)
model: int = Field(
default=1,
ge=1,
le=3,
description="Model type (1=single-shell, 2=multi-shell, 3=ball-and-sticks)",
)
burn_in: int = Field(
default=1000,
ge=0,
description="Number of MCMC burn-in jumps",
)
n_jumps: int = Field(
default=1250,
ge=1,
description="Total number of MCMC jumps",
)
sample_every: int = Field(
default=25,
ge=1,
description="Sample every Nth jump",
)
use_gpu: bool = Field(
default=False,
description="Use GPU acceleration (requires CUDA-enabled BedpostX)",
)
weight: float = Field(
default=1.0,
ge=0.0,
description="ARD weight (only for model 3)",
)
[docs]
class PreprocessSynthSegConfig(BaseConfig):
"""
SynthSeg segmentation parameters for the preprocessing workflow.
Attributes:
run_on_t1: Whether to run SynthSeg on T1-weighted image.
run_on_t2: Whether to run SynthSeg on T2-weighted image.
parc: Whether to generate cortical parcellation.
robust: Whether to use robust mode.
fast: Whether to use fast mode.
vol: Whether to write volumes CSV.
qc: Whether to write QC CSV.
crop: Optional crop size.
cpu: Whether to force CPU execution.
threads: CPU thread count when cpu=True.
options: Additional SynthSeg command-line options.
"""
run_on_t1: bool = Field(
default=True,
description="Run SynthSeg on T1-weighted image",
)
run_on_t2: bool = Field(
default=False,
description="Run SynthSeg on T2-weighted image",
)
parc: bool = Field(
default=False,
description="Generate cortical parcellation",
)
robust: bool = Field(
default=False,
description="Use robust mode",
)
fast: bool = Field(
default=False,
description="Use fast mode",
)
vol: bool = Field(
default=True,
description="Write volumes CSV",
)
qc: bool = Field(
default=True,
description="Write QC CSV",
)
crop: Optional[int] = Field(
default=None,
ge=1,
description="Optional crop size",
)
cpu: bool = Field(
default=False,
description="Force CPU execution",
)
threads: int = Field(
default=1,
ge=1,
description="CPU thread count when cpu=True",
)
options: Dict[str, Any] = Field(
default_factory=dict,
description="Additional SynthSeg command-line options",
)
[docs]
class RegistrationStepConfig(BaseConfig):
"""
Configuration for a single registration step.
Attributes:
method: Registration method (ants, fsl, etc.).
metric: Similarity metric (MI, CC, etc.).
transform_type: Transform type (Rigid, Affine, SyN).
interpolation: Interpolation method (Linear, BSpline, etc.).
use_float: Whether to use float precision.
collapse_output_transforms: Whether to collapse output transforms.
write_composite_transform: Whether to write composite transform.
"""
method: str = Field(
default="ants",
description="Registration method (ants, fsl, etc.)",
)
metric: str = Field(
default="MI",
description="Similarity metric (MI, CC, etc.)",
)
transform_type: str = Field(
default="Rigid",
description="Transform type (Rigid, Affine, SyN)",
)
interpolation: str = Field(
default="Linear",
description="Interpolation method (Linear, BSpline, etc.)",
)
use_float: bool = Field(
default=True,
description="Use float precision",
)
collapse_output_transforms: bool = Field(
default=True,
description="Collapse output transforms where possible",
)
write_composite_transform: bool = Field(
default=True,
description="Write composite transform",
)
[docs]
@field_validator("method")
@classmethod
def validate_method(cls, v: str) -> str:
"""Validate registration method."""
valid_methods = ["ants", "fsl", "dipy", "fireants"]
if v not in valid_methods:
raise ValueError(f"method must be one of {valid_methods}")
return v
[docs]
class RegistrationChainConfig(BaseConfig):
"""
Multi-step registration chain configuration.
Attributes:
dwi_to_t1: Registration from DWI space to T1 space.
t2_to_t1: Registration from T2 space to T1 space.
t1_to_template: Registration from T1 space to template space.
"""
dwi_to_t1: RegistrationStepConfig = Field(
default_factory=lambda: RegistrationStepConfig(
method="ants",
transform_type="Rigid",
metric="MI",
),
description="Registration from DWI to T1",
)
t2_to_t1: RegistrationStepConfig = Field(
default_factory=lambda: RegistrationStepConfig(
method="ants",
transform_type="Rigid",
metric="MI",
),
description="Registration from T2 to T1",
)
t1_to_template: RegistrationStepConfig = Field(
default_factory=lambda: RegistrationStepConfig(
method="ants",
transform_type="SyN",
metric="MI",
),
description="Registration from T1 to template",
)
[docs]
class DTIFitConfig(BaseConfig):
"""
DTIFit configuration for tensor estimation.
Attributes:
use_wls: Whether to use weighted least squares fitting.
compute_kurt: Whether to compute kurtosis parameters.
save_tensor: Whether to save full tensor components.
"""
use_wls: bool = Field(
default=True,
description="Use weighted least squares fitting",
)
compute_kurt: bool = Field(
default=False,
description="Compute kurtosis parameters",
)
save_tensor: bool = Field(
default=True,
description="Save full tensor components",
)
# ---------------------------------------------------------------------------
# Main PreprocessConfig
# ---------------------------------------------------------------------------
[docs]
class PreprocessConfig(BaseConfig):
"""
Comprehensive configuration for the preprocessing workflow.
Attributes:
# Input file patterns
t1_image: T1-weighted image file pattern or path.
t2_image: Optional T2-weighted image file pattern or path.
dwi_ap: Anterior-posterior DWI image.
dwi_ap_bval: b-values file for AP acquisition.
dwi_ap_bvec: b-vectors file for AP acquisition.
dwi_pa: Posterior-anterior DWI image (for TOPUP).
# Acquisition parameters
acq_params: TOPUP acquisition parameters.
# BET parameters
bet: Brain extraction parameters for different modalities.
# BedpostX parameters
bedpostx: Fibre orientation estimation parameters.
# SynthSeg parameters
synthseg: Segmentation parameters.
# Registration chain
registration: Multi-step registration configuration.
# Label transformation
label_transform: Label transformation and atlas warping config.
# DTIFit options
dtifit: DTI tensor estimation options.
# Output naming (HCP-style)
t1_final: Output T1 filename.
t2_final: Output T2 filename.
diffusion_dir: Output diffusion directory name.
bedpostx_dir: Output BedpostX directory name.
# Workflow control flags
run_topup: Whether to run TOPUP distortion correction.
run_eddy: Whether to run eddy current correction.
run_dtifit: Whether to run DTIFit.
run_bedpostx: Whether to run BedpostX.
run_synthseg: Whether to run SynthSeg.
run_coregistration: Whether to run the intra-subject coregistration
chain (DWI->T1, T2->T1, MNI-label warps).
"""
# ---- Input file patterns ----
t1_image: str = Field(
default="{patient_id}_T1.nii.gz",
description="T1-weighted image file pattern or path",
)
t2_image: Optional[str] = Field(
default="{patient_id}_T2.nii.gz",
description="T2-weighted image file pattern or path",
)
dwi_ap: str = Field(
default="{patient_id}_dmri_AP.nii.gz",
description="Anterior-posterior DWI image",
)
dwi_ap_bval: str = Field(
default="{patient_id}_dmri_AP.bval",
description="b-values file for AP acquisition",
)
dwi_ap_bvec: str = Field(
default="{patient_id}_dmri_AP.bvec",
description="b-vectors file for AP acquisition",
)
dwi_pa: Optional[str] = Field(
default="{patient_id}_dmri_PA.nii.gz",
description="Posterior-anterior DWI image (for TOPUP)",
)
# ---- Acquisition parameters ----
acq_params: AcqParamsConfig = Field(
default_factory=AcqParamsConfig,
description="TOPUP acquisition parameters",
)
# ---- BET parameters ----
bet: BetConfig = Field(
default_factory=BetConfig,
description="Brain extraction parameters",
)
# ---- BedpostX parameters ----
bedpostx: BedpostXConfig = Field(
default_factory=BedpostXConfig,
description="Fibre orientation estimation parameters",
)
# ---- SynthSeg parameters ----
synthseg: PreprocessSynthSegConfig = Field(
default_factory=PreprocessSynthSegConfig,
description="Segmentation parameters",
)
# ---- Registration chain ----
registration: RegistrationChainConfig = Field(
default_factory=RegistrationChainConfig,
description="Multi-step registration configuration",
)
# ---- Label transformation ----
label_transform: LabelTransformConfig = Field(
default_factory=LabelTransformConfig,
description="Label transformation and atlas warping config",
)
# ---- DTIFit options ----
dtifit: DTIFitConfig = Field(
default_factory=DTIFitConfig,
description="DTI tensor estimation options",
)
# ---- Output naming (HCP-style) ----
t1_final: str = Field(
default="T1w/T1w_acpc_dc_restore_1.25.nii.gz",
description="Output T1 filename",
)
t2_final: str = Field(
default="T1w/T2w_acpc_dc_restore.nii.gz",
description="Output T2 filename",
)
diffusion_dir: str = Field(
default="T1w/Diffusion",
description="Output diffusion directory name",
)
bedpostx_dir: str = Field(
default="T1w/Diffusion.bedpostX",
description="Output BedpostX directory name",
)
# ---- Workflow control flags ----
run_topup: bool = Field(
default=True,
description="Run TOPUP distortion correction",
)
run_eddy: bool = Field(
default=True,
description="Run eddy current correction",
)
run_dtifit: bool = Field(
default=True,
description="Run DTIFit",
)
run_bedpostx: bool = Field(
default=True,
description="Run BedpostX",
)
run_synthseg: bool = Field(
default=True,
description="Run SynthSeg",
)
run_coregistration: bool = Field(
default=True,
description=(
"Run preprocess's intra-subject coregistration (DWI->T1, T2->T1, and "
"MNI-label warps). Distinct from the top-level patient->template "
"`registration` block. full_pipeline consumes the DWI->T1 transform "
"to place atlas ROIs on the DWI grid; when off, ROIs are regridded "
"instead (only valid when T1 and DWI already share a world)."
),
)
run_robustfov: bool = Field(
default=True,
description="Run FSL robustfov before N4 to crop neck from T1/T2 images",
)
# ---------------------------------------------------------------------------
# Path Resolution
# ---------------------------------------------------------------------------
[docs]
class PreprocessPaths(TypedDict):
"""Resolved filesystem paths for the preprocessing workflow."""
input_dir: Path
output_dir: Path
t1_image: Optional[Path]
t2_image: Optional[Path]
dwi_ap: Path
dwi_ap_bval: Path
dwi_ap_bvec: Path
dwi_pa: Optional[Path]
diffusion_dir: Path
bedpostx_dir: Path
t1_final: Path
t2_final: Optional[Path]
_INPUT_DEFAULTS = {
"t1_image": "{patient_id}_T1.nii.gz",
"t2_image": "{patient_id}_T2.nii.gz",
"dwi_ap": "{patient_id}_dmri_AP.nii.gz",
"dwi_ap_bval": "{patient_id}_dmri_AP.bval",
"dwi_ap_bvec": "{patient_id}_dmri_AP.bvec",
"dwi_pa": "{patient_id}_dmri_PA.nii.gz",
"diffusion_dir": "T1w/Diffusion",
"bedpostx_dir": "T1w/Diffusion.bedpostX",
"t1_final": "T1w/T1w_acpc_dc_restore_1.25.nii.gz",
"t2_final": "T1w/T2w_acpc_dc_restore.nii.gz",
}
[docs]
def prepare_preprocess_paths(config: PipelineConfig, context: ProcessingContext) -> PreprocessPaths:
"""Resolve preprocessing input/output paths from config and context."""
preproc_cfg = getattr(config, "preprocess", None)
if not isinstance(preproc_cfg, PreprocessConfig):
preproc_cfg = None
input_dir = (
Path(context.input_dir).resolve() if context.input_dir is not None else Path(".").resolve()
)
output_dir = (
Path(context.output_dir).resolve()
if context.output_dir is not None
else input_dir / "derivatives"
)
patient_id = context.patient_id if context.patient_id is not None else ""
def _pat(key: str) -> Optional[str]:
raw = (
getattr(preproc_cfg, key, _INPUT_DEFAULTS[key]) if preproc_cfg else _INPUT_DEFAULTS[key]
)
return raw.replace("{patient_id}", patient_id) if raw else None
def _resolve(key: str) -> Optional[Path]:
pat = _pat(key)
return resolve_path(input_dir, pat) if pat else None
t1_image = _resolve("t1_image")
t2_image = _resolve("t2_image")
dwi_ap = _resolve("dwi_ap")
dwi_ap_bval = _resolve("dwi_ap_bval")
dwi_ap_bvec = _resolve("dwi_ap_bvec")
dwi_pa = _resolve("dwi_pa")
for label, path in [
("T1 image", t1_image),
("DWI AP image", dwi_ap),
("DWI AP bval file", dwi_ap_bval),
("DWI AP bvec file", dwi_ap_bvec),
]:
if path is not None and not path.exists():
logger.warning("{} not found: {}", label, path)
assert dwi_ap is not None and dwi_ap_bval is not None and dwi_ap_bvec is not None
return {
"input_dir": input_dir,
"output_dir": output_dir,
"t1_image": t1_image,
"t2_image": t2_image,
"dwi_ap": dwi_ap,
"dwi_ap_bval": dwi_ap_bval,
"dwi_ap_bvec": dwi_ap_bvec,
"dwi_pa": dwi_pa,
"diffusion_dir": output_dir / (_pat("diffusion_dir") or ""),
"bedpostx_dir": output_dir / (_pat("bedpostx_dir") or ""),
"t1_final": output_dir / (_pat("t1_final") or ""),
"t2_final": output_dir / (_pat("t2_final") or "") if t2_image is not None else None,
}