"""Extended QC checks for pipeline outputs.
Covers per-target connectivity maps, SynthSeg quality, ROI transform
comparison, waypoints validation, brain mask overlay, warp field
Jacobian analysis, BedpostX fibre quality, and cross-subject outlier
detection.
All plotting uses the ``Agg`` backend for headless rendering.
"""
import csv
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from thesis.workflows.qc._plotting import configure_headless_matplotlib
configure_headless_matplotlib()
import matplotlib.pyplot as plt # noqa: E402
import numpy as np # noqa: E402
from thesis.core.logging import get_logger # noqa: E402
logger = get_logger(__name__)
try:
import nibabel as nib
NIBABEL_AVAILABLE = True
except ImportError: # pragma: no cover
NIBABEL_AVAILABLE = False
try:
from nilearn.plotting import plot_roi, plot_stat_map
NILEARN_AVAILABLE = True
except Exception as exc: # pragma: no cover
logger.debug("Failed to import nilearn plotting support: {}", exc)
NILEARN_AVAILABLE = False
__all__ = [
"generate_connectivity_map_figures",
"collect_connectivity_map_stats",
"parse_synthseg_qc_csv",
"parse_synthseg_volumes_csv",
"generate_synthseg_overlay",
"generate_roi_transform_comparison",
"validate_waypoints_file",
"generate_brain_mask_overlay",
"compute_jacobian_stats",
"generate_bedpostx_overlay",
"generate_waytotal_overlay",
"detect_batch_outliers",
]
# ---------------------------------------------------------------------------
# 1. Per-target connectivity maps
# ---------------------------------------------------------------------------
[docs]
def collect_connectivity_map_stats(
tractography_dir: Union[str, Path],
) -> Dict[str, Dict[str, Any]]:
"""Collect density statistics for each ``seeds_to_<target>`` map.
Args:
tractography_dir: ProbTrackX2 output directory.
Returns:
Mapping of target name to stats dict (nonzero_voxels,
mean_density, max_density, total_streamlines).
"""
if not NIBABEL_AVAILABLE:
return {}
base = Path(tractography_dir)
results: Dict[str, Dict[str, Any]] = {}
for seed_file in sorted(base.glob("seeds_to_*.nii*")):
target = seed_file.name.replace("seeds_to_", "").replace(".nii.gz", "").replace(".nii", "")
img = nib.load(str(seed_file))
data: np.ndarray = np.asarray(img.dataobj) # type: ignore[attr-defined]
nz = data[data > 0]
results[target] = {
"nonzero_voxels": int(nz.size),
"mean_density": round(float(np.mean(nz)), 2) if nz.size > 0 else 0.0,
"max_density": int(np.max(nz)) if nz.size > 0 else 0,
"total_streamlines": int(np.sum(nz)) if nz.size > 0 else 0,
}
return results
# ---------------------------------------------------------------------------
# 2. SynthSeg QC score parsing
# ---------------------------------------------------------------------------
[docs]
def parse_synthseg_qc_csv(
patient_output: Union[str, Path],
threshold: float = 0.6,
) -> Dict[str, Any]:
"""Parse a SynthSeg ``_qc.csv`` file and flag low-quality subjects.
Args:
patient_output: Patient-level output directory.
threshold: Minimum acceptable QC score (0–1).
Returns:
Dict with ``path``, ``scores`` (list of floats), ``mean_score``,
``passed`` (bool), and ``threshold``. Empty dict if file not found.
"""
base = Path(patient_output)
seg_dir = base / "segmentation" / "synthseg"
if not seg_dir.is_dir():
return {}
qc_files = list(seg_dir.glob("*_qc.csv"))
if not qc_files:
return {}
qc_file = qc_files[0]
scores: List[float] = []
try:
with open(qc_file, newline="") as f:
reader = csv.DictReader(f)
for row in reader:
for key, val in row.items():
if key and "qc" in key.lower():
try:
scores.append(float(val))
except (ValueError, TypeError):
pass
except Exception as exc:
logger.debug("Could not parse SynthSeg QC CSV {}: {}", qc_file, exc)
return {}
if not scores:
# Try reading as simple single-value file
try:
text = qc_file.read_text().strip()
for line in text.splitlines()[1:]: # skip header
parts = line.strip().split(",")
for part in parts:
try:
scores.append(float(part.strip()))
except (ValueError, TypeError):
pass
except Exception:
pass
if not scores:
return {}
mean_score = float(np.mean(scores))
return {
"path": str(qc_file),
"scores": scores,
"mean_score": round(mean_score, 4),
"passed": mean_score >= threshold,
"threshold": threshold,
}
# ---------------------------------------------------------------------------
# 3. SynthSeg segmentation overlay
# ---------------------------------------------------------------------------
[docs]
def generate_synthseg_overlay(
patient_output: Union[str, Path],
background_image: Union[str, Path],
output_dir: Union[str, Path],
) -> List[Path]:
"""Generate a colour-coded SynthSeg label overlay on T1w.
Args:
patient_output: Patient-level output directory.
background_image: T1w anatomical background.
output_dir: Directory for output PNGs.
Returns:
List of generated PNG paths.
"""
if not NILEARN_AVAILABLE or not NIBABEL_AVAILABLE:
return []
seg_dir = Path(patient_output) / "segmentation" / "synthseg"
if not seg_dir.is_dir():
return []
seg_files = list(seg_dir.glob("*_synthseg.nii*"))
if not seg_files:
return []
out = Path(output_dir) / "synthseg"
out.mkdir(parents=True, exist_ok=True)
generated: List[Path] = []
for seg_file in seg_files:
out_path = out / f"{seg_file.stem}_overlay.png"
display = plot_roi(
roi_img=str(seg_file),
bg_img=str(background_image),
display_mode="ortho",
title="SynthSeg segmentation",
)
display.savefig(str(out_path), dpi=150)
display.close()
generated.append(out_path)
plt.close("all")
logger.info("Generated {} SynthSeg overlay(s)", len(generated))
return generated
# ---------------------------------------------------------------------------
# 4. Pre- vs post-transform ROI comparison
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# 5. Waypoints file validation
# ---------------------------------------------------------------------------
[docs]
def validate_waypoints_file(
tractography_dir: Union[str, Path],
reference_image: Optional[Union[str, Path]] = None,
) -> Dict[str, Any]:
"""Validate waypoints text files.
Checks that every path listed in waypoints files exists and (if
a reference image is provided) has compatible dimensions.
Args:
tractography_dir: ProbTrackX2 output directory.
reference_image: Optional reference NIfTI (e.g. brain mask)
to check dimension compatibility.
Returns:
Dict with ``files_checked``, ``total_paths``, ``missing``,
``dimension_mismatches``, ``valid``.
"""
base = Path(tractography_dir)
result: Dict[str, Any] = {
"files_checked": 0,
"total_paths": 0,
"missing": [],
"dimension_mismatches": [],
"valid": True,
}
# Load reference shape if provided
ref_shape: Optional[tuple] = None
if reference_image and NIBABEL_AVAILABLE:
ref_path = Path(reference_image)
if ref_path.exists():
ref_shape = nib.load(str(ref_path)).shape[:3] # type: ignore[attr-defined]
wp_patterns = ["waypoints*.txt", "merged_waypoints*.txt"]
for pattern in wp_patterns:
for wp_file in base.rglob(pattern):
result["files_checked"] += 1
try:
paths = [line.strip() for line in wp_file.read_text().splitlines() if line.strip()]
except OSError:
continue
for p in paths:
result["total_paths"] += 1
nii_path = Path(p)
if not nii_path.exists():
result["missing"].append(str(nii_path))
result["valid"] = False
elif ref_shape and NIBABEL_AVAILABLE:
try:
wp_shape = nib.load(str(nii_path)).shape[:3] # type: ignore[attr-defined]
if wp_shape != ref_shape:
result["dimension_mismatches"].append(
{
"path": str(nii_path),
"expected": ref_shape,
"actual": wp_shape,
}
)
result["valid"] = False
except Exception:
pass
return result
# ---------------------------------------------------------------------------
# 6. Brain mask overlay
# ---------------------------------------------------------------------------
[docs]
def generate_brain_mask_overlay(
brain_mask: Union[str, Path],
background_image: Union[str, Path],
output_dir: Union[str, Path],
) -> List[Path]:
"""Generate an overlay of the brain mask on an anatomical image.
Args:
brain_mask: Path to the brain mask NIfTI.
background_image: Anatomical background (T1w or DWI b0).
output_dir: Directory for output PNGs.
Returns:
List of generated PNG paths.
"""
if not NILEARN_AVAILABLE:
return []
mask_path = Path(brain_mask)
if not mask_path.exists():
return []
out = Path(output_dir) / "brain_mask"
out.mkdir(parents=True, exist_ok=True)
out_path = out / "brain_mask_overlay.png"
display = plot_roi(
roi_img=str(mask_path),
bg_img=str(background_image),
display_mode="ortho",
title="Brain mask",
cmap="Greens",
)
display.savefig(str(out_path), dpi=150)
display.close()
plt.close("all")
logger.info("Generated brain mask overlay: {}", out_path)
return [out_path]
# ---------------------------------------------------------------------------
# 7. SynthSeg volumes sanity check
# ---------------------------------------------------------------------------
# Approximate expected volume fractions relative to total intracranial volume.
_VOLUME_BOUNDS: Dict[str, Tuple[float, float]] = {
"lateral ventricle": (0.005, 0.15),
"cerebral white matter": (0.15, 0.45),
"cerebral cortex": (0.20, 0.55),
"thalamus": (0.003, 0.03),
"caudate": (0.001, 0.02),
"putamen": (0.001, 0.025),
"hippocampus": (0.001, 0.02),
"cerebellum cortex": (0.03, 0.15),
"brain-stem": (0.01, 0.06),
}
[docs]
def parse_synthseg_volumes_csv(
patient_output: Union[str, Path],
) -> Dict[str, Any]:
"""Parse SynthSeg volumes CSV and flag anatomically unreasonable values.
Args:
patient_output: Patient-level output directory.
Returns:
Dict with ``path``, ``volumes`` (dict of structure -> volume),
``total_volume``, ``warnings`` (list of issue strings).
"""
seg_dir = Path(patient_output) / "segmentation" / "synthseg"
if not seg_dir.is_dir():
return {}
vol_files = list(seg_dir.glob("*_volumes.csv"))
if not vol_files:
return {}
vol_file = vol_files[0]
volumes: Dict[str, float] = {}
volume_warnings: List[str] = []
try:
with open(vol_file, newline="") as f:
reader = csv.DictReader(f)
for row in reader:
for key, val in row.items():
if key and key.strip():
try:
volumes[key.strip().lower()] = float(val)
except (ValueError, TypeError):
pass
except Exception as exc:
logger.debug("Could not parse SynthSeg volumes CSV {}: {}", vol_file, exc)
return {}
if not volumes:
return {}
total = sum(volumes.values())
if total <= 0:
return {
"path": str(vol_file),
"volumes": volumes,
"total_volume": 0,
"warnings": ["Total volume is zero"],
}
# Check volume fractions
for structure, (low, high) in _VOLUME_BOUNDS.items():
for vol_key, vol_val in volumes.items():
if structure in vol_key:
frac = vol_val / total
if frac < low:
volume_warnings.append(f"{vol_key}: {frac:.3f} of total (expected >= {low})")
elif frac > high:
volume_warnings.append(f"{vol_key}: {frac:.3f} of total (expected <= {high})")
# Hemisphere asymmetry check
left_wm = sum(v for k, v in volumes.items() if "left" in k and "white" in k)
right_wm = sum(v for k, v in volumes.items() if "right" in k and "white" in k)
if left_wm > 0 and right_wm > 0:
ratio = max(left_wm, right_wm) / min(left_wm, right_wm)
if ratio > 2.0:
volume_warnings.append(f"Hemisphere WM asymmetry ratio: {ratio:.2f} (>2.0)")
return {
"path": str(vol_file),
"volumes": volumes,
"total_volume": round(total, 1),
"warnings": volume_warnings,
}
# ---------------------------------------------------------------------------
# 8. Cross-subject outlier detection
# ---------------------------------------------------------------------------
[docs]
def detect_batch_outliers(
stats_list: List[Dict[str, Any]],
sd_threshold: float = 2.0,
) -> List[Dict[str, Any]]:
"""Flag subjects whose metrics are outliers relative to the batch.
Args:
stats_list: Batch statistics from :func:`collect_batch_stats`.
sd_threshold: Number of standard deviations to flag.
Returns:
List of dicts, each with ``patient_id``, ``metric``,
``value``, ``mean``, ``std``, ``z_score``.
"""
if len(stats_list) < 3:
return []
outliers: List[Dict[str, Any]] = []
# Check waytotal
waytotals = [
(s["patient_id"], s["waytotal"]) for s in stats_list if s.get("waytotal") is not None
]
if len(waytotals) >= 3:
vals = np.array([v for _, v in waytotals])
mean, std = float(np.mean(vals)), float(np.std(vals))
if std > 0:
for pid, val in waytotals:
z = abs(val - mean) / std
if z > sd_threshold:
outliers.append(
{
"patient_id": pid,
"metric": "waytotal",
"value": val,
"mean": round(mean, 1),
"std": round(std, 1),
"z_score": round(z, 2),
}
)
# Check nonzero_voxels
nz_data = [
(s["patient_id"], s["subject_space"]["nonzero_voxels"])
for s in stats_list
if s.get("subject_space", {}).get("nonzero_voxels") is not None
]
if len(nz_data) >= 3:
vals = np.array([v for _, v in nz_data])
mean, std = float(np.mean(vals)), float(np.std(vals))
if std > 0:
for pid, val in nz_data:
z = abs(val - mean) / std
if z > sd_threshold:
outliers.append(
{
"patient_id": pid,
"metric": "nonzero_voxels",
"value": val,
"mean": round(mean, 1),
"std": round(std, 1),
"z_score": round(z, 2),
}
)
# Check volume_fraction
vf_data = [
(s["patient_id"], s["subject_space"]["volume_fraction"])
for s in stats_list
if s.get("subject_space", {}).get("volume_fraction") is not None
]
if len(vf_data) >= 3:
vals = np.array([v for _, v in vf_data])
mean, std = float(np.mean(vals)), float(np.std(vals))
if std > 0:
for pid, val in vf_data:
z = abs(val - mean) / std
if z > sd_threshold:
outliers.append(
{
"patient_id": pid,
"metric": "volume_fraction",
"value": val,
"mean": round(mean, 6),
"std": round(std, 6),
"z_score": round(z, 2),
}
)
return outliers
# ---------------------------------------------------------------------------
# 9. Warp field Jacobian analysis
# ---------------------------------------------------------------------------
[docs]
def compute_jacobian_stats(
warp_field: Union[str, Path],
output_dir: Optional[Union[str, Path]] = None,
) -> Dict[str, Any]:
"""Compute Jacobian determinant statistics from a warp field.
Args:
warp_field: Path to a 4D or 5D warp field NIfTI
(e.g. from ANTs).
output_dir: If provided, save a Jacobian determinant
volume as NIfTI.
Returns:
Dict with ``negative_voxels``, ``negative_fraction``,
``min_jacobian``, ``max_jacobian``, ``mean_jacobian``.
Empty dict if warp field cannot be loaded.
"""
if not NIBABEL_AVAILABLE:
return {}
warp_path = Path(warp_field)
if not warp_path.exists():
return {}
try:
img = nib.load(str(warp_path))
data: np.ndarray = np.asarray(img.dataobj) # type: ignore[attr-defined]
except Exception as exc:
logger.debug("Could not load warp field {}: {}", warp_path, exc)
return {}
# Handle both 4D (x,y,z,3) and 5D (x,y,z,1,3) warp fields
if data.ndim == 5:
data = data[:, :, :, 0, :]
if data.ndim != 4 or data.shape[-1] != 3:
logger.debug("Warp field has unexpected shape: {}", data.shape)
return {}
# Compute the 3x3 Jacobian matrix at each voxel (identity plus the
# displacement-field gradient) and take its determinant.
J = np.zeros(data.shape[:3] + (3, 3), dtype=np.float64)
for i in range(3):
for j in range(3):
J[..., i, j] = np.gradient(data[..., i], axis=j)
if i == j:
J[..., i, j] += 1.0 # identity + displacement gradient
jac = np.linalg.det(J)
total = int(jac.size)
neg_count = int(np.sum(jac < 0))
result: Dict[str, Any] = {
"negative_voxels": neg_count,
"negative_fraction": round(neg_count / total, 6) if total > 0 else 0.0,
"min_jacobian": round(float(np.min(jac)), 4),
"max_jacobian": round(float(np.max(jac)), 4),
"mean_jacobian": round(float(np.mean(jac)), 4),
}
if output_dir and neg_count > 0:
out = Path(output_dir) / "jacobian"
out.mkdir(parents=True, exist_ok=True)
jac_img = nib.Nifti1Image(jac.astype(np.float32), img.affine) # type: ignore[attr-defined]
jac_path = out / "jacobian_determinant.nii.gz"
nib.save(jac_img, jac_path)
result["jacobian_image"] = str(jac_path)
return result
# ---------------------------------------------------------------------------
# 10. BedpostX fibre quality
# ---------------------------------------------------------------------------
[docs]
def generate_bedpostx_overlay(
bedpostx_dir: Union[str, Path],
background_image: Union[str, Path],
output_dir: Union[str, Path],
) -> Tuple[List[Path], Dict[str, Any]]:
"""Generate an overlay of f1 (primary fibre fraction) and report stats.
Args:
bedpostx_dir: BedpostX output directory containing
``mean_f1samples.nii.gz``.
background_image: Anatomical background.
output_dir: Directory for output PNGs.
Returns:
Tuple of (list of PNG paths, stats dict with mean_f1, etc.).
"""
if not NIBABEL_AVAILABLE or not NILEARN_AVAILABLE:
return [], {}
bpx_dir = Path(bedpostx_dir)
f1_path = bpx_dir / "mean_f1samples.nii.gz"
if not f1_path.exists():
return [], {}
img = nib.load(str(f1_path))
data: np.ndarray = np.asarray(img.dataobj) # type: ignore[attr-defined]
mask = data > 0.01 # exclude near-zero voxels
stats: Dict[str, Any] = {
"mean_f1": round(float(np.mean(data[mask])), 4) if mask.any() else 0.0,
"median_f1": round(float(np.median(data[mask])), 4) if mask.any() else 0.0,
"min_f1": round(float(np.min(data[mask])), 4) if mask.any() else 0.0,
"max_f1": round(float(np.max(data[mask])), 4) if mask.any() else 0.0,
"low_f1_fraction": (
round(float(np.sum(data[mask] < 0.1) / mask.sum()), 4) if mask.any() else 0.0
),
}
out = Path(output_dir) / "bedpostx"
out.mkdir(parents=True, exist_ok=True)
out_path = out / "f1_overlay.png"
display = plot_stat_map(
stat_map_img=str(f1_path),
bg_img=str(background_image),
threshold=0.05,
display_mode="ortho",
title=f"BedpostX f1 (mean={stats['mean_f1']:.3f})",
colorbar=True,
)
display.savefig(str(out_path), dpi=150)
display.close()
plt.close("all")
logger.info("Generated BedpostX f1 overlay: {}", out_path)
return [out_path], stats
# ---------------------------------------------------------------------------
# 11. Template-space waytotal overlay (reuses track density infrastructure)
# ---------------------------------------------------------------------------
[docs]
def generate_waytotal_overlay(
tractography_dir: Union[str, Path],
template_image: Union[str, Path],
output_dir: Union[str, Path],
) -> List[Path]:
"""Generate overlay for template-space ``waytotal.nii.gz``.
Args:
tractography_dir: ProbTrackX2 output directory.
template_image: Template background image.
output_dir: Directory for output PNGs.
Returns:
List of generated PNG paths.
"""
if not NILEARN_AVAILABLE or not NIBABEL_AVAILABLE:
return []
wt_path = Path(tractography_dir) / "warped_streamlines" / "waytotal.nii.gz"
if not wt_path.exists():
return []
out = Path(output_dir) / "normtracks"
out.mkdir(parents=True, exist_ok=True)
out_path = out / "template_waytotal.png"
img = nib.load(str(wt_path))
data: np.ndarray = np.asarray(img.dataobj) # type: ignore[attr-defined]
nz = data[data > 0]
thresh = float(np.percentile(nz, 50)) if nz.size > 0 else 0.0
display = plot_stat_map(
stat_map_img=str(wt_path),
bg_img=str(template_image),
threshold=thresh,
display_mode="ortho",
title="Waytotal (template space, >= 50th pct)",
colorbar=True,
)
display.savefig(str(out_path), dpi=150)
display.close()
plt.close("all")
logger.info("Generated template waytotal overlay: {}", out_path)
return [out_path]