"""Cohort-scope workflow: per-HCP-subject leave-one-out similarity vs the cohort atlas.
Produces, for each numeric (HCP) subject under the cohort output directory, a
``metrics.json`` (and optionally four NIfTI volumes) under
``<subject>/<tract_similarity.output_subdir>/`` mirroring the per-patient
``tract_similarity`` workflow. The per-subject reference is a leave-one-out
cohort mean atlas, built in memory by subtracting that subject's volume from
the precomputed stack sum.
Registered as ``thesis run -w tract_similarity_hcp_loo -c <profile>``.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Tuple
import nipype.pipeline.engine as pe
from nipype.interfaces.utility import Function
from thesis.core.decorators import verify, workflow
from thesis.core.logging import get_logger
logger = get_logger(__name__)
def _discover_hcp_subjects(
input_dir: Path | str,
tractography_relpath: str,
) -> List[Tuple[str, List[Tuple[Path, float]]]]:
"""Discover numeric-pid HCP subject directories with valid tractography input.
Mirrors :func:`thesis.workflows.atlas._io._collect_patient_inputs` but also
preserves the patient ID and sorts deterministically. The returned list is
sorted lexicographically by ``pid`` so that ``patient_inputs[i]`` and
``pids[i]`` stay aligned downstream.
Args:
input_dir: Cohort directory containing numeric patient subdirectories.
tractography_relpath: Relative path under each subject directory where
tractography outputs live (e.g. ``"tractography/probtrackx2"``).
Returns:
List of ``(pid, runs)`` pairs. ``runs`` is a list of
``(fdt_path, waytotal)`` tuples, identical in shape to one entry of
``_collect_patient_inputs``'s output.
"""
base = Path(input_dir)
pairs: List[Tuple[str, List[Tuple[Path, float]]]] = []
for pdir in sorted(base.iterdir(), key=lambda d: d.name):
if not pdir.is_dir() or not pdir.name.isdigit():
continue
probtrackx_dir = pdir / tractography_relpath
hemi_dirs = [d for d in (probtrackx_dir / "left", probtrackx_dir / "right") if d.is_dir()]
run_dirs = hemi_dirs or [probtrackx_dir]
runs: List[Tuple[Path, float]] = []
for run_dir in run_dirs:
fdt_file = run_dir / "warped_streamlines" / "fdt_paths.nii.gz"
waytotal_file = run_dir / "waytotal"
if not fdt_file.is_file() or not waytotal_file.is_file():
continue
try:
waytotal = float(waytotal_file.read_text().strip())
except ValueError:
continue
if waytotal <= 0.0:
continue
runs.append((fdt_file, waytotal))
if runs:
pairs.append((pdir.name, runs))
return pairs
def _hcp_loo_task(
input_dir: str,
tractography_relpath: str,
normalization_method: str,
output_subdir: str,
subject_threshold_mode: str,
subject_threshold_value: float,
atlas_threshold_mode: str,
atlas_threshold_value: float,
n_bins: int,
minimum_subjects: int,
write_volumes: bool,
learned_prediction_relpath: str = "",
learned_support_threshold: float = 0.0,
_ready_anchor: object = None,
) -> list:
"""Per-cohort task body: compute LOO similarity metrics for every HCP subject.
Runs entirely in template space — the cohort atlas inputs are already
aligned. For each subject ``i`` the leave-one-out atlas is
``(stack_sum - stack[i]) / (N - 1)``; metrics come from the existing
``tract_similarity._metrics`` family. Writes per-subject artefacts into
``<input_dir>/<pid>/<output_subdir>/``.
All imports are local because this body is invoked as a Nipype
``Function`` and must pickle cleanly to the subprocess.
Returns:
List of patient IDs for which artefacts were successfully written.
"""
import json
import sys
from pathlib import Path
import nibabel as nib
import numpy as np
from thesis.core.exceptions import ProcessingError
from thesis.workflows.atlas._io import _build_patient_stack
from thesis.workflows.atlas._params import NormalizationMethod
from thesis.workflows.learned_atlas.diagnostics import evaluate_prediction
from thesis.workflows.tract_similarity._io import voxel_size_mm
from thesis.workflows.tract_similarity._metrics import (
correlation_metrics,
distance_metrics,
distribution_metrics,
overlap_metrics,
)
from thesis.workflows.tract_similarity.hcp_loo import _discover_hcp_subjects
from thesis.workflows.tract_similarity.workflow import apply_threshold
base = Path(input_dir)
subjects = _discover_hcp_subjects(base, tractography_relpath)
if len(subjects) < minimum_subjects:
raise ProcessingError(
f"HCP LOO requires >= {minimum_subjects} subjects with valid "
f"fdt+waytotal under {base}, found {len(subjects)}."
)
pids = [pid for pid, _ in subjects]
patient_inputs = [runs for _, runs in subjects]
# Load reference metadata from the first run for affine, header, shape.
ref_file = patient_inputs[0][0][0]
ref_img: nib.Nifti1Image = nib.load(str(ref_file)) # type: ignore[assignment]
affine = np.asarray(ref_img.affine)
header = ref_img.header
_shape = ref_img.shape
ref_shape: tuple[int, int, int] = (int(_shape[0]), int(_shape[1]), int(_shape[2]))
try:
stack = _build_patient_stack(
patient_inputs, ref_shape, NormalizationMethod(normalization_method)
)
except ValueError as exc:
raise ProcessingError(f"HCP LOO stack build failed under {base}: {exc}") from exc
n = stack.shape[0]
stack_sum = stack.sum(axis=0)
spacing = voxel_size_mm(affine)
successful: list = []
for i, pid in enumerate(pids):
try:
vol_subj = stack[i]
loo_mean = (stack_sum - vol_subj) / float(n - 1)
thr_subj = apply_threshold(vol_subj, subject_threshold_mode, subject_threshold_value)
thr_atlas = apply_threshold(loo_mean, atlas_threshold_mode, atlas_threshold_value)
mask_subj = vol_subj > thr_subj
mask_atlas = loo_mean > thr_atlas
overlap: dict = overlap_metrics(mask_subj, mask_atlas)
correlation: dict = correlation_metrics(vol_subj, loo_mean)
distance_mm: dict = distance_metrics(mask_subj, mask_atlas, voxel_size_mm=spacing)
distribution: dict = distribution_metrics(vol_subj, loo_mean, n_bins=int(n_bins))
voxel_counts: dict = {
"probtrackx": int(np.count_nonzero(mask_subj)),
"atlas": int(np.count_nonzero(mask_atlas)),
"intersection": int(np.count_nonzero(mask_subj & mask_atlas)),
"union": int(np.count_nonzero(mask_subj | mask_atlas)),
}
metrics: dict = {
"patient_id": pid,
"thresholds": {
"subject": {
"mode": subject_threshold_mode,
"value": float(subject_threshold_value),
"effective": float(thr_subj),
},
"atlas": {
"mode": atlas_threshold_mode,
"value": float(atlas_threshold_value),
"effective": float(thr_atlas),
},
},
"voxel_counts": voxel_counts,
"overlap": overlap,
"correlation": correlation,
"distance_mm": distance_mm,
"distribution": distribution,
}
if learned_prediction_relpath:
pred_path = base / pid / learned_prediction_relpath
if pred_path.is_file():
_pred_obj = nib.load(str(pred_path)).dataobj # type: ignore[attr-defined]
vol_pred = np.asarray(_pred_obj, dtype=np.float32)
if vol_pred.shape[:3] == ref_shape:
thr_pred = apply_threshold(
vol_pred, subject_threshold_mode, subject_threshold_value
)
mask_pred = vol_pred > thr_pred
metrics["learned_atlas"] = {
"overlap": overlap_metrics(mask_subj, mask_pred),
"correlation": correlation_metrics(vol_subj, vol_pred),
"distance_mm": distance_metrics(
mask_subj, mask_pred, voxel_size_mm=spacing
),
"distribution": distribution_metrics(
vol_subj, vol_pred, n_bins=int(n_bins)
),
"non_circular": evaluate_prediction(
vol_pred,
vol_subj,
loo_mean,
support_threshold=float(learned_support_threshold),
),
}
else:
print(
f"[tract_similarity_hcp_loo] WARNING: {pid} learned "
f"prediction shape {vol_pred.shape[:3]} != {ref_shape}; "
"skipping learned arm",
file=sys.stderr,
)
else:
print(
f"[tract_similarity_hcp_loo] {pid}: no learned prediction "
f"at {pred_path}; skipping learned arm",
file=sys.stderr,
)
out_dir = base / pid / output_subdir
out_dir.mkdir(parents=True, exist_ok=True)
(out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
if write_volumes:
for name, vol, dtype in (
("subject_normalized", vol_subj, np.float32),
("atlas_normalized", loo_mean, np.float32),
("subject_mask", mask_subj, np.uint8),
("atlas_mask", mask_atlas, np.uint8),
):
nib.save(
nib.Nifti1Image(np.asarray(vol).astype(dtype), affine, header),
str(out_dir / f"{name}.nii.gz"),
)
print(
f"[tract_similarity_hcp_loo] {pid}: "
f"dice={overlap['dice']:.4f}, "
f"voxels_subj={voxel_counts['probtrackx']}, "
f"voxels_atlas={voxel_counts['atlas']}",
file=sys.stderr,
)
successful.append(pid)
except Exception as exc: # noqa: BLE001 — failure isolation per subject
print(
f"[tract_similarity_hcp_loo] WARNING: subject {pid} failed: {exc}",
file=sys.stderr,
)
continue
return successful
def _resolve_input_dir(output_dir: Path) -> Path:
"""Return the directory containing numeric patient subdirectories.
Mirror of atlas/_resolve_input_dir.
When the CLI runs a cohort-level workflow, ``output_dir`` may be
``<base>/cohort`` and the numeric patient data lives one level up.
"""
patient_dirs = [d for d in output_dir.iterdir() if d.is_dir() and d.name.isdigit()]
if not patient_dirs and output_dir.name == "cohort" and output_dir.parent.is_dir():
parent_dirs = [d for d in output_dir.parent.iterdir() if d.is_dir() and d.name.isdigit()]
if parent_dirs:
return output_dir.parent
return output_dir
[docs]
def verify_requirements(config, context) -> list:
"""Pre-flight checks for tract_similarity_hcp_loo.
Returns a list of human-readable error strings; an empty list means the
workflow is ready to build. The CLI surfaces these to the user verbatim.
"""
if context.output_dir is None:
return ["output_dir is not set in the processing context"]
out = Path(context.output_dir)
if not out.is_dir():
return [f"output_dir does not exist: {out}"]
input_dir = _resolve_input_dir(out)
tractography_relpath = config.atlas.tractography_relpath
minimum = int(config.tract_similarity.hcp_loo.minimum_subjects)
subjects = _discover_hcp_subjects(input_dir, tractography_relpath)
if not subjects:
return [
f"No numeric (HCP) subject directories with valid "
f"'{tractography_relpath}/<run>/warped_streamlines/fdt_paths.nii.gz' "
f"+ 'waytotal' found under {input_dir}. Run the upstream tractography "
f"workflow first."
]
if len(subjects) < minimum:
return [
f"HCP LOO requires at least N={minimum} subjects with valid fdt+waytotal "
f"under {input_dir}, found {len(subjects)}. Run the upstream tractography "
f"workflow first or lower tract_similarity.hcp_loo.minimum_subjects."
]
return []
_HCP_LOO_INPUTS = [
"input_dir",
"tractography_relpath",
"normalization_method",
"output_subdir",
"subject_threshold_mode",
"subject_threshold_value",
"atlas_threshold_mode",
"atlas_threshold_value",
"n_bins",
"minimum_subjects",
"write_volumes",
"learned_prediction_relpath",
"learned_support_threshold",
"_ready_anchor",
]
[docs]
@workflow(
name="tract_similarity_hcp_loo",
description=(
"Per-HCP-subject leave-one-out tract similarity vs the cohort atlas. "
"Emits metrics.json + optional NIfTI volumes per subject in the same "
"format as the per-patient tract_similarity workflow."
),
protocol="hcp",
scope="cohort",
)
@verify(verify_requirements)
def build_workflow(*, config, context) -> "pe.Workflow":
"""Build the cohort-scope HCP-LOO tract similarity workflow."""
if context.output_dir is None:
raise ValueError("output_dir must be set before building the workflow")
wf = pe.Workflow(name="tract_similarity_hcp_loo")
if context.working_dir:
wf.base_dir = str(context.working_dir)
input_dir = _resolve_input_dir(Path(context.output_dir))
ts = config.tract_similarity
atlas_cfg = config.atlas
node = pe.Node(
Function(
input_names=_HCP_LOO_INPUTS,
output_names=["successful_subjects"],
function=_hcp_loo_task,
),
name="compute_hcp_loo_metrics",
)
node.inputs.input_dir = str(input_dir)
node.inputs.tractography_relpath = atlas_cfg.tractography_relpath
node.inputs.normalization_method = atlas_cfg.normalization_method.value
node.inputs.output_subdir = ts.output_subdir
node.inputs.subject_threshold_mode = ts.subject_threshold.mode
node.inputs.subject_threshold_value = float(ts.subject_threshold.value)
node.inputs.atlas_threshold_mode = ts.atlas_threshold.mode
node.inputs.atlas_threshold_value = float(ts.atlas_threshold.value)
node.inputs.n_bins = int(ts.n_bins)
node.inputs.minimum_subjects = int(ts.hcp_loo.minimum_subjects)
node.inputs.write_volumes = bool(ts.hcp_loo.write_volumes)
node.inputs.learned_prediction_relpath = ts.hcp_loo.learned_prediction_relpath
node.inputs.learned_support_threshold = float(ts.hcp_loo.learned_support_threshold)
wf.add_nodes([node])
logger.info(
"Built tract_similarity_hcp_loo workflow rooted at {} (min_subjects={}, write_volumes={})",
input_dir,
ts.hcp_loo.minimum_subjects,
ts.hcp_loo.write_volumes,
)
return wf