"""Nipype workflow definition for the Atlas generator."""
from pathlib import Path
from typing import List
import nipype.pipeline.engine as pe
from nipype.interfaces.utility import Function
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.decorators import produces, verify, workflow
from thesis.core.logging import get_logger
from thesis.core.path_declarations import CohortDir
logger = get_logger(__name__)
_ATLAS_INPUTS = [
"input_dir",
"output_dir",
"atlas_config",
"tractography_relpath",
"output_subdir",
"qc_config",
]
def _resolve_input_dir(output_dir: Path) -> Path:
"""Return the directory containing numeric patient subdirectories.
When the CLI runs a cohort-level workflow, ``output_dir`` is
``<base>/cohort`` and patient data lives one level up.
"""
patient_dirs = [d for d in output_dir.iterdir() if d.is_dir() and d.name.isdigit()]
if not patient_dirs and output_dir.name == "cohort" and output_dir.parent.is_dir():
parent_dirs = [d for d in output_dir.parent.iterdir() if d.is_dir() and d.name.isdigit()]
if parent_dirs:
return output_dir.parent
return output_dir
def _atlas_task(
input_dir: str,
output_dir: str,
atlas_config: dict,
tractography_relpath: str,
output_subdir: str,
qc_config: dict | None = None,
) -> list:
"""Nipype Function body: generate the statistical atlas and (optional) QC.
The cohort stack is loaded once and both the statistical-atlas outputs and
the QC outputs are derived from that single in-memory pass. When
``qc_config`` is ``None`` only the statistical atlas is produced.
"""
from pathlib import Path
from thesis.core.exceptions import ProcessingError
from thesis.workflows.atlas._params import NormalizationMethod
from thesis.workflows.atlas.compute import (
AtlasParams,
QcParams,
generate_atlas_with_qc,
)
d = dict(atlas_config)
if isinstance(d.get("normalization_method"), str):
d["normalization_method"] = NormalizationMethod(d["normalization_method"])
qc_params = QcParams(**qc_config) if qc_config else None
qc_output_dir = output_dir if qc_config else None
try:
return generate_atlas_with_qc(
output_dir=input_dir,
atlas_dir=str(Path(output_dir) / output_subdir),
params=AtlasParams(**d),
qc_params=qc_params,
qc_output_dir=qc_output_dir,
tractography_relpath=tractography_relpath,
)
except Exception as exc:
raise ProcessingError(f"Atlas generation failed: {exc}") from exc
[docs]
def verify_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]:
"""Verify cohort output directory exists."""
if not context.output_dir:
return ["output_dir is not set in the processing context"]
out_path = Path(context.output_dir)
if not out_path.is_dir():
return [f"Output directory does not exist: {out_path}"]
return []
[docs]
@workflow(
name="atlas",
description="Generates statistical reference atlas from cohort tractography data.",
scope="cohort",
)
@produces(atlas_dir=CohortDir("atlas"))
@verify(verify_requirements)
def build_workflow(
*,
atlas_dir: Path,
config: PipelineConfig,
context: ProcessingContext,
) -> pe.Workflow:
"""Build the cohort-level atlas workflow."""
del atlas_dir # resolved purely for its mkdir side-effect; _atlas_task rebuilds the path
wf = pe.Workflow(name="atlas_generation")
if context.working_dir:
wf.base_dir = str(context.working_dir)
if context.output_dir is None:
raise ValueError("output_dir must be set before building the workflow")
output_dir = Path(context.output_dir)
input_dir = _resolve_input_dir(output_dir)
tractography_relpath = config.atlas.tractography_relpath
atlas_node = pe.Node(
Function(input_names=_ATLAS_INPUTS, output_names="generated_files", function=_atlas_task),
name="generate_atlas",
)
atlas_node.inputs.input_dir = str(input_dir)
atlas_node.inputs.output_dir = str(output_dir)
atlas_node.inputs.atlas_config = {
"presence_value": config.atlas.presence_value,
"cov_mean_threshold_pct": config.atlas.cov_mean_threshold_pct,
"normalization_method": config.atlas.normalization_method.value,
}
atlas_node.inputs.tractography_relpath = tractography_relpath
atlas_node.inputs.output_subdir = config.atlas.output_subdir
# QC is computed in the SAME node from the single in-memory cohort stack,
# rather than a separate node that re-collects and re-stacks the cohort.
if config.atlas_qc.enabled:
qc = config.atlas_qc
atlas_node.inputs.qc_config = {
"subject_density_threshold": qc.subject_density_threshold,
"group_core_threshold": qc.group_core_threshold,
"leave_one_out": qc.leave_one_out,
"leave_one_out_min_subjects": qc.leave_one_out_min_subjects,
"compute_cv": qc.compute_cv,
"cv_mean_threshold": qc.cv_mean_threshold,
"compute_log_stats": qc.compute_log_stats,
"log_offset": qc.log_offset,
"qc_subdir": qc.output_subdir,
}
else:
atlas_node.inputs.qc_config = None
atlas_node.plugin_args = {"n_procs": config.hardware.threads}
wf.add_nodes([atlas_node])
logger.info("Built cohort-level Atlas workflow")
return wf