"""Patient-to-template structural registration workflow."""
from __future__ import annotations
import shutil
from importlib.util import find_spec
from pathlib import Path
from typing import List
from nipype import Node, Workflow
from nipype.interfaces.ants import Registration
from nipype.interfaces.utility import Function
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.contracts import attach_inputnode, attach_outputnode
from thesis.core.decorators import verify, workflow
from thesis.core.logging import get_logger
from .fireants_backend import build_fireants_node
from .paths import (
ResolvedRegistrationJob,
get_registration_job_dir,
get_registration_job_inverse_warped_image_path,
get_registration_job_transform_dir,
get_registration_job_warped_image_path,
resolve_fixed_image,
resolve_fixed_image_for_job,
resolve_moving_image_for_job,
resolve_registration_jobs,
)
from .viewers import describe_registration_viewer_command
logger = get_logger(__name__)
__all__ = ["build_workflow", "verify_requirements"]
def _torch_cuda_available() -> bool:
"""Return whether torch reports CUDA availability."""
try:
import torch
except ImportError:
return False
return bool(torch.cuda.is_available())
def _torch_cuda_runtime_usable(device: str = "cuda") -> tuple[bool, str]:
"""Return whether a minimal CUDA torch operation succeeds."""
try:
import torch
except ImportError:
return False, "Python package 'torch' is not installed."
if not torch.cuda.is_available():
return False, "torch.cuda.is_available() is False."
try:
_ = torch.zeros(1, device=device) + 1
except Exception as exc:
return False, str(exc)
return True, ""
_STAGE_MAP = {"Rigid": ["Rigid"], "Affine": ["Rigid", "Affine"], "SyN": ["Rigid", "Affine", "SyN"]}
def _registration_stage_config(transform_type: str, metric: str) -> dict[str, object]:
"""Return stage-wise ANTs settings for the requested transform family."""
transforms = _STAGE_MAP[transform_type]
n = len(transforms)
m = metric.upper()
return {
"transforms": transforms,
"transform_parameters": [(0.1, 3.0, 0.0) if t == "SyN" else (0.1,) for t in transforms],
"number_of_iterations": [
[100, 70, 50, 20] if t == "SyN" else [1000, 500, 250, 100] for t in transforms
],
"metric": [m] * n,
"metric_weight": [1.0] * n,
"radius_or_number_of_bins": [4 if m == "CC" else 32] * n,
"sampling_strategy": ["Regular"] * n,
"sampling_percentage": [0.25] * n,
"convergence_threshold": [1e-6] * n,
"convergence_window_size": [10] * n,
"shrink_factors": [[8, 4, 2, 1]] * n,
"smoothing_sigmas": [[3, 2, 1, 0]] * n,
"sigma_units": ["vox"] * n,
"use_histogram_matching": [m != "CC"] * n,
"use_estimate_learning_rate_once": [True] * n,
}
def _launch_viewer_task(
template_image: str,
warped_image: str,
backend: str,
auto_open: bool,
enabled: bool,
overlay_opacity: float,
) -> str:
"""Launch a registration QC viewer from a Nipype worker process."""
if not enabled or not auto_open:
return ""
from thesis.workflows.registration.viewers import launch_registration_viewer
command = launch_registration_viewer(
template_image=Path(template_image),
warped_image=Path(warped_image),
backend=backend,
overlay_opacity=overlay_opacity,
)
print(f"Launched registration viewer: {' '.join(command)}")
return " ".join(command)
def _ants_transform_prefix(
config: PipelineConfig, context: ProcessingContext, job: ResolvedRegistrationJob
) -> str:
"""Return the ANTs output-transform prefix for a registration job."""
transform_dir = get_registration_job_transform_dir(config, context, job)
if job.is_default:
stem = f"{context.patient_id}_patient_to_template_"
else:
stem = f"{context.patient_id}_{job.name}_"
return str(transform_dir / stem)
def _build_ants_node(
config: PipelineConfig,
context: ProcessingContext,
moving: Path,
fixed: Path,
*,
job: ResolvedRegistrationJob | None = None,
) -> Node:
"""Build the ANTs registration execution node for *job*."""
if job is None:
job = resolve_registration_jobs(config)[0]
rcfg = config.registration
node_name = "ants_registration" if job.is_default else f"ants_{job.safe_name}"
reg = Node(Registration(), name=node_name)
reg.inputs.dimension = 3
reg.inputs.fixed_image = str(fixed)
# Only set moving_image when the file already exists; full_pipeline produces
# the T1 brain at runtime so it isn't on disk at graph-construction time.
if moving.exists():
reg.inputs.moving_image = str(moving)
reg.inputs.output_transform_prefix = _ants_transform_prefix(config, context, job)
reg.inputs.output_warped_image = str(
get_registration_job_warped_image_path(config, context, job)
)
reg.inputs.output_inverse_warped_image = str(
get_registration_job_inverse_warped_image_path(config, context, job)
)
setattr(reg.inputs, "float", job.use_float)
reg.inputs.interpolation = job.interpolation
reg.inputs.collapse_output_transforms = rcfg.collapse_output_transforms
reg.inputs.write_composite_transform = rcfg.write_composite_transform
reg.inputs.initial_moving_transform_com = True
reg.inputs.winsorize_lower_quantile = 0.005
reg.inputs.winsorize_upper_quantile = 0.995
reg.inputs.num_threads = getattr(config.hardware, "threads", 1)
reg.inputs.verbose = True
for key, value in _registration_stage_config(job.transform_type, job.metric).items():
setattr(reg.inputs, key, value)
reg._mem_gb = max(4.0, float(getattr(config.hardware, "memory_gb", 16)) / 4.0)
return reg
def _dedupe(errors: List[str]) -> List[str]:
"""Return *errors* with duplicates removed, preserving first-seen order."""
seen: set[str] = set()
out: List[str] = []
for err in errors:
if err not in seen:
seen.add(err)
out.append(err)
return out
[docs]
def verify_requirements(
config: PipelineConfig, context: ProcessingContext, **_: object
) -> List[str]:
"""Cross-cutting preflight checks: binaries, optional packages, GPU, viewer."""
errors: List[str] = []
reg_cfg = config.registration
jobs = resolve_registration_jobs(config)
needs_ants = any(job.method == "ants" for job in jobs)
needs_fireants = any(job.method == "fireants" for job in jobs)
if needs_ants and shutil.which("antsRegistration") is None:
errors.append("ANTs 'antsRegistration' not found on PATH.")
if needs_fireants:
if find_spec("torch") is None:
errors.append("Python package 'torch' is required for registration.method='fireants'.")
if find_spec("fireants") is None:
errors.append(
"Python package 'fireants' is required for registration.method='fireants'."
)
package_ok = find_spec("torch") is not None and find_spec("fireants") is not None
for job in jobs:
if job.method != "fireants" or not job.fireants.device.startswith("cuda"):
continue
if not package_ok:
continue
if not _torch_cuda_available():
errors.append(
"registration.fireants.device requests CUDA, but "
"torch.cuda.is_available() is False."
)
else:
ok, err = _torch_cuda_runtime_usable(job.fireants.device)
if not ok:
errors.append(
"registration.fireants.device requests CUDA, but a minimal torch CUDA "
f"operation failed: {err}. Use registration.fireants.device='cpu' "
"or install a PyTorch build compatible with this GPU."
)
for job in jobs:
if not job.fixed_image:
errors.append("registration.fixed_image is not configured.")
if job.moving_modality == "t2" and not job.moving_image:
if not getattr(config.hcp, "t2_image", None):
errors.append(
"registration.moving_modality is 't2' but neither "
"registration.moving_image nor hcp.t2_image is configured."
)
v = reg_cfg.viewer
if v.enabled and v.auto_open and v.backend == "fsleyes" and shutil.which("fsleyes") is None:
errors.append(
"FSLeyes 'fsleyes' not found on PATH but registration viewer auto-open is enabled."
)
return _dedupe(errors)
_VIEWER_INPUTS = [
"template_image",
"warped_image",
"backend",
"auto_open",
"enabled",
"overlay_opacity",
]
[docs]
@workflow(
name="registration",
description="Patient-to-template registration workflow with ANTs and FireANTs backends.",
)
@verify(verify_requirements)
def build_workflow(*, config: PipelineConfig, context: ProcessingContext) -> Workflow:
"""Build the configured registration workflow for one patient."""
if not config.registration.fixed_image:
raise ValueError(
"registration.fixed_image must be configured for the registration workflow"
)
wf = Workflow(name=f"registration_{context.patient_id}")
if context.working_dir:
wf.base_dir = str(context.working_dir)
jobs = resolve_registration_jobs(config)
# -- I/O contract fields -------------------------------------------------
# Each job publishes per-job suffixed fields; the default job ALSO publishes
# the legacy unsuffixed aliases (transform / forward_transforms /
# reverse_transforms + moving_image) so existing tests + meta-workflow edges
# keep working unchanged.
input_fields: List[str] = []
input_defaults: dict = {}
output_fields: List[str] = []
# (node, job, output-source, output-target) edges into outputnode.
output_edges: List[tuple[Node, str, str]] = []
moving_input_targets: List[tuple[Node, str]] = []
default_warped_node: Node | None = None
default_fixed: Path | None = None
for job in jobs:
moving = resolve_moving_image_for_job(config, context, job)
fixed = resolve_fixed_image_for_job(config, context, job)
get_registration_job_dir(config, context, job).mkdir(parents=True, exist_ok=True)
get_registration_job_transform_dir(config, context, job).mkdir(parents=True, exist_ok=True)
if job.method == "fireants":
node = build_fireants_node(config, context, moving=moving, fixed=fixed, job=job)
else:
node = _build_ants_node(config, context, moving, fixed, job=job)
wf.add_nodes([node])
suffix = "" if job.is_default else f"__{job.safe_name}"
moving_field = "moving_image" if job.is_default else f"moving_image__{job.safe_name}"
if moving_field not in input_fields:
input_fields.append(moving_field)
input_defaults[moving_field] = str(moving) if moving.exists() else None
moving_input_targets.append((node, moving_field))
# Output contract: per-job fields (+ legacy aliases for the default job).
transform_field = f"transform{suffix}"
forward_field = f"forward_transforms{suffix}"
reverse_field = f"reverse_transforms{suffix}"
output_fields.extend([transform_field, forward_field, reverse_field])
if job.method == "fireants":
output_edges.append((node, "forward_transforms", transform_field))
output_edges.append((node, "forward_transforms", forward_field))
output_edges.append((node, "reverse_transforms", reverse_field))
else:
# ANTs with write_composite_transform=True emits single composite
# files; expose them as the forward/reverse apply chains.
output_edges.append((node, "composite_transform", transform_field))
output_edges.append((node, "composite_transform", forward_field))
output_edges.append((node, "inverse_composite_transform", reverse_field))
if job.is_default:
default_warped_node = node
default_fixed = fixed
# inputnode.moving_image (default) defaults to the resolved moving image for
# standalone runs; a meta-workflow overrides it with the runtime T1 brain.
inputnode = attach_inputnode(wf, input_fields, defaults=input_defaults)
for node, field in moving_input_targets:
wf.connect(inputnode, field, node, "moving_image")
outputnode = attach_outputnode(wf, output_fields)
for node, source, target in output_edges:
wf.connect(node, source, outputnode, target)
fixed_for_viewer = (
default_fixed if default_fixed is not None else resolve_fixed_image(config, context)
)
v = config.registration.viewer
if v.enabled and v.auto_open and default_warped_node is not None:
viewer = Node(
Function(
input_names=_VIEWER_INPUTS, output_names=["command"], function=_launch_viewer_task
),
name="registration_viewer",
)
viewer.inputs.template_image = str(fixed_for_viewer)
viewer.inputs.backend = v.backend
viewer.inputs.auto_open = v.auto_open
viewer.inputs.enabled = v.enabled
viewer.inputs.overlay_opacity = v.overlay_opacity
wf.add_nodes([viewer])
wf.connect(default_warped_node, "warped_image", viewer, "warped_image")
elif v.enabled:
default_job = next((j for j in jobs if j.is_default), jobs[0])
cmd = describe_registration_viewer_command(
template_image=fixed_for_viewer,
warped_image=get_registration_job_warped_image_path(config, context, default_job),
backend=v.backend,
overlay_opacity=v.overlay_opacity,
)
logger.info("Registration QC viewer not auto-opened. Launch manually with: {}", cmd)
logger.info(
"Built registration workflow for {} | {} job(s) | jobs={}",
context.patient_id,
len(jobs),
[j.name for j in jobs],
)
return wf