"""
Pydantic models for configuration validation.
These models ensure that configurations have the correct structure
and valid values before being used in processing pipelines.
"""
import re
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union, cast
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from thesis.core.logging import get_logger
from thesis.core.utils import to_path
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Enums (defined here to avoid circular imports with workflow modules)
# ---------------------------------------------------------------------------
[docs]
class NormalizationMethod(str, Enum):
"""Method for normalizing streamline density volumes before atlas combination.
Attributes:
WAYTOTAL: Divide by waytotal (total streamline count). Default method
for the FSL ProbTrackX2 backend; produces values representing the
fraction of total streamlines passing through each voxel.
MAX: Divide by the maximum voxel value in the volume.
Scales each volume to [0, 1] range.
SOFTMAX: Apply softmax normalization (exp(x) / sum(exp(x))).
Produces a probability distribution that sums to 1.0 across the volume.
STREAMLINE_DENSITY: MRtrix3 counterpart of ``WAYTOTAL``. Divides by
the value stored in ``waytotal``, which the MRtrix3 workflow
writes as the sum of SIFT2 per-streamline weights when SIFT2 is
enabled (matching the SIFT2-weighted TDI numerator) or the raw
streamline count when SIFT2 is off. Produces a streamline
fraction in [0, 1].
"""
WAYTOTAL = "waytotal"
MAX = "max"
SOFTMAX = "softmax"
STREAMLINE_DENSITY = "streamline_density"
# ---------------------------------------------------------------------------
# Shared allowed-value sets for registration (referenced by both
# RegistrationConfig and the per-job RegistrationJobConfig so the two models
# never drift).
# ---------------------------------------------------------------------------
REGISTRATION_METHODS = ["ants", "fireants", "fsl", "dipy"]
REGISTRATION_MOVING_MODALITIES = ["t1", "t2"]
REGISTRATION_INTERPOLATIONS = [
"Linear",
"NearestNeighbor",
"BSpline",
"MultiLabel",
"Gaussian",
"CosineWindowedSinc",
"WelchWindowedSinc",
"HammingWindowedSinc",
"LanczosWindowedSinc",
"GenericLabel",
]
REGISTRATION_TRANSFORM_TYPES = ["Rigid", "Affine", "SyN"]
__all__ = [
"NormalizationMethod",
"BaseConfig",
"PathConfig",
"HardwareConfig",
"AtlasConfig",
"AtlasQCConfig",
"S3Config",
"NipypeConfig",
"PreprocessingConfig",
"FireantsRegistrationConfig",
"RegistrationViewerConfig",
"RegistrationJobConfig",
"RegistrationConfig",
"SegmentationConfig",
"AtlasSourceConfig",
"AtlasTransformConfig",
"TransformJobConfig",
"TransformsConfig",
"SynthSegConfig",
"TractographyConfig",
"ValidationConfig",
"QCConfig",
"OutputSettingsConfig",
"SideThresholdConfig",
"HcpLooConfig",
"TractSimilarityConfig",
"ThresholdGridConfig",
"TractSimilaritySweepConfig",
"HCPConfig",
"PipelineConfig",
]
[docs]
class BaseConfig(BaseModel):
"""Base config model that rejects unknown fields to catch typos early."""
model_config = ConfigDict(extra="forbid")
[docs]
class PathConfig(BaseConfig):
"""Configuration for file paths and directories.
The three read-only/output roots are independent — none is a parent of
the others:
Attributes:
inputs_dir: Per-patient source data root. The patient's input
directory is ``inputs_dir / <patient_id>``.
assets_dir: Cohort-shared, read-only assets (templates, atlases,
ROIs, reference images). Anchors ``DataFile``/``DataDir`` lookups.
output_dir: Root for all workflow outputs/derivatives. The patient's
output directory is ``output_dir / <patient_id>``.
scratch_dir: Optional temporary/scratch directory override. When
unset, scratch defaults to ``output_dir / <patient_id> / temp``.
log_dir: Directory for runtime logs.
scripts_dir: Optional directory of user-supplied workflow scripts
(``*.py``).
"""
inputs_dir: Path = Field(
default=Path("data/raw"),
description="Per-patient source data root (input_dir = inputs_dir/<patient_id>)",
)
assets_dir: Path = Field(
default=Path("data"),
description="Shared read-only assets (templates, atlases, ROIs)",
)
output_dir: Path = Field(
default=Path("outputs"),
description="Root for all workflow outputs (output_dir/<patient_id>)",
)
scratch_dir: Optional[Path] = Field(
default=None, description="Optional temporary/scratch directory override"
)
log_dir: Path = Field(default=Path("logs"), description="Directory for log files")
scripts_dir: Optional[Path] = Field(
default=None,
description=(
"Optional directory of user-supplied workflow scripts (*.py). "
"When set, `thesis list-workflows` scans the directory and lists "
"any scripts whose @workflow decorator registers successfully."
),
)
[docs]
@field_validator("*", mode="before")
@classmethod
def convert_to_path(cls, v):
"""Convert string paths to Path objects, expanding ~ and env vars."""
return to_path(v)
[docs]
class HardwareConfig(BaseConfig):
"""Configuration for hardware and computational resources.
Attributes:
threads: CPU threads available to workflow execution.
memory_gb: Memory budget in gigabytes.
gpu_enabled: Whether GPU-aware execution is enabled.
gpu_device: Optional GPU device index.
n_gpu_procs: Number of scheduler GPU worker slots.
n_gpus: Physical GPUs exposed per worker slot.
"""
threads: int = Field(default=4, ge=1, le=128, description="Number of CPU threads to use")
memory_gb: int = Field(default=16, ge=1, le=1024, description="Maximum memory in GB")
gpu_enabled: bool = Field(default=False, description="Whether to use GPU acceleration")
gpu_device: Optional[int] = Field(default=None, description="GPU device ID (if multiple GPUs)")
n_gpu_procs: int = Field(
default=1,
ge=1,
description="Nipype GPU worker slots (concurrent GPU nodes allowed by scheduler)",
)
n_gpus: int = Field(
default=1,
ge=1,
description="Physical GPUs visible per worker slot",
)
[docs]
class AtlasConfig(BaseConfig):
"""Atlas workflow configuration."""
presence_value: float = Field(
default=0.10,
ge=0.0,
le=1.0,
description=(
"Threshold for prob_threshold calculation. A voxel is counted as "
"'present' when its normalised value exceeds this threshold."
),
)
cov_mean_threshold_pct: float = Field(
default=0.01,
ge=0.0,
le=1.0,
description=(
"Global atlas mean-map percentage used to suppress low-signal voxels "
"when computing cov. CV is only computed where mean exceeds this "
"fraction of the atlas mean maximum."
),
)
normalization_method: NormalizationMethod = Field(
default=NormalizationMethod.WAYTOTAL,
description=(
"Method for normalizing streamline density volumes before atlas "
"combination. Options: 'waytotal' (FSL ProbTrackX2 — divide by "
"total streamline count), 'streamline_density' (MRtrix3 — divide "
"by SIFT2 weight sum or raw count, matching the TDI units), "
"'max' (divide by maximum voxel value), 'softmax' (probability "
"distribution summing to 1.0)."
),
)
tractography_relpath: str = Field(
default="tractography/probtrackx2",
description=(
"Relative path under each patient directory where per-patient "
"tractography output lives. The atlas workflow expects "
"`<patient>/<tractography_relpath>/<run>/warped_streamlines/"
"fdt_paths.nii.gz` (+ `waytotal`). Override to "
"'tractography/mrtrix3' for the MRtrix3 backend."
),
)
output_subdir: str = Field(
default="atlas",
description=(
"Subdirectory under the cohort output directory where the "
"statistical atlas (atlas_mean.nii.gz et al.) is written. "
"Override per backend (e.g. 'atlas_mrtrix3') when running "
"probtrackx2 and mrtrix3 into the same outputs/ tree to "
"avoid clobbering."
),
)
[docs]
class S3Config(BaseConfig):
"""Configuration for S3 data download.
Attributes:
enabled: Whether S3 download is enabled.
bucket: S3 bucket name.
region: AWS region.
prefix: Bucket prefix (e.g., 'HCP_1200').
cache_policy: Download behavior when files exist locally.
max_retries: Maximum retry attempts for failed downloads.
retry_backoff: Exponential backoff multiplier for retries.
required_patterns: File patterns that must be downloaded.
optional_patterns: File patterns that should be downloaded if available.
"""
enabled: bool = Field(
default=False,
description="Enable automatic S3 data download for HCP workflows",
)
bucket: str = Field(
default="hcp-openaccess",
description="S3 bucket name containing HCP data",
)
region: str = Field(
default="us-east-1",
description="AWS region where the S3 bucket is located",
)
prefix: str = Field(
default="HCP_1200",
description="S3 key prefix (folder path within bucket)",
)
cache_policy: str = Field(
default="skip_if_exists",
description=(
"Cache behavior: 'skip_if_exists' (don't re-download), "
"'check_size' (re-download if size differs), "
"'always' (always re-download)"
),
)
max_retries: int = Field(
default=3,
ge=0,
description="Maximum number of retry attempts for failed downloads",
)
retry_backoff: float = Field(
default=2.0,
ge=1.0,
description="Exponential backoff multiplier for retry delays",
)
required_patterns: List[str] = Field(
default=[
"T1w/Diffusion/data.nii*",
"T1w/Diffusion/bvals",
"T1w/Diffusion/bvecs",
"T1w/Diffusion/nodif_brain_mask.nii*",
"T1w/Diffusion.bedpostX/merged_*samples.nii.gz",
"T1w/T1w_acpc_dc_restore_1.25.nii.gz",
],
description="File patterns (glob) that must be downloaded for workflow to succeed",
)
optional_patterns: List[str] = Field(
default=[
"T1w/T1w_acpc_dc_restore_brain.nii.gz",
"T1w/brainmask_fs.nii.gz",
],
description="File patterns (glob) to download if available (won't fail if missing)",
)
[docs]
@field_validator("cache_policy")
@classmethod
def validate_cache_policy(cls, v: str) -> str:
"""Validate cache policy is one of the allowed values."""
allowed = {"skip_if_exists", "check_size", "always"}
if v not in allowed:
raise ValueError(f"cache_policy must be one of {allowed}, got '{v}'")
return v
[docs]
class PreprocessingConfig(BaseConfig):
"""Configuration for preprocessing pipeline.
Attributes:
denoise: Whether denoising is enabled.
bias_correction: Whether bias correction is enabled.
brain_extraction: Whether brain extraction is enabled.
brain_extraction_method: Selected extraction backend.
motion_correction: Whether motion correction is enabled.
eddy_correction: Whether eddy-current correction is enabled.
topup: Whether TOPUP distortion correction is enabled.
acq_params: Optional TOPUP acquisition parameters.
"""
denoise: bool = Field(default=True, description="Whether to apply denoising")
bias_correction: bool = Field(
default=True, description="Whether to apply bias field correction"
)
brain_extraction: bool = Field(default=True, description="Whether to perform brain extraction")
brain_extraction_method: str = Field(
default="synthstrip", description="Brain extraction method (bet, synthstrip)"
)
motion_correction: bool = Field(default=True, description="Whether to apply motion correction")
eddy_correction: bool = Field(
default=True, description="Whether to apply eddy current correction"
)
topup: bool = Field(default=True, description="Whether to run TOPUP for distortion correction")
acq_params: Optional[Dict[str, Any]] = Field(
default=None, description="Acquisition parameters for TOPUP"
)
[docs]
@field_validator("brain_extraction_method")
@classmethod
def validate_bet_method(cls, v):
"""Validate brain extraction method."""
valid_methods = ["bet", "synthstrip", "ants"]
if v not in valid_methods:
raise ValueError(f"brain_extraction_method must be one of {valid_methods}")
return v
[docs]
class RegistrationViewerConfig(BaseConfig):
"""Configuration for registration QC viewer launching.
Attributes:
enabled: Whether viewer launching is enabled.
backend: Viewer backend name.
auto_open: Whether the viewer should be launched automatically.
overlay_opacity: Default overlay opacity for the transformed patient image.
"""
enabled: bool = Field(
default=True,
description="Whether registration QC viewer support is enabled",
)
backend: str = Field(default="fsleyes", description="Viewer backend name")
auto_open: bool = Field(default=False, description="Launch viewer automatically after run")
overlay_opacity: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Overlay opacity for the transformed patient image",
)
[docs]
@field_validator("backend")
@classmethod
def validate_backend(cls, v: str) -> str:
"""Validate the registration QC viewer backend."""
valid_methods = ["fsleyes", "html"]
if v not in valid_methods:
raise ValueError(f"backend must be one of {valid_methods}")
return v
[docs]
class FireantsRegistrationConfig(BaseConfig):
"""Configuration for the FireANTs registration backend.
Attributes:
device: Torch device string used by FireANTs.
scales: Multi-resolution registration scales.
affine_iterations: Iterations per scale for affine registration.
deformable_iterations: Iterations per scale for deformable registration.
optimizer: FireANTs optimizer name.
affine_lr: Learning rate for affine registration.
deformable_lr: Learning rate for deformable registration.
cc_kernel_size: Cross-correlation kernel size.
deformation_type: FireANTs deformation model.
dtype: Torch dtype name.
loss_type: Similarity metric for the staged registration (cc/mi/mse/
fusedcc/fusedmi); auto-fused on CUDA.
normalize: Min-max normalize intensities to [0, 1] before registration.
"""
device: str = Field(default="cuda", description="Torch device used by FireANTs")
scales: List[int] = Field(
default=[4, 2, 1],
description="Multi-resolution scales used by FireANTs",
)
affine_iterations: List[int] = Field(
default=[200, 100, 50],
description="Affine iterations per FireANTs scale",
)
deformable_iterations: List[int] = Field(
default=[200, 100, 25],
description="Deformable iterations per FireANTs scale",
)
optimizer: str = Field(default="Adam", description="FireANTs optimizer name")
affine_lr: float = Field(default=3e-3, gt=0.0, description="Affine optimizer learning rate")
deformable_lr: float = Field(
default=0.5,
gt=0.0,
description="Deformable optimizer learning rate",
)
cc_kernel_size: int = Field(
default=5,
ge=1,
description="Cross-correlation kernel size",
)
loss_type: str = Field(
default="cc",
description=(
"Similarity metric for the rigid/affine/deformable stages: 'cc' "
"(local normalized cross-correlation, same-modality), 'mi' (mutual "
"information, robust to contrast differences / cross-dataset), 'mse', "
"or the GPU-fused variants 'fusedcc'/'fusedmi'. On a CUDA device "
"'cc'/'mi' are auto-upgraded to their fused equivalents; on CPU the "
"fused variants fall back to non-fused."
),
)
normalize: bool = Field(
default=True,
description=(
"Min-max normalize fixed and moving intensities to [0, 1] before "
"registration (mirrors FireANTs' own template pipeline). Strongly "
"recommended when the two images come from different sources / "
"intensity ranges (e.g. MNI vs a study template)."
),
)
deformation_type: str = Field(
default="compositive",
description="FireANTs deformation type",
)
dtype: str = Field(default="float32", description="Torch dtype used by FireANTs")
do_moments: bool = Field(
default=True,
description=(
"Run the FireANTs MomentsRegistration initialization stage "
"(center-of-mass + principal-axis alignment) before rigid/affine."
),
)
do_rigid: bool = Field(
default=True,
description="Run the FireANTs RigidRegistration stage before affine/deformable.",
)
moments_scale: int = Field(
default=4,
ge=1,
description=(
"Downsampling scale used by the FireANTs MomentsRegistration stage. "
"FireANTs recommends 4 (the moments init is robust at low resolution)."
),
)
moments_moments: int = Field(
default=1,
ge=1,
description=(
"Number of moments used by the FireANTs MomentsRegistration stage. "
"FireANTs recommends 1 (center-of-mass + first-order); 2 adds a "
"second-order principal-axis term that is prone to 180-degree flips "
"on near-symmetric brains."
),
)
rigid_iterations: List[int] = Field(
default=[200, 100, 25],
description="Rigid iterations per FireANTs scale.",
)
rigid_lr: float = Field(
default=3e-2,
gt=0.0,
description="Rigid optimizer learning rate.",
)
deform_algo: str = Field(
default="greedy",
description=(
"Deformable engine used for SyN-family transforms: 'greedy' (default, "
"yields an analytic inverse warp) or 'syn' (no analytic-inverse export)."
),
)
driver: str = Field(
default="inprocess",
description=(
"How the FireANTs staged registration is driven: 'inprocess' (Python "
"API in the Nipype Function node) or 'cli' (shell out to the "
"thesis-fireants-register console script)."
),
)
deformable_max_spacing_mm: Optional[float] = Field(
default=None,
gt=0.0,
description=(
"Cap the spacing (mm) of the fixed/moving images fed to the deformable "
"step. When the template has finer spacing than this cap, both images "
"are resampled to it before FireANTs runs. Default 'null' = no cap "
"(deformable runs at the template's native resolution; best quality, "
"but the template's full grid must fit on the GPU). Set to e.g. 1.0 "
"to trade quality for VRAM headroom."
),
)
inverse_method: str = Field(
default="simpleitk",
description=(
"How the reverse (template->patient) warp is produced for SyN-family "
"transforms: 'simpleitk' (default) numerically inverts the forward "
"displacement field with SimpleITK's InvertDisplacementFieldImageFilter "
"— ANTs-free, has an explicit convergence tolerance, and works for both "
"'greedy' and 'syn'; 'fireants' uses FireANTs' own inverse-consistency "
"solve (greedy only; 'syn' has no analytic inverse). 'simpleitk' is "
"recommended: it inverts the high-quality forward field directly instead "
"of re-optimizing an under-constrained reverse warp."
),
)
inverse_max_iterations: int = Field(
default=50,
ge=1,
description=(
"Maximum fixed-point iterations for the SimpleITK displacement-field "
"inversion (inverse_method='simpleitk')."
),
)
inverse_tolerance: float = Field(
default=0.01,
gt=0.0,
description=(
"Mean-error tolerance (in displacement units, i.e. mm) for the SimpleITK "
"displacement-field inversion. Lower = tighter inverse, more iterations."
),
)
[docs]
@field_validator("inverse_method")
@classmethod
def validate_inverse_method(cls, v: str) -> str:
"""Validate the reverse-warp reconstruction method."""
valid_methods = ["simpleitk", "fireants"]
if v not in valid_methods:
raise ValueError(f"inverse_method must be one of {valid_methods}")
return v
[docs]
@field_validator("driver")
@classmethod
def validate_driver(cls, v: str) -> str:
"""Validate the FireANTs driver."""
valid_drivers = ["inprocess", "cli"]
if v not in valid_drivers:
raise ValueError(f"driver must be one of {valid_drivers}")
return v
[docs]
@field_validator("optimizer")
@classmethod
def validate_optimizer(cls, v: str) -> str:
"""Validate the FireANTs optimizer."""
valid_optimizers = ["Adam", "SGD"]
if v not in valid_optimizers:
raise ValueError(f"optimizer must be one of {valid_optimizers}")
return v
[docs]
@field_validator("dtype")
@classmethod
def validate_dtype(cls, v: str) -> str:
"""Validate the FireANTs torch dtype."""
valid_dtypes = ["float16", "float32", "bfloat16"]
if v not in valid_dtypes:
raise ValueError(f"dtype must be one of {valid_dtypes}")
return v
[docs]
@field_validator("loss_type")
@classmethod
def validate_loss_type(cls, v: str) -> str:
"""Validate the FireANTs similarity metric."""
valid_loss_types = ["cc", "mi", "mse", "fusedcc", "fusedmi"]
if v not in valid_loss_types:
raise ValueError(f"loss_type must be one of {valid_loss_types}")
return v
[docs]
@model_validator(mode="after")
def validate_scale_lengths(self) -> "FireantsRegistrationConfig":
"""Ensure per-scale iteration lists match the configured scales."""
n_scales = len(self.scales)
if len(self.affine_iterations) != n_scales:
raise ValueError("affine_iterations must have the same length as scales")
if len(self.deformable_iterations) != n_scales:
raise ValueError("deformable_iterations must have the same length as scales")
if len(self.rigid_iterations) != n_scales:
raise ValueError("rigid_iterations must have the same length as scales")
return self
[docs]
class RegistrationJobConfig(BaseConfig):
"""Configuration for a single named registration job.
A job overrides the shared ``RegistrationConfig`` defaults for one
patient-to-template registration. All override fields default to ``None``
(meaning "inherit the shared value"); only ``name`` is required. The
``fireants`` field, when set, is a sparse dict merged onto
``registration.fireants`` via ``model_copy(update=...)`` (which re-validates).
Attributes:
name: Unique job name used in node/output naming and referenced by
``transforms.jobs[*].from_registration``.
method: Optional registration-backend override.
moving_modality: Optional moving-image modality override.
moving_image: Optional explicit moving-image path override.
fixed_image: Optional fixed-image path override.
interpolation: Optional interpolation-mode override.
metric: Optional similarity-metric override.
transform_type: Optional transform-family override.
use_float: Optional float-precision override.
fireants: Optional sparse FireANTs override block.
"""
name: str = Field(description="Unique registration job name used in node/output naming")
method: Optional[str] = Field(default=None, description="Registration backend override")
moving_modality: Optional[str] = Field(
default=None, description="Moving-image modality override (t1, t2)"
)
moving_image: Optional[str] = Field(
default=None, description="Explicit moving-image path override"
)
fixed_image: Optional[str] = Field(default=None, description="Fixed template image override")
interpolation: Optional[str] = Field(default=None, description="Interpolation method override")
metric: Optional[str] = Field(default=None, description="Similarity metric override")
transform_type: Optional[str] = Field(
default=None, description="Transform family override (Rigid, Affine, SyN)"
)
use_float: Optional[bool] = Field(
default=None, description="Float-precision override for registration"
)
fireants: Optional[Dict[str, Any]] = Field(
default=None,
description=(
"Sparse FireANTs override block merged onto registration.fireants "
"(re-validated). Only the keys present here override the shared values."
),
)
[docs]
@field_validator("method")
@classmethod
def validate_method(cls, v: Optional[str]) -> Optional[str]:
"""Validate the per-job registration method (None inherits the shared value)."""
if v is not None and v not in REGISTRATION_METHODS:
raise ValueError(f"method must be one of {REGISTRATION_METHODS}")
return v
[docs]
@field_validator("moving_modality")
@classmethod
def validate_moving_modality(cls, v: Optional[str]) -> Optional[str]:
"""Validate the per-job moving-image modality (None inherits the shared value)."""
if v is not None and v not in REGISTRATION_MOVING_MODALITIES:
raise ValueError(f"moving_modality must be one of {REGISTRATION_MOVING_MODALITIES}")
return v
[docs]
@field_validator("interpolation")
@classmethod
def validate_interpolation(cls, v: Optional[str]) -> Optional[str]:
"""Validate the per-job interpolation method (None inherits the shared value)."""
if v is not None and v not in REGISTRATION_INTERPOLATIONS:
raise ValueError(f"interpolation must be one of {REGISTRATION_INTERPOLATIONS}")
return v
[docs]
class RegistrationConfig(BaseConfig):
"""Configuration for image registration.
Attributes:
enabled: Whether the registration workflow is enabled.
method: Registration backend.
moving_modality: Structural modality used as the moving image.
moving_image: Optional explicit moving-image override.
fixed_image: Template-space fixed image.
interpolation: Interpolation mode.
metric: Similarity metric.
transform_type: Transform family.
use_float: Whether float precision is used.
collapse_output_transforms: Whether ANTs should collapse output transforms.
write_composite_transform: Whether ANTs should emit composite transforms.
output_subdir: Output subdirectory under the patient output dir.
fireants: FireANTs backend settings.
viewer: Registration QC viewer settings.
"""
enabled: bool = Field(default=True, description="Whether registration is enabled")
method: str = Field(default="ants", description="Registration method (ants, fireants, fsl)")
moving_modality: str = Field(
default="t1", description="Structural modality for the moving image (t1, t2)"
)
moving_image: Optional[str] = Field(
default=None,
description="Optional explicit moving-image path override",
)
fixed_image: str = Field(default="", description="Fixed template image path")
interpolation: str = Field(default="Linear", description="Interpolation method")
metric: str = Field(default="MI", description="Similarity metric (MI, CC, etc.)")
transform_type: str = Field(default="SyN", description="Type of transform (Rigid, Affine, SyN)")
use_float: bool = Field(default=True, description="Use float precision for registration")
collapse_output_transforms: bool = Field(
default=True,
description="Collapse ANTs output transforms where possible",
)
write_composite_transform: bool = Field(
default=True,
description="Write ANTs composite transform outputs",
)
output_subdir: str = Field(
default="registration",
description="Patient output subdirectory for registration artifacts",
)
fireants: FireantsRegistrationConfig = Field(
default_factory=FireantsRegistrationConfig,
description="FireANTs backend settings",
)
viewer: RegistrationViewerConfig = Field(
default_factory=RegistrationViewerConfig,
description="Registration QC viewer settings",
)
jobs: List["RegistrationJobConfig"] = Field(
default_factory=list,
description=(
"Named registration jobs. When empty, the top-level fields above are "
"treated as a single implicit job named 'patient_to_template'. Each "
"job may override the shared defaults and supply a sparse 'fireants' "
"block merged onto registration.fireants."
),
)
[docs]
@field_validator("method")
@classmethod
def validate_method(cls, v):
"""Validate registration method."""
if v not in REGISTRATION_METHODS:
raise ValueError(f"method must be one of {REGISTRATION_METHODS}")
return v
[docs]
@field_validator("moving_modality")
@classmethod
def validate_moving_modality(cls, v: str) -> str:
"""Validate the selected moving-image modality."""
if v not in REGISTRATION_MOVING_MODALITIES:
raise ValueError(f"moving_modality must be one of {REGISTRATION_MOVING_MODALITIES}")
return v
@model_validator(mode="after")
def _validate_unique_job_names(self) -> "RegistrationConfig":
"""Ensure all registration job names are unique."""
names = [job.name for job in self.jobs]
duplicates = {name for name in names if names.count(name) > 1}
if duplicates:
raise ValueError(f"registration.jobs names must be unique; duplicates: {duplicates}")
return self
[docs]
class SegmentationConfig(BaseConfig):
"""Configuration for segmentation.
Attributes:
method: Segmentation backend.
tissue_types: Tissue classes to segment.
create_masks: Whether binary masks are emitted.
labels: Optional label-name to integer mapping.
"""
method: str = Field(default="synthseg", description="Segmentation method")
tissue_types: List[str] = Field(
default=["GM", "WM", "CSF"], description="Tissue types to segment"
)
create_masks: bool = Field(default=True, description="Whether to create binary masks")
labels: Optional[Dict[str, int]] = Field(
default=None, description="Label mapping (name -> value)"
)
[docs]
@field_validator("method")
@classmethod
def validate_method(cls, v):
"""Validate segmentation method."""
valid_methods = ["synthseg", "fast", "freesurfer"]
if v not in valid_methods:
raise ValueError(f"method must be one of {valid_methods}")
return v
[docs]
class AtlasSourceConfig(BaseConfig):
"""Configuration for one atlas label-map source.
Attributes:
name: Unique source name.
roi_file: Template-space atlas image path.
transform: Optional named transform to use for this source.
label_file: Optional label CSV path.
waypoint_labels: ROI extraction mapping for this source.
"""
name: str = Field(description="Unique atlas source name used in node/output naming")
roi_file: str = Field(description="Template-space atlas image path")
transform: Optional[str] = Field(
default=None,
description=(
"Optional named transform from transforms.atlas_transforms " "to use for this atlas"
),
)
label_file: Optional[str] = Field(
default=None,
description="Optional atlas label CSV path; required when label_name entries are used",
)
waypoint_labels: Dict[str, Any] = Field(
default_factory=dict,
description="ROI label configuration for this atlas source",
)
[docs]
class TractographyConfig(BaseConfig):
"""Configuration for tractography.
Attributes:
method: Tractography backend.
run_tractography: Optional workflow-level execution toggle.
n_samples: Streamlines per seed voxel.
n_steps: Maximum steps per streamline.
step_length: Step length in millimetres.
curvature_threshold: Maximum curvature threshold.
hemisphere: Which hemisphere(s) to run tractography for.
atlas_sources: Optional atlas ROI source list.
roi_labels: Legacy single-atlas ROI configuration.
synthseg_roi_labels: Optional SynthSeg-derived ROI configuration.
force_dir: Whether ProbTrackX2 uses ``--forcedir``.
opd: Whether ProbTrackX2 writes orientation distribution outputs.
"""
method: str = Field(default="probtrackx2", description="Tractography method")
run_tractography: Optional[bool] = Field(
default=None, description="Whether to run tractography in HCP workflow"
)
n_samples: int = Field(default=5000, ge=100, description="Number of streamlines per voxel")
n_steps: int = Field(default=2000, ge=100, description="Maximum number of steps")
step_length: float = Field(default=0.5, ge=0.1, le=2.0, description="Step length in mm")
curvature_threshold: float = Field(
default=0.2, ge=0.0, le=1.0, description="Curvature threshold (0-1)"
)
loop_check: bool = Field(
default=True,
description="Perform loop checks on paths (slower, but allows lower curvature threshold)",
)
dist_thresh: float = Field(
default=0.0,
ge=0.0,
description="Discard samples shorter than this threshold in mm (0 = no limit)",
)
fibst: int = Field(
default=1,
ge=1,
description="Starting fibre for tracking (only effective when rand_fib=0)",
)
rand_fib: int = Field(
default=0,
ge=0,
le=3,
description=(
"Fibre sampling mode: 0=default, 1=random (f>thresh), "
"2=proportional (f>thresh), 3=all populations"
),
)
mod_euler: bool = Field(
default=False,
description="Use modified Euler streamlining",
)
hemisphere: str = Field(
default="both",
description=(
"Which hemisphere(s) to run tractography for: "
"left, right, both (merged), or both-separately (two independent runs)"
),
)
mem_gb_gpu: float = Field(
default=8.0,
gt=0,
description="Nipype scheduler memory hint in GB for GPU ProbTrackX2 nodes",
)
mem_gb_cpu: float = Field(
default=4.0,
gt=0,
description="Nipype scheduler memory hint in GB for CPU ProbTrackX2 nodes",
)
gpu_runtime_env: Optional[Dict[str, str]] = Field(
default=None,
description=(
"Environment variables injected ONLY into the GPU probtrackx2 "
"subprocess (the Nipype CommandLine node's ``.inputs.environ``), and "
"thus only into its ``singularity exec --cleanenv`` FSL container. "
"Keys MUST carry a SINGULARITYENV_/APPTAINERENV_ prefix to survive "
"``--cleanenv`` and reach inside the container. Never applied to the "
"ANTs or SynthSeg nodes, and ignored for CPU runs. Example: "
"{'SINGULARITYENV_LD_PRELOAD': '/path/HAMi-core/build/libvgpu.so', "
"'SINGULARITYENV_CUDA_DEVICE_MEMORY_LIMIT': '11g'} to cap each "
"probtrackx2 to 11 GB so several co-locate on one GPU without OOM."
),
)
use_waypoints: bool = Field(default=True, description="Whether to use waypoint masks")
waypoint_files: Optional[Dict[str, Path]] = Field(
default=None, description="Paths to waypoint mask files"
)
roi_files: Optional[List[Path]] = Field(
default=None, description="Template-space ROI files to transform"
)
seed_roi: Optional[Path] = Field(
default=None, description="Seed ROI (patient or template space depending on workflow)"
)
waypoint_masks: Optional[List[Path]] = Field(default=None, description="Waypoint mask files")
atlas_sources: Optional[List[AtlasSourceConfig]] = Field(
default=None,
description="Template-space atlas sources, each with its own atlas image and labels",
)
roi_labels: Optional[Dict[str, Any]] = Field(
default=None,
description="Label-map ROI configuration (label_file, roi_file, waypoint_labels)",
)
synthseg_roi_labels: Optional[Dict[str, Any]] = Field(
default=None,
description="SynthSeg-sourced ROI label configuration (waypoint_labels with label_values)",
)
force_dir: bool = Field(default=True, description="ProbTrackX2 --forcedir flag")
opd: bool = Field(default=True, description="ProbTrackX2 --opd flag")
# MRtrix3-specific tckgen parameters (ignored when method != "mrtrix3")
tckgen_algorithm: str = Field(
default="iFOD2",
description="tckgen tracking algorithm (iFOD2, iFOD1, SD_STREAM, Tensor_Det, ...)",
)
tckgen_select: int = Field(
default=5000,
ge=1,
description="tckgen target streamline count (-select)",
)
tckgen_seeds: int = Field(
default=5_000_000,
ge=1,
description="tckgen seed attempt cap (-seeds); 0 means unlimited",
)
tckgen_minlength: float = Field(
default=30.0,
ge=0.0,
description="tckgen minimum streamline length in mm (-minlength)",
)
tckgen_maxlength: float = Field(
default=250.0,
gt=0.0,
description="tckgen maximum streamline length in mm (-maxlength)",
)
tckgen_backtrack: bool = Field(
default=True,
description="tckgen -backtrack flag (improves ACT-based tracking)",
)
tckgen_crop_at_gmwmi: bool = Field(
default=True,
description="tckgen -crop_at_gmwmi flag (precise endpoint placement)",
)
tckgen_cutoff: Optional[float] = Field(
default=None,
ge=0.0,
description="tckgen FOD amplitude cutoff (-cutoff). None uses MRtrix default.",
)
response_algorithm: str = Field(
default="dhollander",
description="dwi2response algorithm (dhollander | tournier | msmt_5tt | fa | manual)",
)
fod_algorithm: str = Field(
default="msmt_csd",
description="dwi2fod algorithm (msmt_csd for multi-shell, csd for single-shell)",
)
use_mtnormalise: bool = Field(
default=True,
description="Run mtnormalise after dwi2fod to correct intensity bias across tissues",
)
fivett_algorithm: str = Field(
default="fsl",
description="5ttgen backend (fsl | freesurfer | hsvs)",
)
mask_source: str = Field(
default="fsl_nodif",
description=(
"Brain-mask source: 'fsl_nodif' reuses the HCP nodif_brain_mask, "
"'dwi2mask' generates one from the DWI."
),
)
mask_dilate_voxels: int = Field(
default=1,
ge=0,
description="Number of dilation passes applied to the brain mask (MRtrix ACT recommends 1)",
)
use_sift2: bool = Field(
default=True,
description="Run tcksift2 after tckgen to weight streamlines for tckmap",
)
seed_strategy: str = Field(
default="roi",
description=(
"tckgen seed source (MRtrix3 only): 'roi' seeds from the atlas/"
"synthseg seed mask (default); 'gmwmi' seeds from the GM-WM interface "
"via -seed_gmwmi (requires ACT, which is always enabled here)."
),
)
gmwmi_apply_roi_filters: bool = Field(
default=True,
description=(
"When seed_strategy='gmwmi', still apply any configured waypoint/avoid/"
"stop/target masks as tckgen -include/-exclude/-mask filters. False "
"forces pure whole-brain tracking. Ignored when seed_strategy='roi'."
),
)
[docs]
@field_validator("method")
@classmethod
def validate_method(cls, v):
"""Validate tractography method."""
valid_methods = ["probtrackx2", "tckgen", "dsi_studio", "mrtrix3"]
if v not in valid_methods:
raise ValueError(f"method must be one of {valid_methods}")
return v
[docs]
@field_validator("gpu_runtime_env")
@classmethod
def warn_unprefixed_gpu_runtime_env(
cls, v: Optional[Dict[str, str]]
) -> Optional[Dict[str, str]]:
"""Warn on keys that will not survive ``singularity exec --cleanenv``.
Only ``SINGULARITYENV_``/``APPTAINERENV_``-prefixed variables are
forwarded into the container past ``--cleanenv``. A bare key (e.g.
``LD_PRELOAD``) would be set in the wrapper's host shell and then
silently stripped, so it would never reach probtrackx2 — a confusing
no-op we surface here rather than letting it fail at runtime.
"""
if v:
bad = [k for k in v if not k.startswith(("SINGULARITYENV_", "APPTAINERENV_"))]
if bad:
logger.warning(
"tractography.gpu_runtime_env keys {} lack a "
"SINGULARITYENV_/APPTAINERENV_ prefix; they will NOT survive "
"`singularity exec --cleanenv` and will not reach the "
"probtrackx2 container.",
bad,
)
return v
[docs]
@field_validator("response_algorithm")
@classmethod
def validate_response_algorithm(cls, v: str) -> str:
"""Validate dwi2response algorithm."""
valid = ("dhollander", "tournier", "msmt_5tt", "fa", "manual")
if v not in valid:
raise ValueError(f"response_algorithm must be one of {valid}")
return v
[docs]
@field_validator("fod_algorithm")
@classmethod
def validate_fod_algorithm(cls, v: str) -> str:
"""Validate dwi2fod algorithm."""
valid = ("msmt_csd", "csd")
if v not in valid:
raise ValueError(f"fod_algorithm must be one of {valid}")
return v
[docs]
@field_validator("fivett_algorithm")
@classmethod
def validate_fivett_algorithm(cls, v: str) -> str:
"""Validate 5ttgen backend."""
valid = ("fsl", "freesurfer", "hsvs")
if v not in valid:
raise ValueError(f"fivett_algorithm must be one of {valid}")
return v
[docs]
@field_validator("mask_source")
@classmethod
def validate_mask_source(cls, v: str) -> str:
"""Validate brain-mask source."""
valid = ("fsl_nodif", "dwi2mask")
if v not in valid:
raise ValueError(f"mask_source must be one of {valid}")
return v
[docs]
@field_validator("seed_strategy")
@classmethod
def validate_seed_strategy(cls, v: str) -> str:
"""Validate tckgen seed source."""
valid = ("roi", "gmwmi")
if v not in valid:
raise ValueError(f"seed_strategy must be one of {valid}")
return v
[docs]
@field_validator("hemisphere")
@classmethod
def validate_hemisphere(cls, v: str) -> str:
"""Validate hemisphere value."""
valid = ("left", "right", "both", "both-separately")
if v not in valid:
raise ValueError(f"hemisphere must be one of {valid}")
return v
[docs]
class HCPConfig(BaseConfig):
"""Configuration for HCP preprocessed data.
Attributes:
diffusion_dir: Diffusion subdirectory under the subject input directory.
bedpostx_dir: BedpostX output directory.
t1_image: T1-weighted image path.
t2_image: Optional T2-weighted image path.
n_fibers: Number of BedpostX fibres.
mask_name: Default diffusion brain-mask filename.
mask_path: Optional full mask-path override.
"""
diffusion_dir: str = Field(
default="T1w/Diffusion", description="Diffusion data subdirectory relative to input_dir"
)
bedpostx_dir: str = Field(
default="T1w/Diffusion.bedpostX",
description="BedpostX output directory relative to input_dir",
)
t1_image: str = Field(
default="T1w/T1w_acpc_dc_restore_1.25.nii",
description="T1 image path relative to input_dir",
)
t2_image: Optional[str] = Field(
default=None,
description="Optional T2 image path relative to input_dir",
)
n_fibers: int = Field(default=3, description="Number of fibers in BedpostX")
mask_name: str = Field(default="nodif_brain_mask.nii", description="Brain mask filename")
mask_path: Optional[str] = Field(
default=None, description="Brain mask path relative to input_dir (overrides default)"
)
[docs]
class ValidationConfig(BaseConfig):
"""Configuration for post-processing validation checks.
Attributes:
check_rois: Whether warped ROI validation is enabled.
min_voxels: Minimum non-zero voxel count per ROI.
singularity_threshold: Minimum affine determinant magnitude.
volume_ratio_min: Lower acceptable warped-to-original voxel ratio.
volume_ratio_max: Upper acceptable warped-to-original voxel ratio.
"""
check_rois: bool = Field(
default=False,
description="Run automated validation on warped ROI masks after transformation",
)
min_voxels: int = Field(
default=10,
ge=1,
description="Minimum non-zero voxel count; fewer voxels raises PipelineError",
)
singularity_threshold: float = Field(
default=1e-6,
gt=0,
description=(
"Minimum absolute determinant of the affine 3x3 rotation/scale block; "
"below this the ROI is rejected as degenerate (prevents probtrackx2 crash)"
),
)
volume_ratio_min: float = Field(
default=0.5,
gt=0.0,
le=1.0,
description=(
"Lower bound for warped-vs-original voxel count ratio; "
"below this a warning is emitted"
),
)
volume_ratio_max: float = Field(
default=1.5,
ge=1.0,
description=(
"Upper bound for warped-vs-original voxel count ratio; "
"above this a warning is emitted"
),
)
[docs]
class SynthSegConfig(BaseConfig):
"""Configuration for standalone or embedded SynthSeg execution.
Attributes:
t1_image: Optional explicit T1-weighted input path.
parc: Whether cortical parcellation output is requested.
robust: Whether robust mode is enabled.
fast: Whether fast mode is enabled.
vol: Whether a volumes CSV is written.
qc: Whether a QC CSV is written.
crop: Optional crop size.
cpu: Whether CPU mode is forced.
threads: CPU thread count used in CPU mode.
"""
t1_image: Optional[Path] = Field(default=None, description="Input T1 image override")
parc: bool = Field(default=False, description="Request parcellation output")
robust: bool = Field(default=False, description="Enable robust mode")
fast: bool = Field(default=False, description="Enable fast mode")
vol: bool = Field(default=False, description="Write volumes CSV")
qc: bool = Field(default=False, description="Write QC CSV")
crop: Optional[int] = Field(default=None, ge=1, description="Optional crop size")
cpu: bool = Field(default=False, description="Force CPU execution")
threads: int = Field(default=1, ge=1, description="CPU thread count when cpu=true")
gpu_runtime_env: Optional[Dict[str, str]] = Field(
default=None,
description=(
"Environment variables injected ONLY into the GPU SynthSeg subprocess "
"(the mri_synthseg node's ``.inputs.environ``), and thus only into its "
"Singularity exec / FreeSurfer container — never the probtrackx2 or ANTs "
"containers. Ignored in CPU mode. Used to make the container's TensorFlow "
"see the GPU when the image lacks the CUDA runtime: e.g. {'SINGULARITY_NV': "
"'1', 'APPTAINER_NV': '1', 'SINGULARITYENV_LD_LIBRARY_PATH': "
"'/path/cuda118/lib:/.singularity.d/libs', 'SINGULARITYENV_TF_FORCE_GPU_"
"ALLOW_GROWTH': 'true'} (bare SINGULARITY_NV/APPTAINER_NV are read by the "
"singularity CLI to enable --nv; SINGULARITYENV_/APPTAINERENV_ keys are "
"forwarded inside the container past --cleanenv)."
),
)
@field_validator("t1_image", mode="before")
@classmethod
def _convert_t1_to_path(cls, v: Any) -> Optional[Path]:
"""Convert configured T1 image paths to ``Path`` objects."""
return cast(Optional[Path], to_path(v))
[docs]
class NipypeConfig(BaseConfig):
"""Configuration for Nipype workflow execution.
Attributes:
working_dir: Base working directory.
crash_dir: Optional crash dump directory.
plugin: Nipype execution plugin.
plugin_args: Plugin arguments passed to ``workflow.run``.
stop_on_first_crash: Whether execution stops after the first crash.
remove_unnecessary_outputs: Whether intermediate outputs are cleaned.
keep_inputs: Whether node input files remain in the working directory.
hash_method: Caching hash strategy.
use_profiler: Whether the Nipype profiler is enabled.
"""
working_dir: Path = Field(default=Path("work"), description="Base working directory")
crash_dir: Optional[Path] = Field(default=None, description="Crash dump directory")
plugin: str = Field(default="MultiProc", description="Nipype execution plugin")
plugin_args: Dict[str, Any] = Field(default_factory=dict, description="Plugin arguments")
stop_on_first_crash: bool = Field(default=False, description="Stop on first crash")
remove_unnecessary_outputs: bool = Field(default=True, description="Clean intermediate outputs")
keep_inputs: bool = Field(default=True, description="Keep input files in workdir")
hash_method: str = Field(default="content", description="Hash method for caching")
use_profiler: bool = Field(default=False, description="Enable Nipype profiler")
@field_validator("working_dir", "crash_dir", mode="before")
@classmethod
def _convert_to_path(cls, v):
"""Convert string paths to Path objects, expanding ~ and env vars."""
return to_path(v)
[docs]
@field_validator("plugin")
@classmethod
def validate_plugin(cls, v: str) -> str:
"""Validate the Nipype execution plugin against the supported allow-list."""
valid = {"Linear", "MultiProc", "SGE", "PBS", "LSF", "SLURM", "SLURMGraph"}
if v not in valid:
raise ValueError(f"plugin must be one of {sorted(valid)}, got '{v}'")
return v
[docs]
class QCConfig(BaseConfig):
"""Configuration for QC visualisation outputs.
Controls automatic generation of quality-control overlay images
(ROI placement on anatomical backgrounds, track density maps on
template images) after workflow execution, as well as extended
QC checks (SynthSeg quality, outlier detection, etc.).
Attributes:
generate_overlays: Whether to produce ROI overlay PNGs at
the end of the HCP workflow.
track_density_thresholds: Percentile thresholds used when
rendering track density figures (applied to non-zero
voxels of `fdt_paths.nii.gz`).
synthseg_qc_threshold: Minimum acceptable SynthSeg QC score.
Subjects below this threshold are flagged.
outlier_sd_threshold: Number of standard deviations from
the batch mean to flag a subject as an outlier.
"""
generate_overlays: bool = Field(
default=False,
description="Generate ROI overlay PNGs after workflow execution",
)
track_density_thresholds: List[float] = Field(
default=[50.0, 90.0, 99.0],
description="Percentile thresholds for track density map visualisation",
)
synthseg_qc_threshold: float = Field(
default=0.6,
description="Minimum SynthSeg QC score (0-1) before flagging",
)
outlier_sd_threshold: float = Field(
default=2.0,
description="SD threshold for batch outlier detection",
)
[docs]
@field_validator("track_density_thresholds")
@classmethod
def validate_thresholds(cls, v: List[float]) -> List[float]:
"""Validate that all thresholds are valid percentiles in (0, 100)."""
for t in v:
if not (0.0 < t < 100.0):
raise ValueError(f"Each threshold must be between 0 and 100 (exclusive), got {t}")
return sorted(v)
[docs]
class AtlasQCConfig(BaseConfig):
"""Configuration for atlas workflow QC outputs.
Attributes:
enabled: Whether atlas QC generation is enabled.
generate_group_plots: Whether cohort-level atlas summary plots are requested.
generate_patient_reports: Whether per-patient atlas QC outputs are requested.
outlier_sd_threshold: Number of standard deviations from the cohort mean used
to flag atlas-derived outliers.
subject_density_threshold: Minimum normalized streamline density for a subject
voxel to be considered present.
group_core_threshold: Minimum proportion of subjects that must have a valid
streamline (above subject_density_threshold) for a voxel to be part of the core.
leave_one_out: Whether to compute leave-one-out statistics for each subject.
leave_one_out_min_subjects: Minimum number of subjects required to run
leave-one-out validation.
compute_cv: Whether to compute Coefficient of Variation (CV) maps.
cv_mean_threshold: Fraction of atlas mean maximum required to compute CV.
compute_log_stats: Whether to compute statistics in log space.
log_offset: Small offset added to zero values before log transformation.
"""
enabled: bool = Field(default=False, description="Enable atlas QC generation")
generate_group_plots: bool = Field(
default=True,
description="Generate cohort-level atlas QC plots",
)
generate_patient_reports: bool = Field(
default=False,
description="Generate per-patient atlas QC outputs",
)
outlier_sd_threshold: float = Field(
default=2.0,
gt=0.0,
description="SD threshold for atlas QC outlier detection",
)
subject_density_threshold: float = Field(
default=0.01,
ge=0.0,
le=1.0,
description=(
"Minimum normalized streamline density for a subject voxel " "to be considered present"
),
)
group_core_threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Proportion of subjects required for a voxel to be part of the group core",
)
leave_one_out: bool = Field(
default=False,
description="Compute leave-one-out statistics for each subject",
)
leave_one_out_min_subjects: int = Field(
default=3,
ge=2,
description="Minimum number of subjects required to run leave-one-out validation",
)
compute_cv: bool = Field(
default=True,
description="Compute Coefficient of Variation (CV) maps",
)
cv_mean_threshold: float = Field(
default=0.01,
ge=0.0,
le=1.0,
description="Fraction of atlas mean maximum required to compute CV",
)
compute_log_stats: bool = Field(
default=True,
description="Compute statistics in log space",
)
log_offset: float = Field(
default=1e-6,
gt=0.0,
description="Small offset added to zero values before log transformation",
)
output_subdir: str = Field(
default="atlas_qc",
description=(
"Subdirectory under the cohort output directory where atlas QC "
"artefacts are written. Override per backend (e.g. "
"'atlas_qc_mrtrix3') to coexist with another backend's QC "
"under the same outputs/ tree."
),
)
[docs]
class SideThresholdConfig(BaseConfig):
"""Binarisation threshold for one side (subject or atlas) of a tract comparison.
Attributes:
mode: ``"fraction"`` applies ``value * max(volume)`` as the cutoff;
``"absolute"`` uses ``value`` directly as a raw voxel-intensity cutoff.
value: Threshold value. Must be in ``(0, 1)`` when ``mode="fraction"``;
any positive number when ``mode="absolute"``.
"""
mode: Literal["fraction", "absolute"] = Field(
default="fraction",
description="Thresholding mode: 'fraction' of volume max, or 'absolute' voxel intensity.",
)
value: float = Field(
default=0.05,
gt=0.0,
description=(
"Threshold value. Must be in (0, 1) when mode='fraction'; any positive "
"number when mode='absolute'."
),
)
@model_validator(mode="after")
def _validate_fraction_range(self) -> "SideThresholdConfig":
if self.mode == "fraction" and not (0.0 < self.value < 1.0):
raise ValueError("value must be in (0, 1) when mode='fraction'")
return self
[docs]
class HcpLooConfig(BaseConfig):
"""Configuration for the tract_similarity_hcp_loo workflow.
Controls the per-HCP-subject leave-one-out comparison against the cohort
atlas. LOO is mathematically valid at N=2 (the LOO atlas becomes the
other subject's volume), but a floor of 3 enforces a meaningful cohort
reference.
Attributes:
minimum_subjects: Minimum cohort size required to run the workflow.
write_volumes: If True, write the four NIfTI volumes per subject
(``subject_normalized.nii.gz``, ``atlas_normalized.nii.gz``,
``subject_mask.nii.gz``, ``atlas_mask.nii.gz``) alongside
``metrics.json``. If False, only ``metrics.json`` is written.
learned_prediction_relpath: Optional relative path under each subject
directory to a learned-atlas per-subject prediction NIfTI (template
space). Empty (default) disables the learned-template 3rd arm.
learned_support_threshold: Density threshold defining the held-out
subject's true support for the non-circular learned-arm metrics.
"""
minimum_subjects: int = Field(
default=3,
ge=2,
description=(
"Minimum cohort size required to run the LOO comparison. "
"Mathematically valid at 2; default 3 enforces a meaningful cohort."
),
)
write_volumes: bool = Field(
default=True,
description=(
"If True, write the four NIfTI volumes per subject alongside "
"metrics.json. Set to False to skip them (saves disk on large cohorts)."
),
)
learned_prediction_relpath: str = Field(
default="",
description=(
"Optional relative path under each subject directory to a learned-"
"atlas per-subject prediction NIfTI (template space). Empty disables "
"the learned-template 3rd arm."
),
)
learned_support_threshold: float = Field(
default=0.0,
ge=0.0,
description=(
"Density threshold defining the held-out subject's true support for "
"the non-circular learned-arm metrics."
),
)
[docs]
class TractSimilarityConfig(BaseConfig):
"""Configuration for the tract_similarity analysis workflow.
Controls the per-patient comparison between the native-space probtrackx2
tract density and the warped cohort mean atlas, plus the cohort-level
aggregation of those per-patient metrics.
Attributes:
fdt_relpath: Relative path under the patient output directory to the
native-space ``fdt_paths.nii.gz``.
waytotal_relpath: Relative path to the ``waytotal`` text file.
atlas_relpath: Relative path (or glob pattern) under the patient output
directory that points at the warped atlas mean volume.
subject_threshold: Binarisation threshold applied to the subject's
probtrackx2 volume before overlap / distance metrics.
atlas_threshold: Binarisation threshold applied to the warped atlas
volume before overlap / distance metrics.
n_bins: Histogram bin count for normalised mutual information.
output_subdir: Subdirectory under the patient output directory where
``metrics.json`` is written.
cohort_output_subdir: Subdirectory under the cohort output directory
where the aggregated summary is written.
"""
probtrackx_relpath: str = Field(
default="tractography/probtrackx2",
description=(
"Relative path to the probtrackx2 output directory. Auto-detects "
"single-run vs hemisphere-split (left/ + right/) layout and sums "
"the waytotal-normalised volumes across hemispheres, matching how "
"the atlas workflow builds its cohort stack."
),
)
fdt_name: str = Field(
default="fdt_paths.nii.gz",
description="Filename of the probtrackx2 density volume within each run dir",
)
waytotal_name: str = Field(
default="waytotal",
description="Filename of the probtrackx2 waytotal text file within each run dir",
)
atlas_relpath: str = Field(
default="atlas_in_patient_space/atlas_mean*.nii.gz",
description=(
"Relative path or glob pattern for the warped cohort mean atlas "
"in the patient's native space"
),
)
subject_threshold: SideThresholdConfig = Field(
default_factory=SideThresholdConfig,
description="Binarisation threshold for the subject's probtrackx2 volume.",
)
atlas_threshold: SideThresholdConfig = Field(
default_factory=SideThresholdConfig,
description="Binarisation threshold for the warped atlas volume.",
)
n_bins: int = Field(
default=64,
ge=8,
le=1024,
description="Number of histogram bins for normalised mutual information",
)
output_subdir: str = Field(
default="tract_similarity",
description="Patient-level output subdirectory for metrics.json",
)
cohort_output_subdir: str = Field(
default="cohort/tract_similarity",
description="Cohort-level output subdirectory for aggregated metrics",
)
hcp_loo: HcpLooConfig = Field(
default_factory=HcpLooConfig,
description=(
"Per-HCP-subject leave-one-out comparison against the cohort atlas. "
"See HcpLooConfig for details."
),
)
[docs]
class ThresholdGridConfig(BaseConfig):
"""Grid specification for one binarisation-threshold axis of the sweep.
Provide either ``start``/``stop``/``step`` (inclusive endpoint within float
tolerance) or an explicit ``values`` list. All entries must lie in
``(0, 1)`` because the sweep operates in ``"fraction"`` mode.
"""
start: Optional[float] = Field(default=0.05, gt=0.0, lt=1.0)
stop: Optional[float] = Field(default=0.50, gt=0.0, lt=1.0)
step: Optional[float] = Field(default=0.025, gt=0.0)
values: Optional[List[float]] = Field(default=None)
@model_validator(mode="after")
def _one_form(self) -> "ThresholdGridConfig":
if self.values is not None:
if len(self.values) == 0 or not all(0.0 < v < 1.0 for v in self.values):
raise ValueError("values must be a non-empty list of floats in (0, 1)")
return self
if self.start is None or self.stop is None or self.step is None:
raise ValueError("specify values OR start/stop/step")
if self.stop <= self.start:
raise ValueError("stop must be > start")
return self
[docs]
class TractSimilaritySweepConfig(BaseConfig):
"""Configuration for the ``tract_similarity_sweep`` cohort grid search.
Sweeps the ``subject_threshold.value`` and ``atlas_threshold.value`` knobs
(both in ``"fraction"`` mode) and reports the cell that maximises the
chosen aggregation of Dice across the cohort.
"""
subject_threshold_grid: ThresholdGridConfig = Field(
default_factory=ThresholdGridConfig,
description="Grid over subject_threshold.value (fraction-of-max).",
)
atlas_threshold_grid: ThresholdGridConfig = Field(
default_factory=ThresholdGridConfig,
description="Grid over atlas_threshold.value (fraction-of-max).",
)
aggregation: Literal["mean", "median"] = Field(
default="mean",
description="Cohort aggregation used to pick the best grid cell.",
)
output_subdir: str = Field(
default="tract_similarity_sweep",
description="Sweep output subdirectory under the cohort output root.",
)
emit_heatmap: bool = Field(
default=True,
description="Emit a matplotlib heatmap (PNG) of the aggregated grid.",
)
[docs]
class OutputSettingsConfig(BaseConfig):
"""Configuration for CLI output behavior (YAML-level defaults).
CLI flags ``-v``, ``-q``, ``--summary``, ``--no-progress`` override
values set here at runtime.
Attributes:
verbosity: Default verbosity level (``"quiet"``, ``"normal"``, ``"verbose"``).
summary: Summary detail (``"off"``, ``"compact"``, ``"full"``).
progress: Whether to show progress bars/spinners.
``"auto"`` enables progress for TTY and disables for pipes/CI.
output_format: Output format (``"human"``, ``"json"``).
"""
verbosity: str = Field(
default="normal",
description="Default verbosity: quiet | normal | verbose",
)
summary: str = Field(
default="compact",
description="Summary detail: off | compact | full",
)
progress: str = Field(
default="auto",
description="Progress UI: auto | on | off",
)
output_format: str = Field(
default="human",
description="Output format: human | json",
)
[docs]
@field_validator("verbosity")
@classmethod
def validate_verbosity(cls, v: str) -> str:
"""Validate verbosity value."""
valid = ("quiet", "normal", "verbose")
if v not in valid:
raise ValueError(f"verbosity must be one of {valid}")
return v
[docs]
@field_validator("summary")
@classmethod
def validate_summary(cls, v: str) -> str:
"""Validate summary value."""
valid = ("off", "compact", "full")
if v not in valid:
raise ValueError(f"summary must be one of {valid}")
return v
[docs]
@field_validator("progress")
@classmethod
def validate_progress(cls, v: str) -> str:
"""Validate progress value."""
valid = ("auto", "on", "off")
if v not in valid:
raise ValueError(f"progress must be one of {valid}")
return v
class SLURMConfig(BaseConfig):
"""Configuration for SLURM-based grouped stage submission.
When ``enabled`` is ``False`` (the default) the SLURM submit path is never
taken and pipeline behavior is identical to local execution. The walltime
regex validator, plugin allow-list, and per-group sbatch directive
derivation are added in a later phase; this model only needs to *exist* so
that ``cfg.slurm.enabled`` never raises ``AttributeError``.
Attributes:
enabled: Whether grouped SLURM submission is used.
account: SLURM account/charge string (``--account``).
cpu_partition: Partition for CPU-only stage groups.
gpu_partition: Default (short) GPU partition.
gpu_partition_long: Long-running GPU partition (``None`` to disable).
gpu_short_max_minutes: ``est_minutes`` above this route to the long
GPU partition.
walltime_cpu: Walltime (``HH:MM:SS``) for CPU groups.
walltime_gpu: Walltime (``HH:MM:SS``) for short GPU groups.
walltime_gpu_long: Walltime (``HH:MM:SS``) for long GPU groups.
max_concurrent_gpu_jobs: In-flight GPU job cap (matches physical GPUs).
max_jobs_per_wave: Maximum in-flight jobs per submission wave.
submit_dir: Shared-FS directory for sbatch scripts / logs / manifest.
apptainer_image: Path to the Apptainer image used by stage jobs.
env_setup: Path to an environment-setup script sourced before runs.
sbatch_extra_args: Extra sbatch args keyed by compute class
(e.g. ``{"cpu": [...], "gpu": [...]}``).
"""
enabled: bool = Field(default=False, description="Enable grouped SLURM submission")
account: str = Field(default="", description="SLURM account/charge string")
cpu_partition: str = Field(default="CPU_short", description="Partition for CPU-only groups")
gpu_partition: str = Field(default="GPU_short", description="Default (short) GPU partition")
gpu_partition_long: Optional[str] = Field(
default="GPU_long", description="Long-running GPU partition (None to disable)"
)
gpu_short_max_minutes: int = Field(
default=120, description="est_minutes above this route to the long GPU partition"
)
walltime_cpu: str = Field(default="04:00:00", description="Walltime for CPU groups")
walltime_gpu: str = Field(default="04:00:00", description="Walltime for short GPU groups")
walltime_gpu_long: str = Field(default="24:00:00", description="Walltime for long GPU groups")
max_concurrent_gpu_jobs: int = Field(
default=6, description="In-flight GPU job cap (matches physical GPUs)"
)
max_jobs_per_wave: int = Field(
default=1800, ge=1, le=2000, description="Maximum in-flight jobs per submission wave"
)
submit_dir: Optional[Path] = Field(
default=None, description="Shared-FS directory for sbatch scripts, logs, and manifest"
)
apptainer_image: Optional[Path] = Field(
default=None, description="Path to the Apptainer image used by stage jobs"
)
env_setup: Optional[Path] = Field(
default=None, description="Path to an environment-setup script sourced before runs"
)
sbatch_extra_args: Dict[str, List[str]] = Field(
default_factory=dict,
description="Extra sbatch args keyed by compute class (e.g. cpu/gpu)",
)
@field_validator("submit_dir", "apptainer_image", "env_setup", mode="before")
@classmethod
def _convert_to_path(cls, v):
"""Convert string paths to Path objects, expanding ~ and env vars."""
return to_path(v)
@field_validator("walltime_cpu", "walltime_gpu", "walltime_gpu_long")
@classmethod
def _validate_walltime(cls, v: str) -> str:
"""Validate SLURM walltimes have an ``H+:MM:SS`` shape."""
if not re.match(r"^\d+:\d{2}:\d{2}$", v):
raise ValueError(f"walltime must match H+:MM:SS (e.g. '04:00:00'), got '{v}'")
return v
[docs]
class PipelineConfig(BaseModel):
"""Root configuration model for thesis pipelines.
Aggregates all sub-configurations and provides convenience methods.
Attributes:
paths: Filesystem paths.
hardware: Compute resources.
atlas: Atlas generation settings.
s3: S3 data download settings (optional).
preprocessing: Generic preprocessing steps.
preprocess: Workflow-specific raw-to-HCP preprocessing settings.
registration: Image registration.
segmentation: Brain segmentation.
tractography: Tractography parameters.
hcp: HCP-specific overrides.
transforms: Pre-computed transform settings.
nipype: Nipype execution settings.
validation: ROI validation settings.
qc: QC generation settings.
atlas_qc: Atlas QC generation settings.
synthseg: Optional SynthSeg-specific overrides.
output: CLI output defaults.
patient_id: Optional patient identifier.
protocol: Optional protocol name.
"""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
paths: PathConfig = Field(default_factory=PathConfig)
hardware: HardwareConfig = Field(default_factory=HardwareConfig)
atlas: AtlasConfig = Field(default_factory=AtlasConfig)
s3: Optional[S3Config] = Field(
default=None,
description="S3 data download configuration (optional, for cloud deployments)",
)
preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
preprocess: Dict[str, Any] = Field(
default_factory=dict,
description="Workflow-specific config for the preprocess workflow",
)
mrtrix3: Dict[str, Any] = Field(
default_factory=dict,
description="Workflow-specific config for the mrtrix3 workflow",
)
registration: RegistrationConfig = Field(default_factory=RegistrationConfig)
segmentation: SegmentationConfig = Field(default_factory=SegmentationConfig)
tractography: TractographyConfig = Field(default_factory=TractographyConfig)
hcp: HCPConfig = Field(default_factory=HCPConfig)
transforms: TransformsConfig = Field(default_factory=TransformsConfig)
nipype: NipypeConfig = Field(default_factory=NipypeConfig)
slurm: SLURMConfig = Field(default_factory=SLURMConfig)
validation: ValidationConfig = Field(default_factory=ValidationConfig)
qc: QCConfig = Field(
default_factory=QCConfig,
description="QC visualisation settings (ROI overlays, track density figures)",
)
atlas_qc: AtlasQCConfig = Field(
default_factory=AtlasQCConfig,
description="Atlas QC settings for cohort-level and per-patient atlas outputs",
)
synthseg: Optional[SynthSegConfig] = Field(
default=None,
description="SynthSeg-specific config overrides (e.g. t1_image)",
)
output: OutputSettingsConfig = Field(
default_factory=OutputSettingsConfig,
description="CLI output behavior defaults (verbosity, summary, progress, format)",
)
tract_similarity: TractSimilarityConfig = Field(
default_factory=TractSimilarityConfig,
description="Tract similarity analysis settings (per-patient and cohort aggregation)",
)
tract_similarity_sweep: TractSimilaritySweepConfig = Field(
default_factory=TractSimilaritySweepConfig,
description="Grid-search sweep over tract_similarity binarization thresholds.",
)
patient_id: Optional[str] = Field(default=None, description="Patient identifier")
protocol: Optional[str] = Field(default=None, description="Acquisition protocol")
@model_validator(mode="before")
@classmethod
def _validate_workflow_namespaces(cls, data: Any) -> Any:
"""Validate workflow-registered namespaces and reject unknown keys.
``extra="allow"`` on PipelineConfig lets unknown top-level keys reach
this validator. For each unknown key we look up
:data:`NAMESPACE_REGISTRY` — keys matching a registered workflow
namespace are validated against the registered schema; keys with no
match raise ``ValueError`` to preserve the historical typo-catching
behaviour of ``extra="forbid"``.
"""
if not isinstance(data, dict):
return data
from thesis.core.config.namespace_registry import NAMESPACE_REGISTRY
known_static = set(cls.model_fields.keys())
for key in list(data.keys()):
if key in known_static:
continue
schema = NAMESPACE_REGISTRY.get(key)
if schema is None:
raise ValueError(
f"Unknown top-level config key '{key}'. "
f"Known fields: {sorted(known_static)}. "
f"Registered workflow namespaces: {NAMESPACE_REGISTRY.list()}. "
"If '" + key + "' is a workflow namespace, make sure the workflow "
"module is imported before validating the config."
)
raw = data[key] if data[key] is not None else {}
# Validate once here (surfacing schema errors early) and store the
# typed instance directly. ``extra="allow"`` +
# ``arbitrary_types_allowed=True`` pass the instance through core
# validation unchanged, so the after-validator can reuse it without
# a second parse.
data[key] = schema.model_validate(raw)
return data
@model_validator(mode="after")
def _instantiate_workflow_namespaces(self) -> "PipelineConfig":
"""Materialise registered namespaces as their typed Pydantic instances.
``extra="allow"`` stores unknown keys as raw values on the model.
After base validation runs, replace each raw dict with an instance
of the registered schema so workflow bodies can use attribute
access (``config.my_wf.threshold``). Namespaces absent from the
YAML get a default-constructed instance.
"""
from thesis.core.config.namespace_registry import NAMESPACE_REGISTRY
known_static = set(type(self).model_fields.keys())
for ns, schema in NAMESPACE_REGISTRY.items():
if ns in known_static:
continue
raw = getattr(self, ns, None)
if raw is None:
object.__setattr__(self, ns, schema())
elif isinstance(raw, schema):
# Already validated/instantiated by the before-validator.
continue
elif isinstance(raw, dict):
# Constructed directly (not via the before-validator path),
# e.g. PipelineConfig(**kwargs) — validate the raw dict.
object.__setattr__(self, ns, schema.model_validate(raw))
return self
@model_validator(mode="after")
def _validate_from_registration_references(self) -> "PipelineConfig":
"""Ensure every transforms.jobs[*].from_registration resolves to a job.
A reference must match either an explicit ``registration.jobs`` name or
the implicit ``patient_to_template`` job that is synthesized from the
top-level registration fields when ``registration.jobs`` is empty.
"""
referenced = {
job.from_registration
for job in self.transforms.jobs
if job.from_registration is not None
}
if not referenced:
return self
if self.registration.jobs:
known = {job.name for job in self.registration.jobs}
else:
known = {"patient_to_template"}
unknown = referenced - known
if unknown:
raise ValueError(
"transforms.jobs[*].from_registration references unknown registration "
f"job name(s): {sorted(unknown)}. Known registration job(s): {sorted(known)}."
)
return self
@model_validator(mode="after")
def _validate_slurm_plugin_consistency(self) -> "PipelineConfig":
"""Reject ``SLURMGraph`` when grouped SLURM submission is enabled.
The grouped SLURM submit path coalesces nodes into resource-homogeneous
stage jobs. Nipype's ``SLURMGraph`` plugin submits one job per node,
which defeats the job-count budget; point users at grouped submission.
"""
if self.slurm.enabled and self.nipype.plugin == "SLURMGraph":
raise ValueError(
"nipype.plugin='SLURMGraph' is incompatible with slurm.enabled=true; "
"grouped SLURM submission handles job emission, so set "
"nipype.plugin to 'MultiProc' (or 'Linear')."
)
return self
[docs]
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "PipelineConfig":
"""
Create a PipelineConfig from a dictionary.
Convenience wrapper around ``model_validate`` that provides a
domain-specific name and ensures the return type is narrowed.
Args:
config_dict: Dictionary of configuration values
Returns:
Validated PipelineConfig object
Example:
>>> config_dict = {"hardware": {"threads": 8}}
>>> config = PipelineConfig.from_dict(config_dict)
"""
return cast("PipelineConfig", cls.model_validate(config_dict))
[docs]
def to_dict(self) -> Dict[str, Any]:
"""
Convert config to dictionary.
Returns:
Dictionary representation of config
Example:
>>> config_dict = config.to_dict()
"""
return cast(Dict[str, Any], self.model_dump(mode="python", exclude_none=False))
[docs]
def merge_with(self, other: Union["PipelineConfig", Dict[str, Any]]) -> "PipelineConfig":
"""
Merge this config with another config or dict.
Args:
other: Another PipelineConfig or dict to merge
Returns:
New PipelineConfig with merged values
Example:
>>> merged = config.merge_with({"hardware": {"threads": 16}})
"""
from .loaders import merge_configs
base_dict = self.to_dict()
if isinstance(other, PipelineConfig):
other_dict = other.to_dict()
else:
other_dict = other
merged_dict = merge_configs(base_dict, other_dict)
return self.__class__.from_dict(merged_dict)