"""Nipype workflow definition for the learned conditional deformable atlas.
Owns the cohort-scope ``@workflow`` registration, the ``learned_atlas`` config
namespace binding, the pre-flight verifier, and the wiring of a single Nipype
``Function`` node (the trainer). The trainer body lives in the sibling
``train`` module; the Function wrapper here imports and calls it so torch stays
out of this module's import path.
Byte-compatibility: the trainer derives the output affine/header from the first
cohort fdt image exactly as :func:`thesis.workflows.atlas.compute` does (NOT
from a pre-existing atlas file), so the learned template is byte-compatible with
the averaging atlas's ``atlas_mean.nii.gz`` on every clean run.
Note on import ordering: ``config.learned_atlas`` resolves only after this
package is imported (which registers the namespace). The CLI imports the
selected workflow before loading config, so ``thesis run -w learned_atlas``
works; verify_requirements/build_workflow defensively ``getattr`` the namespace.
"""
from __future__ import annotations
import importlib.util
from pathlib import Path
from typing import List, Optional
import nipype.pipeline.engine as pe
from nipype.interfaces.utility import Function
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.decorators import produces, verify, workflow
from thesis.core.logging import get_logger
from thesis.core.path_declarations import CohortDir
from ._params import LearnedAtlasConfig
logger = get_logger(__name__)
# Single source of truth for the cohort output subdir; the @produces CohortDir
# literal and LearnedAtlasConfig.output_subdir default MUST equal this.
_OUTPUT_SUBDIR = "learned_atlas"
_LEARNED_ATLAS_INPUTS = [
"input_dir",
"atlas_dir",
"tractography_relpath",
"normalization_method",
"training_space",
"affine_native_relpath",
"emit_baseline_maps",
"presence_value",
"cov_mean_threshold_pct",
"save_fields",
"verify_jacobian",
"train_config",
]
def _resolve_input_dir(output_dir: Path) -> Path:
"""Return the directory containing numeric patient subdirectories.
Mirror of :func:`thesis.workflows.atlas.workflow._resolve_input_dir`.
"""
patient_dirs = [d for d in output_dir.iterdir() if d.is_dir() and d.name.isdigit()]
if not patient_dirs and output_dir.name == "cohort" and output_dir.parent.is_dir():
parent_dirs = [d for d in output_dir.parent.iterdir() if d.is_dir() and d.name.isdigit()]
if parent_dirs:
return output_dir.parent
return output_dir
[docs]
def verify_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]:
"""Pre-flight checks for the ``learned_atlas`` workflow.
Verifies (1) the cohort output directory exists, (2) torch is importable
(the ``ml`` extra is installed), (3) the cohort has at least
``learned_atlas.min_subjects`` subjects with valid tractography input, and
(4) when ``training_space='affine_native'`` warns that the output is NOT a
drop-in for the template-space ``atlas_to_patient`` jobs.
Args:
config: Fully merged pipeline configuration.
context: Processing context (carries ``output_dir`` for the cohort).
Returns:
A list of human-readable error strings; empty means ready to build.
"""
errors: List[str] = []
if context.output_dir is None:
return ["output_dir is not set in the processing context"]
out_path = Path(context.output_dir)
if not out_path.is_dir():
return [f"Output directory does not exist: {out_path}"]
# torch is OPTIONAL (the 'ml' group). Probe without importing so the CLI
# loads cleanly when torch is absent.
if importlib.util.find_spec("torch") is None:
errors.append(
"PyTorch is required for the learned_atlas workflow but is not "
"installed. Install the ML extras: pip install -e '.[ml]'"
)
la: Optional[LearnedAtlasConfig] = getattr(config, "learned_atlas", None)
if la is None:
return [
"learned_atlas config namespace is not registered; import "
"thesis.workflows.learned_atlas before loading config."
]
tractography_relpath = config.atlas.tractography_relpath
input_dir = _resolve_input_dir(out_path)
# Cheap metadata-only scan (no volume data is loaded here).
from thesis.workflows.tract_similarity.hcp_loo import _discover_hcp_subjects
subjects = _discover_hcp_subjects(input_dir, tractography_relpath)
if len(subjects) < la.min_subjects:
errors.append(
f"learned_atlas requires at least {la.min_subjects} subjects with valid "
f"tractography input under '{tractography_relpath}' in {input_dir}, "
f"found {len(subjects)}."
)
if la.training_space == "affine_native":
errors.append(
"training_space='affine_native' produces a template in affine-native "
"space that is NOT drop-in for the template-space atlas_to_patient "
"jobs. Use training_space='template_native' for a drop-in atlas, or "
"add an explicit warp step downstream."
)
return errors
def _learned_atlas_task(
input_dir: str,
atlas_dir: str,
tractography_relpath: str,
normalization_method: str,
training_space: str,
affine_native_relpath: Optional[str],
emit_baseline_maps: list,
presence_value: float,
cov_mean_threshold_pct: float,
save_fields: bool,
verify_jacobian: bool,
train_config: dict,
) -> list:
"""Nipype Function body: delegate to the trainer and return generated paths.
Thin boundary. All imports are LOCAL (the body is pickled into a Nipype
subprocess and loguru cannot be pickled); progress uses
``print(..., file=sys.stderr)``.
Returns:
List of absolute paths to generated files (atlas maps first in canonical
order, then per-subject fields, then the training-metadata JSON).
"""
import sys
from thesis.core.exceptions import ProcessingError
from thesis.workflows.learned_atlas.train import train_learned_atlas
print(
f"[learned_atlas] training in {atlas_dir} "
f"(space={training_space}, device={train_config.get('device')})",
file=sys.stderr,
flush=True,
)
try:
return train_learned_atlas(
input_dir=input_dir,
atlas_dir=atlas_dir,
tractography_relpath=tractography_relpath,
normalization_method=normalization_method,
training_space=training_space,
affine_native_relpath=affine_native_relpath,
emit_baseline_maps=list(emit_baseline_maps),
presence_value=float(presence_value),
cov_mean_threshold_pct=float(cov_mean_threshold_pct),
save_fields=bool(save_fields),
verify_jacobian=bool(verify_jacobian),
train_config=dict(train_config),
)
except ProcessingError:
raise
except Exception as exc: # noqa: BLE001 - surface as a typed pipeline error
raise ProcessingError(f"Learned atlas training failed: {exc}") from exc
[docs]
@workflow(
name="learned_atlas",
description=(
"Train a learned conditional deformable tract-density atlas (sharp "
"learnable template + deformation-only diffeomorphic network) from the "
"cohort; emits the five atlas maps byte-compatible with the averaging atlas."
),
scope="cohort",
config_namespace="learned_atlas",
config_schema=LearnedAtlasConfig,
)
@produces(atlas_dir=CohortDir(_OUTPUT_SUBDIR))
@verify(verify_requirements)
def build_workflow(
*,
atlas_dir: Path,
config: PipelineConfig,
context: ProcessingContext,
) -> pe.Workflow:
"""Build the cohort-level learned-atlas training workflow.
Args:
atlas_dir: Resolved ``CohortDir`` output (created on resolution); the
single source of truth for the output location.
config: Fully merged pipeline configuration.
context: Processing context carrying the cohort ``output_dir``.
Returns:
A configured :class:`nipype.pipeline.engine.Workflow`.
Raises:
ValueError: If ``context.output_dir`` is not set.
"""
if context.output_dir is None:
raise ValueError("output_dir must be set before building the workflow")
wf = pe.Workflow(name="learned_atlas_training")
if context.working_dir:
wf.base_dir = str(context.working_dir)
output_dir = Path(context.output_dir)
input_dir = _resolve_input_dir(output_dir)
la: LearnedAtlasConfig = config.learned_atlas # type: ignore[attr-defined]
atlas_cfg = config.atlas
node = pe.Node(
Function(
input_names=_LEARNED_ATLAS_INPUTS,
output_names="generated_files",
function=_learned_atlas_task,
),
name="train_learned_atlas",
)
node.inputs.input_dir = str(input_dir)
node.inputs.atlas_dir = str(atlas_dir)
node.inputs.tractography_relpath = atlas_cfg.tractography_relpath
node.inputs.normalization_method = atlas_cfg.normalization_method.value
node.inputs.training_space = la.training_space
node.inputs.affine_native_relpath = la.affine_native_relpath
node.inputs.emit_baseline_maps = list(la.emit_baseline_maps)
node.inputs.presence_value = float(atlas_cfg.presence_value)
node.inputs.cov_mean_threshold_pct = float(atlas_cfg.cov_mean_threshold_pct)
node.inputs.save_fields = bool(la.save_fields)
node.inputs.verify_jacobian = bool(la.verify_jacobian)
node.inputs.train_config = {
"device": la.device,
"dtype": la.dtype,
"n_templates": la.model.n_templates,
"int_steps": la.model.int_steps,
"enc_features": list(la.model.enc_features),
"dec_features": list(la.model.dec_features),
"similarity_weight": la.loss.similarity_weight,
"presence_weight": la.loss.presence_weight,
"smoothness_weight": la.loss.smoothness_weight,
"log_offset": la.loss.log_offset,
"lr": la.optimizer.lr,
"epochs": la.optimizer.epochs,
"batch_size": la.optimizer.batch_size,
"seed": la.optimizer.seed,
}
node.plugin_args = {"n_procs": config.hardware.threads}
wf.add_nodes([node])
logger.info(
"Built cohort-level learned_atlas workflow rooted at {} (space={}, device={})",
input_dir,
la.training_space,
la.device,
)
return wf