Source code for thesis.workflows.tract_similarity.workflow

"""Tract similarity analysis Nipype workflows.

Two workflows are registered here:

- ``tract_similarity`` — per-patient. Loads the patient's native-space
  probtrackx2 output and the warped cohort mean atlas, computes four families
  of similarity metrics (overlap, voxelwise correlation, spatial distance,
  distribution similarity), and writes ``metrics.json`` alongside the four
  NIfTI volumes used to compute those metrics:
  ``subject_normalized.nii.gz`` (waytotal-normalized probtrackx2 density),
  ``atlas_normalized.nii.gz`` (normalized warped atlas),
  ``subject_mask.nii.gz`` (thresholded subject binary mask), and
  ``atlas_mask.nii.gz`` (thresholded atlas binary mask).
- ``tract_similarity_cohort`` — cohort-level. Aggregates every patient's
  ``metrics.json`` into ``summary.csv``, ``per_patient.csv``, and
  ``outliers.json``.

Prerequisites (verified pre-flight):

- ``hcp`` workflow has produced ``tractography/probtrackx2/fdt_paths.nii.gz``
  and ``waytotal`` for the patient (per-patient mode).
- ``atlas`` -> ``atlas_to_patient`` has warped a cohort mean map into the
  patient's native space (per-patient mode).
"""

from __future__ import annotations

from pathlib import Path
from typing import Any, List

import nipype.pipeline.engine as pe
from nipype.interfaces.utility import Function, IdentityInterface

from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.decorators import verify, workflow
from thesis.core.logging import get_logger
from thesis.workflows.tract_similarity._io import _discover_run_dirs, discover_patient_dirs

logger = get_logger(__name__)

_SIM_INPUTS = [
    "patient_output",
    "patient_id",
    "probtrackx_relpath",
    "fdt_name",
    "waytotal_name",
    "atlas_relpath",
    "output_subdir",
    "subject_threshold_mode",
    "subject_threshold_value",
    "atlas_threshold_mode",
    "atlas_threshold_value",
    "n_bins",
    "_ready_anchor",
]

_COHORT_INPUTS = [
    "input_dir",
    "cohort_output_dir",
    "output_subdir",
    "outlier_sd_threshold",
]

__all__ = [
    "build_workflow",
    "build_cohort_workflow",
    "verify_requirements",
    "verify_cohort_requirements",
    "apply_threshold",
]


[docs] def apply_threshold(vol: Any, mode: str, value: float) -> float: """Compute the binarisation cutoff for a volume. ``mode="fraction"`` returns ``value * max(volume)`` (with a tiny floor to handle empty volumes); ``mode="absolute"`` returns ``value`` unchanged. Shared by the per-patient workflow and the threshold sweep so the cutoff formula has a single source of truth. """ import numpy as np if mode == "fraction": return float(value) * max(float(np.max(np.asarray(vol))), 1e-12) if mode == "absolute": return float(value) raise ValueError(f"unknown threshold mode: {mode!r}")
# --------------------------------------------------------------------------- # Per-patient task body (runs in a Nipype subprocess — all imports local) # --------------------------------------------------------------------------- def _compute_similarity_task( patient_output: str, patient_id: str, probtrackx_relpath: str, fdt_name: str, waytotal_name: str, atlas_relpath: str, output_subdir: str, subject_threshold_mode: str, subject_threshold_value: float, atlas_threshold_mode: str, atlas_threshold_value: float, n_bins: int, _ready_anchor: object = None, ) -> str: """Per-patient Nipype Function body: load both volumes, compute the four metric families, write ``metrics.json`` + four NIfTI volumes (subject_normalized, atlas_normalized, subject_mask, atlas_mask). """ import json import sys from pathlib import Path import nibabel as nib import numpy as np from thesis.workflows.tract_similarity._io import ( load_atlas_normalized, load_probtrackx_volume, resolve_atlas_file, voxel_size_mm, ) from thesis.workflows.tract_similarity._metrics import ( correlation_metrics, distance_metrics, distribution_metrics, overlap_metrics, ) base = Path(patient_output) probtrackx_dir = base / probtrackx_relpath atlas_file = resolve_atlas_file(base, atlas_relpath) print(f"[tract_similarity] probtrackx_dir = {probtrackx_dir}", file=sys.stderr) print(f"[tract_similarity] atlas_mean = {atlas_file}", file=sys.stderr) vol_prob, affine_prob = load_probtrackx_volume( probtrackx_dir, fdt_name=fdt_name, waytotal_name=waytotal_name ) vol_atlas, _ = load_atlas_normalized(atlas_file) if vol_prob.shape != vol_atlas.shape: raise ValueError( f"Shape mismatch: probtrackx2 {vol_prob.shape} vs atlas " f"{vol_atlas.shape}. The warped atlas must be resampled onto the " f"patient's native DWI grid by atlas_to_patient." ) from thesis.workflows.tract_similarity.workflow import apply_threshold thr_prob = apply_threshold(vol_prob, subject_threshold_mode, subject_threshold_value) thr_atlas = apply_threshold(vol_atlas, atlas_threshold_mode, atlas_threshold_value) mask_prob = vol_prob > thr_prob mask_atlas = vol_atlas > thr_atlas spacing = voxel_size_mm(affine_prob) metrics = { "patient_id": patient_id, "thresholds": { "subject": { "mode": subject_threshold_mode, "value": float(subject_threshold_value), "effective": float(thr_prob), }, "atlas": { "mode": atlas_threshold_mode, "value": float(atlas_threshold_value), "effective": float(thr_atlas), }, }, "voxel_counts": { "probtrackx": int(np.count_nonzero(mask_prob)), "atlas": int(np.count_nonzero(mask_atlas)), "intersection": int(np.count_nonzero(mask_prob & mask_atlas)), "union": int(np.count_nonzero(mask_prob | mask_atlas)), }, "overlap": overlap_metrics(mask_prob, mask_atlas), "correlation": correlation_metrics(vol_prob, vol_atlas), "distance_mm": distance_metrics(mask_prob, mask_atlas, voxel_size_mm=spacing), "distribution": distribution_metrics(vol_prob, vol_atlas, n_bins=int(n_bins)), } out_dir = base / output_subdir out_dir.mkdir(parents=True, exist_ok=True) for name, vol, dtype in ( ("subject_normalized", vol_prob, np.float32), ("atlas_normalized", vol_atlas, np.float32), ("subject_mask", mask_prob, np.uint8), ("atlas_mask", mask_atlas, np.uint8), ): nib.save( nib.Nifti1Image(vol.astype(dtype), affine_prob), str(out_dir / f"{name}.nii.gz"), ) print( f"[tract_similarity] wrote normalized volumes and thresholded masks under {out_dir}", file=sys.stderr, ) out_file = out_dir / "metrics.json" out_file.write_text(json.dumps(metrics, indent=2), encoding="utf-8") print(f"[tract_similarity] wrote {out_file}", file=sys.stderr) return str(out_file) # --------------------------------------------------------------------------- # Cohort aggregation task body # --------------------------------------------------------------------------- def _aggregate_cohort_task( input_dir: str, cohort_output_dir: str, output_subdir: str, outlier_sd_threshold: float, ) -> list: """Cohort Nipype Function body. Scans numeric patient subdirectories under *input_dir* for ``<output_subdir>/metrics.json`` and emits a cohort-level summary. """ import csv import json import sys from pathlib import Path import numpy as np from thesis.workflows.tract_similarity._io import discover_patient_dirs def _is_nan(value: object) -> bool: return isinstance(value, float) and value != value # NaN != NaN in_path = Path(input_dir) out_path = Path(cohort_output_dir) out_path.mkdir(parents=True, exist_ok=True) patient_dirs = discover_patient_dirs(in_path) records: list = [] for pdir in patient_dirs: metrics_file = pdir / output_subdir / "metrics.json" if not metrics_file.is_file(): continue try: data = json.loads(metrics_file.read_text(encoding="utf-8")) except json.JSONDecodeError as exc: print(f"[tract_similarity_cohort] skipping {metrics_file}: {exc}", file=sys.stderr) continue pid = data.get("patient_id", pdir.name) for family in ("overlap", "correlation", "distance_mm", "distribution"): for metric_name, value in (data.get(family) or {}).items(): records.append((pid, family, metric_name, value)) if not records: raise RuntimeError( f"No per-patient metrics.json found under {in_path}/*/ {output_subdir}. " f"Run 'tract_similarity' for each patient first." ) def _write_csv(path: Path, header: list, rows: list) -> None: with path.open("w", encoding="utf-8", newline="") as fh: writer = csv.writer(fh) writer.writerow(header) writer.writerows(rows) per_patient_csv = out_path / "per_patient.csv" _write_csv(per_patient_csv, ["patient_id", "metric_family", "metric_name", "value"], records) summary_rows: list = [] outliers: dict = {} grouped: dict = {} for pid, family, name, value in records: grouped.setdefault((family, name), []).append((pid, value)) for (family, name), pid_values in sorted(grouped.items()): values = np.asarray( [v for _, v in pid_values if v is not None and not _is_nan(v)], dtype=np.float64 ) n = int(values.size) if n == 0: summary_rows.append([family, name, 0, "", "", "", "", ""]) continue mean = float(np.mean(values)) std = float(np.std(values)) median = float(np.median(values)) q25 = float(np.percentile(values, 25)) q75 = float(np.percentile(values, 75)) summary_rows.append([family, name, n, mean, std, median, q25, q75]) if std > 0.0: for pid, v in pid_values: if v is None or _is_nan(v): continue if abs(float(v) - median) > outlier_sd_threshold * std: outliers.setdefault(pid, []).append( {"family": family, "metric": name, "value": float(v)} ) summary_csv = out_path / "summary.csv" _write_csv( summary_csv, ["metric_family", "metric_name", "n", "mean", "std", "median", "q25", "q75"], summary_rows, ) outliers_file = out_path / "outliers.json" outliers_file.write_text(json.dumps(outliers, indent=2), encoding="utf-8") print( f"[tract_similarity_cohort] wrote {summary_csv}, {per_patient_csv}, {outliers_file}", file=sys.stderr, ) return [str(summary_csv), str(per_patient_csv), str(outliers_file)] # --------------------------------------------------------------------------- # Verifiers # ---------------------------------------------------------------------------
[docs] def verify_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]: """Pre-flight checks for the per-patient tract_similarity workflow.""" if context.output_dir is None: return ["output_dir is not set in the processing context"] errors: List[str] = [] base = Path(context.output_dir) ts = config.tract_similarity probtrackx_dir = base / ts.probtrackx_relpath if not probtrackx_dir.is_dir(): errors.append( f"probtrackx output directory not found at {probtrackx_dir}. " f"Run 'hcp' workflow first." ) else: run_dirs = _discover_run_dirs(probtrackx_dir) if not any( (d / ts.fdt_name).is_file() and (d / ts.waytotal_name).is_file() for d in run_dirs ): errors.append( f"No '{ts.fdt_name}' + '{ts.waytotal_name}' pair found under " f"{probtrackx_dir} (nor in left/ or right/ subdirectories). " f"Run 'hcp' workflow first." ) atlas_candidate = base / ts.atlas_relpath if not atlas_candidate.is_file() and not list(base.glob(ts.atlas_relpath)): errors.append( f"No warped atlas found matching '{ts.atlas_relpath}' under {base}. " f"Run 'atlas_to_patient' first, or adjust 'tract_similarity.atlas_relpath'." ) return errors
[docs] def verify_cohort_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]: """Pre-flight checks for the cohort aggregation workflow.""" if context.output_dir is None: return ["output_dir is not set in the processing context"] input_dir = _resolve_cohort_input_dir(Path(context.output_dir)) patient_dirs = discover_patient_dirs(input_dir, sort=False) if not patient_dirs: return [ f"No patient subdirectories found under {input_dir}. " f"Run 'tract_similarity' for each patient first." ] return []
def _resolve_cohort_input_dir(output_dir: Path) -> Path: """Mirror of atlas/_resolve_input_dir for the cohort aggregation. When the CLI runs the cohort-level workflow it sets ``output_dir`` to ``<base>/cohort``. Patient data lives one level up. """ if output_dir.name == "cohort" and output_dir.parent.is_dir(): return output_dir.parent return output_dir # --------------------------------------------------------------------------- # Workflow builders # ---------------------------------------------------------------------------
[docs] @workflow( name="tract_similarity", description=( "Per-patient tract similarity analysis: compare native-space " "probtrackx2 density to the warped cohort mean atlas across four " "metric families (overlap, correlation, distance, distribution)." ), protocol="hcp", ) @verify(verify_requirements) def build_workflow(*, config: PipelineConfig, context: ProcessingContext) -> pe.Workflow: """Build the per-patient tract_similarity workflow. Inputs (probtrackx density, warped atlas) live under ``context.output_dir`` and are accessed via ``ts.probtrackx_relpath`` / ``ts.atlas_relpath`` from config — those support hemisphere-split layouts and glob patterns. """ if context.output_dir is None: raise ValueError("output_dir must be set before building the workflow") wf = pe.Workflow(name=f"tract_similarity_{context.patient_id}") if context.working_dir: wf.base_dir = str(context.working_dir) ts = config.tract_similarity node = pe.Node( Function( input_names=_SIM_INPUTS, output_names=["metrics_file"], function=_compute_similarity_task, ), name="compute_metrics", ) node.inputs.patient_output = str(context.output_dir) node.inputs.patient_id = context.patient_id node.inputs.probtrackx_relpath = ts.probtrackx_relpath node.inputs.fdt_name = ts.fdt_name node.inputs.waytotal_name = ts.waytotal_name node.inputs.atlas_relpath = ts.atlas_relpath node.inputs.output_subdir = ts.output_subdir node.inputs.subject_threshold_mode = ts.subject_threshold.mode node.inputs.subject_threshold_value = float(ts.subject_threshold.value) node.inputs.atlas_threshold_mode = ts.atlas_threshold.mode node.inputs.atlas_threshold_value = float(ts.atlas_threshold.value) node.inputs.n_bins = int(ts.n_bins) # Fixed-name anchor for cross-workflow ordering. Meta-workflows # (e.g. full_pipeline stage 5) wire entry_gate.ready from upstream # probtrackx + atlas_to_patient completion; standalone runs leave it # unset and compute_metrics fires from its statically-set inputs. entry_gate = pe.Node(IdentityInterface(fields=["ready"]), name="entry_gate") wf.add_nodes([entry_gate, node]) wf.connect(entry_gate, "ready", node, "_ready_anchor") logger.info( "Built tract_similarity workflow for {} " "(subject_threshold={}:{}, atlas_threshold={}:{})", context.patient_id, ts.subject_threshold.mode, ts.subject_threshold.value, ts.atlas_threshold.mode, ts.atlas_threshold.value, ) return wf
[docs] @workflow( name="tract_similarity_cohort", description=( "Cohort-level aggregation of per-patient tract_similarity metrics. " "Emits summary.csv, per_patient.csv, and outliers.json." ), protocol="hcp", scope="cohort", ) @verify(verify_cohort_requirements) def build_cohort_workflow(*, config: PipelineConfig, context: ProcessingContext) -> pe.Workflow: """Build the cohort-level tract_similarity aggregation workflow.""" if context.output_dir is None: raise ValueError("output_dir must be set before building the workflow") wf = pe.Workflow(name="tract_similarity_cohort") if context.working_dir: wf.base_dir = str(context.working_dir) ts = config.tract_similarity cohort_out = Path(context.output_dir) / ts.cohort_output_subdir input_dir = _resolve_cohort_input_dir(Path(context.output_dir)) node = pe.Node( Function( input_names=_COHORT_INPUTS, output_names=["generated_files"], function=_aggregate_cohort_task, ), name="aggregate_cohort", ) node.inputs.input_dir = str(input_dir) node.inputs.cohort_output_dir = str(cohort_out) node.inputs.output_subdir = ts.output_subdir node.inputs.outlier_sd_threshold = float(config.qc.outlier_sd_threshold) wf.add_nodes([node]) logger.info("Built tract_similarity_cohort workflow rooted at {}", input_dir) return wf