"""Path helpers for the registration workflow."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from thesis.core.config import PipelineConfig
from thesis.core.config.validators import FireantsRegistrationConfig
from thesis.core.context import ProcessingContext
from thesis.core.naming import safe_node_name
from thesis.workflows.hcp.common import (
format_patient_path,
resolve_t1_path,
resolve_with_fallback,
)
#: Name of the implicit registration job synthesized when ``registration.jobs``
#: is empty. Its on-disk transforms use the legacy flat layout/filenames.
DEFAULT_REGISTRATION_JOB = "patient_to_template"
__all__ = [
"DEFAULT_REGISTRATION_JOB",
"ResolvedRegistrationJob",
"get_registration_output_dir",
"get_registration_transform_dir",
"get_registration_warped_image_path",
"get_registration_inverse_warped_image_path",
"get_registration_transform_prefix",
"get_registration_job_dir",
"get_registration_job_transform_dir",
"get_registration_job_transform_paths",
"get_registration_job_warped_image_path",
"get_registration_job_inverse_warped_image_path",
"registration_transform_filenames",
"resolve_registration_jobs",
"resolve_fixed_image",
"resolve_fixed_image_for_job",
"resolve_moving_image",
"resolve_moving_image_for_job",
]
[docs]
@dataclass(frozen=True)
class ResolvedRegistrationJob:
"""A fully-resolved registration job.
The single source of truth produced by :func:`resolve_registration_jobs`:
the shared ``registration.*`` defaults merged with any per-job overrides.
All scalar fields are non-optional here (overrides fall back to the shared
value), and ``fireants`` is a fully-validated config (the sparse per-job
block merged onto the shared one).
Attributes:
name: The job name (``"patient_to_template"`` for the implicit default).
is_default: Whether this is the implicit default job (legacy flat
on-disk layout).
method: Resolved registration backend.
moving_modality: Resolved moving-image modality.
moving_image: Resolved explicit moving-image path override (or ``None``).
fixed_image: Resolved fixed-image path override (or ``None``).
interpolation: Resolved interpolation mode.
metric: Resolved similarity metric.
transform_type: Resolved transform family (``Rigid``/``Affine``/``SyN``).
use_float: Resolved float-precision flag.
fireants: Resolved (re-validated) FireANTs backend config.
"""
name: str
is_default: bool
method: str
moving_modality: str
moving_image: Optional[str]
fixed_image: Optional[str]
interpolation: str
metric: str
transform_type: str
use_float: bool
fireants: FireantsRegistrationConfig
@property
def safe_name(self) -> str:
"""Return an identifier-safe form of the job name."""
return safe_node_name(self.name)
[docs]
def resolve_registration_jobs(config: PipelineConfig) -> List[ResolvedRegistrationJob]:
"""Resolve the registration jobs for *config*.
This is the single source of truth for the named-jobs model: it merges the
shared ``registration.*`` defaults with each ``registration.jobs`` entry's
overrides. When ``registration.jobs`` is empty, a single implicit job named
``"patient_to_template"`` is synthesized from the top-level fields (so the
existing single-registration config keeps working unchanged).
The per-job sparse ``fireants`` block is merged onto the shared
``registration.fireants`` and **re-validated** via
:meth:`FireantsRegistrationConfig.model_validate` (pydantic v2
``model_copy(update=...)`` does *not* re-validate, so an invalid per-job
override — e.g. a scale-length mismatch — must surface here).
Args:
config: The merged pipeline configuration.
Returns:
One :class:`ResolvedRegistrationJob` per configured job (or a single
implicit default when none are configured).
"""
reg = config.registration
if not reg.jobs:
return [
ResolvedRegistrationJob(
name=DEFAULT_REGISTRATION_JOB,
is_default=True,
method=reg.method,
moving_modality=reg.moving_modality,
moving_image=reg.moving_image,
fixed_image=reg.fixed_image or None,
interpolation=reg.interpolation,
metric=reg.metric,
transform_type=reg.transform_type,
use_float=reg.use_float,
fireants=reg.fireants,
)
]
resolved: List[ResolvedRegistrationJob] = []
for job in reg.jobs:
if job.fireants:
merged = {**reg.fireants.model_dump(), **job.fireants}
fireants = FireantsRegistrationConfig.model_validate(merged)
else:
fireants = reg.fireants
resolved.append(
ResolvedRegistrationJob(
name=job.name,
is_default=job.name == DEFAULT_REGISTRATION_JOB,
method=job.method if job.method is not None else reg.method,
moving_modality=(
job.moving_modality if job.moving_modality is not None else reg.moving_modality
),
moving_image=(
job.moving_image if job.moving_image is not None else reg.moving_image
),
fixed_image=(
job.fixed_image if job.fixed_image is not None else (reg.fixed_image or None)
),
interpolation=(
job.interpolation if job.interpolation is not None else reg.interpolation
),
metric=job.metric if job.metric is not None else reg.metric,
transform_type=(
job.transform_type if job.transform_type is not None else reg.transform_type
),
use_float=job.use_float if job.use_float is not None else reg.use_float,
fireants=fireants,
)
)
return resolved
[docs]
def get_registration_output_dir(config: PipelineConfig, context: ProcessingContext) -> Path:
"""Return the patient-specific output directory for registration artifacts."""
base_output = (
Path(context.output_dir) if context.output_dir else Path("outputs") / context.patient_id
)
subdir = getattr(config.registration, "output_subdir", "registration")
return base_output / subdir
[docs]
def get_registration_warped_image_path(config: PipelineConfig, context: ProcessingContext) -> Path:
"""Return the output path of the patient image warped to template space."""
out_dir = get_registration_output_dir(config, context)
return out_dir / f"{context.patient_id}_moving_to_template.nii.gz"
[docs]
def get_registration_inverse_warped_image_path(
config: PipelineConfig,
context: ProcessingContext,
) -> Path:
"""Return the output path of the template image warped to patient space."""
out_dir = get_registration_output_dir(config, context)
return out_dir / f"{context.patient_id}_template_to_moving.nii.gz"
[docs]
def resolve_moving_image(config: PipelineConfig, context: ProcessingContext) -> Path:
"""Resolve the moving image path for the registration workflow."""
registration_cfg = config.registration
input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve()
if registration_cfg.moving_image:
raw = format_patient_path(str(registration_cfg.moving_image), context.patient_id)
return resolve_with_fallback(raw, input_dir, [context.data_dir])
if registration_cfg.moving_modality == "t1":
return resolve_t1_path(config, context)
hcp_t2 = getattr(config.hcp, "t2_image", None)
if hcp_t2:
raw = format_patient_path(str(hcp_t2), context.patient_id)
return resolve_with_fallback(raw, input_dir, [context.data_dir])
return input_dir / "T1w" / "T2w.nii.gz"
[docs]
def resolve_fixed_image(config: PipelineConfig, context: ProcessingContext) -> Path:
"""Resolve the fixed template image path for the registration workflow."""
if not config.registration.fixed_image:
return Path("__missing_registration_fixed_image__")
input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve()
raw = format_patient_path(str(config.registration.fixed_image), context.patient_id)
return resolve_with_fallback(raw, Path(context.data_dir).resolve(), [input_dir])
# --------------------------------------------------------------------------- #
# Job-aware path + image resolution #
# --------------------------------------------------------------------------- #
[docs]
def get_registration_job_dir(
config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob
) -> Path:
"""Return the per-job output directory for registration artifacts.
The implicit default job (``"patient_to_template"``) uses the legacy flat
layout (``<output>/registration``) so existing on-disk paths and caches
keep working. Explicit jobs nest under their (identifier-safe) name
(``<output>/registration/<job>``).
"""
base = get_registration_output_dir(config, context)
if job.is_default:
return base
return base / job.safe_name
[docs]
def get_registration_job_warped_image_path(
config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob
) -> Path:
"""Return the per-job moving→template warped-image output path."""
out_dir = get_registration_job_dir(config, context, job)
if job.is_default:
return out_dir / f"{context.patient_id}_moving_to_template.nii.gz"
return out_dir / f"{context.patient_id}_{job.name}_moving_to_template.nii.gz"
[docs]
def get_registration_job_inverse_warped_image_path(
config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob
) -> Path:
"""Return the per-job template→moving inverse-warped-image output path."""
out_dir = get_registration_job_dir(config, context, job)
if job.is_default:
return out_dir / f"{context.patient_id}_template_to_moving.nii.gz"
return out_dir / f"{context.patient_id}_{job.name}_template_to_moving.nii.gz"
[docs]
def resolve_moving_image_for_job(
config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob
) -> Path:
"""Resolve the moving image for a specific registration job.
The default job reuses :func:`resolve_moving_image` verbatim (preserving the
existing behaviour and tests). Explicit jobs honour their resolved
``moving_image`` / ``moving_modality`` overrides.
"""
if job.is_default:
return resolve_moving_image(config, context)
input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve()
if job.moving_image:
raw = format_patient_path(str(job.moving_image), context.patient_id)
return resolve_with_fallback(raw, input_dir, [context.data_dir])
if job.moving_modality == "t1":
return resolve_t1_path(config, context)
hcp_t2 = getattr(config.hcp, "t2_image", None)
if hcp_t2:
raw = format_patient_path(str(hcp_t2), context.patient_id)
return resolve_with_fallback(raw, input_dir, [context.data_dir])
return input_dir / "T1w" / "T2w.nii.gz"
[docs]
def resolve_fixed_image_for_job(
config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob
) -> Path:
"""Resolve the fixed template image for a specific registration job."""
if job.is_default:
return resolve_fixed_image(config, context)
if not job.fixed_image:
return Path("__missing_registration_fixed_image__")
input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve()
raw = format_patient_path(str(job.fixed_image), context.patient_id)
return resolve_with_fallback(raw, Path(context.data_dir).resolve(), [input_dir])