Source code for thesis.workflows.registration.paths

"""Path helpers for the registration workflow."""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

from thesis.core.config import PipelineConfig
from thesis.core.config.validators import FireantsRegistrationConfig
from thesis.core.context import ProcessingContext
from thesis.core.naming import safe_node_name
from thesis.workflows.hcp.common import (
    format_patient_path,
    resolve_t1_path,
    resolve_with_fallback,
)

#: Name of the implicit registration job synthesized when ``registration.jobs``
#: is empty.  Its on-disk transforms use the legacy flat layout/filenames.
DEFAULT_REGISTRATION_JOB = "patient_to_template"

__all__ = [
    "DEFAULT_REGISTRATION_JOB",
    "ResolvedRegistrationJob",
    "get_registration_output_dir",
    "get_registration_transform_dir",
    "get_registration_warped_image_path",
    "get_registration_inverse_warped_image_path",
    "get_registration_transform_prefix",
    "get_registration_job_dir",
    "get_registration_job_transform_dir",
    "get_registration_job_transform_paths",
    "get_registration_job_warped_image_path",
    "get_registration_job_inverse_warped_image_path",
    "registration_transform_filenames",
    "resolve_registration_jobs",
    "resolve_fixed_image",
    "resolve_fixed_image_for_job",
    "resolve_moving_image",
    "resolve_moving_image_for_job",
]


[docs] @dataclass(frozen=True) class ResolvedRegistrationJob: """A fully-resolved registration job. The single source of truth produced by :func:`resolve_registration_jobs`: the shared ``registration.*`` defaults merged with any per-job overrides. All scalar fields are non-optional here (overrides fall back to the shared value), and ``fireants`` is a fully-validated config (the sparse per-job block merged onto the shared one). Attributes: name: The job name (``"patient_to_template"`` for the implicit default). is_default: Whether this is the implicit default job (legacy flat on-disk layout). method: Resolved registration backend. moving_modality: Resolved moving-image modality. moving_image: Resolved explicit moving-image path override (or ``None``). fixed_image: Resolved fixed-image path override (or ``None``). interpolation: Resolved interpolation mode. metric: Resolved similarity metric. transform_type: Resolved transform family (``Rigid``/``Affine``/``SyN``). use_float: Resolved float-precision flag. fireants: Resolved (re-validated) FireANTs backend config. """ name: str is_default: bool method: str moving_modality: str moving_image: Optional[str] fixed_image: Optional[str] interpolation: str metric: str transform_type: str use_float: bool fireants: FireantsRegistrationConfig @property def safe_name(self) -> str: """Return an identifier-safe form of the job name.""" return safe_node_name(self.name)
[docs] def resolve_registration_jobs(config: PipelineConfig) -> List[ResolvedRegistrationJob]: """Resolve the registration jobs for *config*. This is the single source of truth for the named-jobs model: it merges the shared ``registration.*`` defaults with each ``registration.jobs`` entry's overrides. When ``registration.jobs`` is empty, a single implicit job named ``"patient_to_template"`` is synthesized from the top-level fields (so the existing single-registration config keeps working unchanged). The per-job sparse ``fireants`` block is merged onto the shared ``registration.fireants`` and **re-validated** via :meth:`FireantsRegistrationConfig.model_validate` (pydantic v2 ``model_copy(update=...)`` does *not* re-validate, so an invalid per-job override — e.g. a scale-length mismatch — must surface here). Args: config: The merged pipeline configuration. Returns: One :class:`ResolvedRegistrationJob` per configured job (or a single implicit default when none are configured). """ reg = config.registration if not reg.jobs: return [ ResolvedRegistrationJob( name=DEFAULT_REGISTRATION_JOB, is_default=True, method=reg.method, moving_modality=reg.moving_modality, moving_image=reg.moving_image, fixed_image=reg.fixed_image or None, interpolation=reg.interpolation, metric=reg.metric, transform_type=reg.transform_type, use_float=reg.use_float, fireants=reg.fireants, ) ] resolved: List[ResolvedRegistrationJob] = [] for job in reg.jobs: if job.fireants: merged = {**reg.fireants.model_dump(), **job.fireants} fireants = FireantsRegistrationConfig.model_validate(merged) else: fireants = reg.fireants resolved.append( ResolvedRegistrationJob( name=job.name, is_default=job.name == DEFAULT_REGISTRATION_JOB, method=job.method if job.method is not None else reg.method, moving_modality=( job.moving_modality if job.moving_modality is not None else reg.moving_modality ), moving_image=( job.moving_image if job.moving_image is not None else reg.moving_image ), fixed_image=( job.fixed_image if job.fixed_image is not None else (reg.fixed_image or None) ), interpolation=( job.interpolation if job.interpolation is not None else reg.interpolation ), metric=job.metric if job.metric is not None else reg.metric, transform_type=( job.transform_type if job.transform_type is not None else reg.transform_type ), use_float=job.use_float if job.use_float is not None else reg.use_float, fireants=fireants, ) ) return resolved
[docs] def registration_transform_filenames( patient_id: str, job_name: str, transform_type: str, ) -> dict[str, list[str]]: """Return the canonical transform filenames the FireANTs backend emits. This is the single source of truth shared by the FireANTs backend (``_run_staged``) and the standalone transform resolver (``get_registration_job_transform_paths``). The names must never be duplicated elsewhere — both producers and consumers import this helper so the on-disk chain and the resolved chain stay in lock-step. The returned lists are *bare filenames* (no directory). Callers join them onto the appropriate transform directory. The chain order matches the order the backend emits and downstream ``antsApplyTransforms`` expects. For the legacy default job ``"patient_to_template"`` the names are the historical ones (``<pid>_patient_to_template_*`` for the forward chain and ``<pid>_template_to_patient_*`` for the reverse chain) so existing caches and on-disk paths keep working. Other jobs use the ``<pid>_<job>_{forward,reverse}_*`` form. Args: patient_id: Patient identifier used as the filename prefix. job_name: Registration job name (e.g. ``"patient_to_template"``). transform_type: One of ``"Rigid"``, ``"Affine"`` or ``"SyN"``. Returns: Mapping with two keys, ``"forward"`` (patient→template chain) and ``"reverse"`` (template→patient chain). Each value is the ordered list of transform filenames for that direction. Raises: ValueError: If ``transform_type`` is not a recognised value. """ if job_name == "patient_to_template": fwd_prefix = f"{patient_id}_patient_to_template" rev_prefix = f"{patient_id}_template_to_patient" else: fwd_prefix = f"{patient_id}_{job_name}_forward" rev_prefix = f"{patient_id}_{job_name}_reverse" if transform_type in {"Rigid", "Affine"}: return { "forward": [f"{fwd_prefix}_affine.mat"], "reverse": [f"{rev_prefix}_affine.mat"], } if transform_type == "SyN": return { "forward": [f"{fwd_prefix}_warp.nii.gz"], "reverse": [f"{rev_prefix}_warp.nii.gz"], } raise ValueError( f"Unknown registration transform_type '{transform_type}'. " "Expected one of 'Rigid', 'Affine', 'SyN'." )
[docs] def get_registration_output_dir(config: PipelineConfig, context: ProcessingContext) -> Path: """Return the patient-specific output directory for registration artifacts.""" base_output = ( Path(context.output_dir) if context.output_dir else Path("outputs") / context.patient_id ) subdir = getattr(config.registration, "output_subdir", "registration") return base_output / subdir
[docs] def get_registration_transform_dir(config: PipelineConfig, context: ProcessingContext) -> Path: """Return the transform output directory for the registration workflow.""" return get_registration_output_dir(config, context) / "transforms"
[docs] def get_registration_warped_image_path(config: PipelineConfig, context: ProcessingContext) -> Path: """Return the output path of the patient image warped to template space.""" out_dir = get_registration_output_dir(config, context) return out_dir / f"{context.patient_id}_moving_to_template.nii.gz"
[docs] def get_registration_inverse_warped_image_path( config: PipelineConfig, context: ProcessingContext, ) -> Path: """Return the output path of the template image warped to patient space.""" out_dir = get_registration_output_dir(config, context) return out_dir / f"{context.patient_id}_template_to_moving.nii.gz"
[docs] def get_registration_transform_prefix(config: PipelineConfig, context: ProcessingContext) -> str: """Return the ANTs output transform prefix for patient-to-template transforms.""" transform_dir = get_registration_transform_dir(config, context) return str(transform_dir / f"{context.patient_id}_patient_to_template_")
[docs] def resolve_moving_image(config: PipelineConfig, context: ProcessingContext) -> Path: """Resolve the moving image path for the registration workflow.""" registration_cfg = config.registration input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve() if registration_cfg.moving_image: raw = format_patient_path(str(registration_cfg.moving_image), context.patient_id) return resolve_with_fallback(raw, input_dir, [context.data_dir]) if registration_cfg.moving_modality == "t1": return resolve_t1_path(config, context) hcp_t2 = getattr(config.hcp, "t2_image", None) if hcp_t2: raw = format_patient_path(str(hcp_t2), context.patient_id) return resolve_with_fallback(raw, input_dir, [context.data_dir]) return input_dir / "T1w" / "T2w.nii.gz"
[docs] def resolve_fixed_image(config: PipelineConfig, context: ProcessingContext) -> Path: """Resolve the fixed template image path for the registration workflow.""" if not config.registration.fixed_image: return Path("__missing_registration_fixed_image__") input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve() raw = format_patient_path(str(config.registration.fixed_image), context.patient_id) return resolve_with_fallback(raw, Path(context.data_dir).resolve(), [input_dir])
# --------------------------------------------------------------------------- # # Job-aware path + image resolution # # --------------------------------------------------------------------------- #
[docs] def get_registration_job_dir( config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob ) -> Path: """Return the per-job output directory for registration artifacts. The implicit default job (``"patient_to_template"``) uses the legacy flat layout (``<output>/registration``) so existing on-disk paths and caches keep working. Explicit jobs nest under their (identifier-safe) name (``<output>/registration/<job>``). """ base = get_registration_output_dir(config, context) if job.is_default: return base return base / job.safe_name
[docs] def get_registration_job_transform_dir( config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob ) -> Path: """Return the per-job transform output directory.""" return get_registration_job_dir(config, context, job) / "transforms"
[docs] def get_registration_job_warped_image_path( config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob ) -> Path: """Return the per-job moving→template warped-image output path.""" out_dir = get_registration_job_dir(config, context, job) if job.is_default: return out_dir / f"{context.patient_id}_moving_to_template.nii.gz" return out_dir / f"{context.patient_id}_{job.name}_moving_to_template.nii.gz"
[docs] def get_registration_job_inverse_warped_image_path( config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob ) -> Path: """Return the per-job template→moving inverse-warped-image output path.""" out_dir = get_registration_job_dir(config, context, job) if job.is_default: return out_dir / f"{context.patient_id}_template_to_moving.nii.gz" return out_dir / f"{context.patient_id}_{job.name}_template_to_moving.nii.gz"
[docs] def get_registration_job_transform_paths( config: PipelineConfig, context: ProcessingContext, job_name: str, direction: str, ) -> List[str]: """Return the resolved transform chain a registration job produces. The returned list is the same chain — in the same order, with the same filenames — the FireANTs backend emits, because both this resolver and the backend share :func:`registration_transform_filenames`. Args: config: The merged pipeline configuration. context: The processing context (supplies patient id + output dirs). job_name: The registration-job name to resolve. direction: ``"patient_to_template"`` selects the forward chain; ``"template_to_patient"`` selects the reverse chain. Returns: Ordered list of absolute transform path strings. Raises: KeyError: If *job_name* is not a resolved registration job. ValueError: If *direction* is not a recognised value. """ jobs = {j.name: j for j in resolve_registration_jobs(config)} if job_name not in jobs: raise KeyError(f"Unknown registration job '{job_name}'. " f"Known jobs: {sorted(jobs)}.") job = jobs[job_name] if direction == "patient_to_template": key = "forward" elif direction == "template_to_patient": key = "reverse" else: raise ValueError( f"Unknown transform direction '{direction}'. Expected " "'patient_to_template' or 'template_to_patient'." ) transform_dir = get_registration_job_transform_dir(config, context, job) names = registration_transform_filenames(context.patient_id, job_name, job.transform_type) return [str(transform_dir / name) for name in names[key]]
[docs] def resolve_moving_image_for_job( config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob ) -> Path: """Resolve the moving image for a specific registration job. The default job reuses :func:`resolve_moving_image` verbatim (preserving the existing behaviour and tests). Explicit jobs honour their resolved ``moving_image`` / ``moving_modality`` overrides. """ if job.is_default: return resolve_moving_image(config, context) input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve() if job.moving_image: raw = format_patient_path(str(job.moving_image), context.patient_id) return resolve_with_fallback(raw, input_dir, [context.data_dir]) if job.moving_modality == "t1": return resolve_t1_path(config, context) hcp_t2 = getattr(config.hcp, "t2_image", None) if hcp_t2: raw = format_patient_path(str(hcp_t2), context.patient_id) return resolve_with_fallback(raw, input_dir, [context.data_dir]) return input_dir / "T1w" / "T2w.nii.gz"
[docs] def resolve_fixed_image_for_job( config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob ) -> Path: """Resolve the fixed template image for a specific registration job.""" if job.is_default: return resolve_fixed_image(config, context) if not job.fixed_image: return Path("__missing_registration_fixed_image__") input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve() raw = format_patient_path(str(job.fixed_image), context.patient_id) return resolve_with_fallback(raw, Path(context.data_dir).resolve(), [input_dir])