Source code for thesis.workflows.atlas.workflow

"""Nipype workflow definition for the Atlas generator."""

from pathlib import Path
from typing import List

import nipype.pipeline.engine as pe
from nipype.interfaces.utility import Function

from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.decorators import produces, verify, workflow
from thesis.core.logging import get_logger
from thesis.core.path_declarations import CohortDir

logger = get_logger(__name__)

_ATLAS_INPUTS = [
    "input_dir",
    "output_dir",
    "atlas_config",
    "tractography_relpath",
    "output_subdir",
    "qc_config",
]


def _resolve_input_dir(output_dir: Path) -> Path:
    """Return the directory containing numeric patient subdirectories.

    When the CLI runs a cohort-level workflow, ``output_dir`` is
    ``<base>/cohort`` and patient data lives one level up.
    """
    patient_dirs = [d for d in output_dir.iterdir() if d.is_dir() and d.name.isdigit()]
    if not patient_dirs and output_dir.name == "cohort" and output_dir.parent.is_dir():
        parent_dirs = [d for d in output_dir.parent.iterdir() if d.is_dir() and d.name.isdigit()]
        if parent_dirs:
            return output_dir.parent
    return output_dir


def _atlas_task(
    input_dir: str,
    output_dir: str,
    atlas_config: dict,
    tractography_relpath: str,
    output_subdir: str,
    qc_config: dict | None = None,
) -> list:
    """Nipype Function body: generate the statistical atlas and (optional) QC.

    The cohort stack is loaded once and both the statistical-atlas outputs and
    the QC outputs are derived from that single in-memory pass. When
    ``qc_config`` is ``None`` only the statistical atlas is produced.
    """
    from pathlib import Path

    from thesis.core.exceptions import ProcessingError
    from thesis.workflows.atlas._params import NormalizationMethod
    from thesis.workflows.atlas.compute import (
        AtlasParams,
        QcParams,
        generate_atlas_with_qc,
    )

    d = dict(atlas_config)
    if isinstance(d.get("normalization_method"), str):
        d["normalization_method"] = NormalizationMethod(d["normalization_method"])

    qc_params = QcParams(**qc_config) if qc_config else None
    qc_output_dir = output_dir if qc_config else None

    try:
        return generate_atlas_with_qc(
            output_dir=input_dir,
            atlas_dir=str(Path(output_dir) / output_subdir),
            params=AtlasParams(**d),
            qc_params=qc_params,
            qc_output_dir=qc_output_dir,
            tractography_relpath=tractography_relpath,
        )
    except Exception as exc:
        raise ProcessingError(f"Atlas generation failed: {exc}") from exc


[docs] def verify_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]: """Verify cohort output directory exists.""" if not context.output_dir: return ["output_dir is not set in the processing context"] out_path = Path(context.output_dir) if not out_path.is_dir(): return [f"Output directory does not exist: {out_path}"] return []
[docs] @workflow( name="atlas", description="Generates statistical reference atlas from cohort tractography data.", scope="cohort", ) @produces(atlas_dir=CohortDir("atlas")) @verify(verify_requirements) def build_workflow( *, atlas_dir: Path, config: PipelineConfig, context: ProcessingContext, ) -> pe.Workflow: """Build the cohort-level atlas workflow.""" del atlas_dir # resolved purely for its mkdir side-effect; _atlas_task rebuilds the path wf = pe.Workflow(name="atlas_generation") if context.working_dir: wf.base_dir = str(context.working_dir) if context.output_dir is None: raise ValueError("output_dir must be set before building the workflow") output_dir = Path(context.output_dir) input_dir = _resolve_input_dir(output_dir) tractography_relpath = config.atlas.tractography_relpath atlas_node = pe.Node( Function(input_names=_ATLAS_INPUTS, output_names="generated_files", function=_atlas_task), name="generate_atlas", ) atlas_node.inputs.input_dir = str(input_dir) atlas_node.inputs.output_dir = str(output_dir) atlas_node.inputs.atlas_config = { "presence_value": config.atlas.presence_value, "cov_mean_threshold_pct": config.atlas.cov_mean_threshold_pct, "normalization_method": config.atlas.normalization_method.value, } atlas_node.inputs.tractography_relpath = tractography_relpath atlas_node.inputs.output_subdir = config.atlas.output_subdir # QC is computed in the SAME node from the single in-memory cohort stack, # rather than a separate node that re-collects and re-stacks the cohort. if config.atlas_qc.enabled: qc = config.atlas_qc atlas_node.inputs.qc_config = { "subject_density_threshold": qc.subject_density_threshold, "group_core_threshold": qc.group_core_threshold, "leave_one_out": qc.leave_one_out, "leave_one_out_min_subjects": qc.leave_one_out_min_subjects, "compute_cv": qc.compute_cv, "cv_mean_threshold": qc.cv_mean_threshold, "compute_log_stats": qc.compute_log_stats, "log_offset": qc.log_offset, "qc_subdir": qc.output_subdir, } else: atlas_node.inputs.qc_config = None atlas_node.plugin_args = {"n_procs": config.hardware.threads} wf.add_nodes([atlas_node]) logger.info("Built cohort-level Atlas workflow") return wf