"""Tract similarity analysis Nipype workflows.
Two workflows are registered here:
- ``tract_similarity`` — per-patient. Loads the patient's native-space
probtrackx2 output and the warped cohort mean atlas, computes four families
of similarity metrics (overlap, voxelwise correlation, spatial distance,
distribution similarity), and writes ``metrics.json`` alongside the four
NIfTI volumes used to compute those metrics:
``subject_normalized.nii.gz`` (waytotal-normalized probtrackx2 density),
``atlas_normalized.nii.gz`` (normalized warped atlas),
``subject_mask.nii.gz`` (thresholded subject binary mask), and
``atlas_mask.nii.gz`` (thresholded atlas binary mask).
- ``tract_similarity_cohort`` — cohort-level. Aggregates every patient's
``metrics.json`` into ``summary.csv``, ``per_patient.csv``, and
``outliers.json``.
Prerequisites (verified pre-flight):
- ``hcp`` workflow has produced ``tractography/probtrackx2/fdt_paths.nii.gz``
and ``waytotal`` for the patient (per-patient mode).
- ``atlas`` -> ``atlas_to_patient`` has warped a cohort mean map into the
patient's native space (per-patient mode).
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, List
import nipype.pipeline.engine as pe
from nipype.interfaces.utility import Function, IdentityInterface
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.decorators import verify, workflow
from thesis.core.logging import get_logger
from thesis.workflows.tract_similarity._io import _discover_run_dirs, discover_patient_dirs
logger = get_logger(__name__)
_SIM_INPUTS = [
"patient_output",
"patient_id",
"probtrackx_relpath",
"fdt_name",
"waytotal_name",
"atlas_relpath",
"output_subdir",
"subject_threshold_mode",
"subject_threshold_value",
"atlas_threshold_mode",
"atlas_threshold_value",
"n_bins",
"_ready_anchor",
]
_COHORT_INPUTS = [
"input_dir",
"cohort_output_dir",
"output_subdir",
"outlier_sd_threshold",
]
__all__ = [
"build_workflow",
"build_cohort_workflow",
"verify_requirements",
"verify_cohort_requirements",
"apply_threshold",
]
[docs]
def apply_threshold(vol: Any, mode: str, value: float) -> float:
"""Compute the binarisation cutoff for a volume.
``mode="fraction"`` returns ``value * max(volume)`` (with a tiny floor to
handle empty volumes); ``mode="absolute"`` returns ``value`` unchanged.
Shared by the per-patient workflow and the threshold sweep so the cutoff
formula has a single source of truth.
"""
import numpy as np
if mode == "fraction":
return float(value) * max(float(np.max(np.asarray(vol))), 1e-12)
if mode == "absolute":
return float(value)
raise ValueError(f"unknown threshold mode: {mode!r}")
# ---------------------------------------------------------------------------
# Per-patient task body (runs in a Nipype subprocess — all imports local)
# ---------------------------------------------------------------------------
def _compute_similarity_task(
patient_output: str,
patient_id: str,
probtrackx_relpath: str,
fdt_name: str,
waytotal_name: str,
atlas_relpath: str,
output_subdir: str,
subject_threshold_mode: str,
subject_threshold_value: float,
atlas_threshold_mode: str,
atlas_threshold_value: float,
n_bins: int,
_ready_anchor: object = None,
) -> str:
"""Per-patient Nipype Function body: load both volumes, compute the
four metric families, write ``metrics.json`` + four NIfTI volumes
(subject_normalized, atlas_normalized, subject_mask, atlas_mask).
"""
import json
import sys
from pathlib import Path
import nibabel as nib
import numpy as np
from thesis.workflows.tract_similarity._io import (
load_atlas_normalized,
load_probtrackx_volume,
resolve_atlas_file,
voxel_size_mm,
)
from thesis.workflows.tract_similarity._metrics import (
correlation_metrics,
distance_metrics,
distribution_metrics,
overlap_metrics,
)
base = Path(patient_output)
probtrackx_dir = base / probtrackx_relpath
atlas_file = resolve_atlas_file(base, atlas_relpath)
print(f"[tract_similarity] probtrackx_dir = {probtrackx_dir}", file=sys.stderr)
print(f"[tract_similarity] atlas_mean = {atlas_file}", file=sys.stderr)
vol_prob, affine_prob = load_probtrackx_volume(
probtrackx_dir, fdt_name=fdt_name, waytotal_name=waytotal_name
)
vol_atlas, _ = load_atlas_normalized(atlas_file)
if vol_prob.shape != vol_atlas.shape:
raise ValueError(
f"Shape mismatch: probtrackx2 {vol_prob.shape} vs atlas "
f"{vol_atlas.shape}. The warped atlas must be resampled onto the "
f"patient's native DWI grid by atlas_to_patient."
)
from thesis.workflows.tract_similarity.workflow import apply_threshold
thr_prob = apply_threshold(vol_prob, subject_threshold_mode, subject_threshold_value)
thr_atlas = apply_threshold(vol_atlas, atlas_threshold_mode, atlas_threshold_value)
mask_prob = vol_prob > thr_prob
mask_atlas = vol_atlas > thr_atlas
spacing = voxel_size_mm(affine_prob)
metrics = {
"patient_id": patient_id,
"thresholds": {
"subject": {
"mode": subject_threshold_mode,
"value": float(subject_threshold_value),
"effective": float(thr_prob),
},
"atlas": {
"mode": atlas_threshold_mode,
"value": float(atlas_threshold_value),
"effective": float(thr_atlas),
},
},
"voxel_counts": {
"probtrackx": int(np.count_nonzero(mask_prob)),
"atlas": int(np.count_nonzero(mask_atlas)),
"intersection": int(np.count_nonzero(mask_prob & mask_atlas)),
"union": int(np.count_nonzero(mask_prob | mask_atlas)),
},
"overlap": overlap_metrics(mask_prob, mask_atlas),
"correlation": correlation_metrics(vol_prob, vol_atlas),
"distance_mm": distance_metrics(mask_prob, mask_atlas, voxel_size_mm=spacing),
"distribution": distribution_metrics(vol_prob, vol_atlas, n_bins=int(n_bins)),
}
out_dir = base / output_subdir
out_dir.mkdir(parents=True, exist_ok=True)
for name, vol, dtype in (
("subject_normalized", vol_prob, np.float32),
("atlas_normalized", vol_atlas, np.float32),
("subject_mask", mask_prob, np.uint8),
("atlas_mask", mask_atlas, np.uint8),
):
nib.save(
nib.Nifti1Image(vol.astype(dtype), affine_prob),
str(out_dir / f"{name}.nii.gz"),
)
print(
f"[tract_similarity] wrote normalized volumes and thresholded masks under {out_dir}",
file=sys.stderr,
)
out_file = out_dir / "metrics.json"
out_file.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
print(f"[tract_similarity] wrote {out_file}", file=sys.stderr)
return str(out_file)
# ---------------------------------------------------------------------------
# Cohort aggregation task body
# ---------------------------------------------------------------------------
def _aggregate_cohort_task(
input_dir: str,
cohort_output_dir: str,
output_subdir: str,
outlier_sd_threshold: float,
) -> list:
"""Cohort Nipype Function body.
Scans numeric patient subdirectories under *input_dir* for
``<output_subdir>/metrics.json`` and emits a cohort-level summary.
"""
import csv
import json
import sys
from pathlib import Path
import numpy as np
from thesis.workflows.tract_similarity._io import discover_patient_dirs
def _is_nan(value: object) -> bool:
return isinstance(value, float) and value != value # NaN != NaN
in_path = Path(input_dir)
out_path = Path(cohort_output_dir)
out_path.mkdir(parents=True, exist_ok=True)
patient_dirs = discover_patient_dirs(in_path)
records: list = []
for pdir in patient_dirs:
metrics_file = pdir / output_subdir / "metrics.json"
if not metrics_file.is_file():
continue
try:
data = json.loads(metrics_file.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
print(f"[tract_similarity_cohort] skipping {metrics_file}: {exc}", file=sys.stderr)
continue
pid = data.get("patient_id", pdir.name)
for family in ("overlap", "correlation", "distance_mm", "distribution"):
for metric_name, value in (data.get(family) or {}).items():
records.append((pid, family, metric_name, value))
if not records:
raise RuntimeError(
f"No per-patient metrics.json found under {in_path}/*/ {output_subdir}. "
f"Run 'tract_similarity' for each patient first."
)
def _write_csv(path: Path, header: list, rows: list) -> None:
with path.open("w", encoding="utf-8", newline="") as fh:
writer = csv.writer(fh)
writer.writerow(header)
writer.writerows(rows)
per_patient_csv = out_path / "per_patient.csv"
_write_csv(per_patient_csv, ["patient_id", "metric_family", "metric_name", "value"], records)
summary_rows: list = []
outliers: dict = {}
grouped: dict = {}
for pid, family, name, value in records:
grouped.setdefault((family, name), []).append((pid, value))
for (family, name), pid_values in sorted(grouped.items()):
values = np.asarray(
[v for _, v in pid_values if v is not None and not _is_nan(v)], dtype=np.float64
)
n = int(values.size)
if n == 0:
summary_rows.append([family, name, 0, "", "", "", "", ""])
continue
mean = float(np.mean(values))
std = float(np.std(values))
median = float(np.median(values))
q25 = float(np.percentile(values, 25))
q75 = float(np.percentile(values, 75))
summary_rows.append([family, name, n, mean, std, median, q25, q75])
if std > 0.0:
for pid, v in pid_values:
if v is None or _is_nan(v):
continue
if abs(float(v) - median) > outlier_sd_threshold * std:
outliers.setdefault(pid, []).append(
{"family": family, "metric": name, "value": float(v)}
)
summary_csv = out_path / "summary.csv"
_write_csv(
summary_csv,
["metric_family", "metric_name", "n", "mean", "std", "median", "q25", "q75"],
summary_rows,
)
outliers_file = out_path / "outliers.json"
outliers_file.write_text(json.dumps(outliers, indent=2), encoding="utf-8")
print(
f"[tract_similarity_cohort] wrote {summary_csv}, {per_patient_csv}, {outliers_file}",
file=sys.stderr,
)
return [str(summary_csv), str(per_patient_csv), str(outliers_file)]
# ---------------------------------------------------------------------------
# Verifiers
# ---------------------------------------------------------------------------
[docs]
def verify_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]:
"""Pre-flight checks for the per-patient tract_similarity workflow."""
if context.output_dir is None:
return ["output_dir is not set in the processing context"]
errors: List[str] = []
base = Path(context.output_dir)
ts = config.tract_similarity
probtrackx_dir = base / ts.probtrackx_relpath
if not probtrackx_dir.is_dir():
errors.append(
f"probtrackx output directory not found at {probtrackx_dir}. "
f"Run 'hcp' workflow first."
)
else:
run_dirs = _discover_run_dirs(probtrackx_dir)
if not any(
(d / ts.fdt_name).is_file() and (d / ts.waytotal_name).is_file() for d in run_dirs
):
errors.append(
f"No '{ts.fdt_name}' + '{ts.waytotal_name}' pair found under "
f"{probtrackx_dir} (nor in left/ or right/ subdirectories). "
f"Run 'hcp' workflow first."
)
atlas_candidate = base / ts.atlas_relpath
if not atlas_candidate.is_file() and not list(base.glob(ts.atlas_relpath)):
errors.append(
f"No warped atlas found matching '{ts.atlas_relpath}' under {base}. "
f"Run 'atlas_to_patient' first, or adjust 'tract_similarity.atlas_relpath'."
)
return errors
[docs]
def verify_cohort_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]:
"""Pre-flight checks for the cohort aggregation workflow."""
if context.output_dir is None:
return ["output_dir is not set in the processing context"]
input_dir = _resolve_cohort_input_dir(Path(context.output_dir))
patient_dirs = discover_patient_dirs(input_dir, sort=False)
if not patient_dirs:
return [
f"No patient subdirectories found under {input_dir}. "
f"Run 'tract_similarity' for each patient first."
]
return []
def _resolve_cohort_input_dir(output_dir: Path) -> Path:
"""Mirror of atlas/_resolve_input_dir for the cohort aggregation.
When the CLI runs the cohort-level workflow it sets ``output_dir`` to
``<base>/cohort``. Patient data lives one level up.
"""
if output_dir.name == "cohort" and output_dir.parent.is_dir():
return output_dir.parent
return output_dir
# ---------------------------------------------------------------------------
# Workflow builders
# ---------------------------------------------------------------------------
[docs]
@workflow(
name="tract_similarity",
description=(
"Per-patient tract similarity analysis: compare native-space "
"probtrackx2 density to the warped cohort mean atlas across four "
"metric families (overlap, correlation, distance, distribution)."
),
protocol="hcp",
)
@verify(verify_requirements)
def build_workflow(*, config: PipelineConfig, context: ProcessingContext) -> pe.Workflow:
"""Build the per-patient tract_similarity workflow.
Inputs (probtrackx density, warped atlas) live under ``context.output_dir``
and are accessed via ``ts.probtrackx_relpath`` / ``ts.atlas_relpath`` from
config — those support hemisphere-split layouts and glob patterns.
"""
if context.output_dir is None:
raise ValueError("output_dir must be set before building the workflow")
wf = pe.Workflow(name=f"tract_similarity_{context.patient_id}")
if context.working_dir:
wf.base_dir = str(context.working_dir)
ts = config.tract_similarity
node = pe.Node(
Function(
input_names=_SIM_INPUTS,
output_names=["metrics_file"],
function=_compute_similarity_task,
),
name="compute_metrics",
)
node.inputs.patient_output = str(context.output_dir)
node.inputs.patient_id = context.patient_id
node.inputs.probtrackx_relpath = ts.probtrackx_relpath
node.inputs.fdt_name = ts.fdt_name
node.inputs.waytotal_name = ts.waytotal_name
node.inputs.atlas_relpath = ts.atlas_relpath
node.inputs.output_subdir = ts.output_subdir
node.inputs.subject_threshold_mode = ts.subject_threshold.mode
node.inputs.subject_threshold_value = float(ts.subject_threshold.value)
node.inputs.atlas_threshold_mode = ts.atlas_threshold.mode
node.inputs.atlas_threshold_value = float(ts.atlas_threshold.value)
node.inputs.n_bins = int(ts.n_bins)
# Fixed-name anchor for cross-workflow ordering. Meta-workflows
# (e.g. full_pipeline stage 5) wire entry_gate.ready from upstream
# probtrackx + atlas_to_patient completion; standalone runs leave it
# unset and compute_metrics fires from its statically-set inputs.
entry_gate = pe.Node(IdentityInterface(fields=["ready"]), name="entry_gate")
wf.add_nodes([entry_gate, node])
wf.connect(entry_gate, "ready", node, "_ready_anchor")
logger.info(
"Built tract_similarity workflow for {} "
"(subject_threshold={}:{}, atlas_threshold={}:{})",
context.patient_id,
ts.subject_threshold.mode,
ts.subject_threshold.value,
ts.atlas_threshold.mode,
ts.atlas_threshold.value,
)
return wf
[docs]
@workflow(
name="tract_similarity_cohort",
description=(
"Cohort-level aggregation of per-patient tract_similarity metrics. "
"Emits summary.csv, per_patient.csv, and outliers.json."
),
protocol="hcp",
scope="cohort",
)
@verify(verify_cohort_requirements)
def build_cohort_workflow(*, config: PipelineConfig, context: ProcessingContext) -> pe.Workflow:
"""Build the cohort-level tract_similarity aggregation workflow."""
if context.output_dir is None:
raise ValueError("output_dir must be set before building the workflow")
wf = pe.Workflow(name="tract_similarity_cohort")
if context.working_dir:
wf.base_dir = str(context.working_dir)
ts = config.tract_similarity
cohort_out = Path(context.output_dir) / ts.cohort_output_subdir
input_dir = _resolve_cohort_input_dir(Path(context.output_dir))
node = pe.Node(
Function(
input_names=_COHORT_INPUTS,
output_names=["generated_files"],
function=_aggregate_cohort_task,
),
name="aggregate_cohort",
)
node.inputs.input_dir = str(input_dir)
node.inputs.cohort_output_dir = str(cohort_out)
node.inputs.output_subdir = ts.output_subdir
node.inputs.outlier_sd_threshold = float(config.qc.outlier_sd_threshold)
wf.add_nodes([node])
logger.info("Built tract_similarity_cohort workflow rooted at {}", input_dir)
return wf