"""FireANTs backend for patient-to-template registration.
The native staged flow mirrors FireANTs' own reference pipeline
(``MomentsRegistration → RigidRegistration → AffineRegistration →
GreedyRegistration``). Moments registration supplies center-of-mass and
principal-axis initialization, which is what the old ANTs hack was substituting
for; rigid/affine refine that initialization; and (for ``SyN``) Greedy/SyN bakes
the affine into a single composite forward warp. No ANTs subprocesses.
The reverse (template->patient) warp is produced by one of two ANTs-free routes,
selected via ``registration.fireants.inverse_method``:
* ``simpleitk`` (default): numerically invert the high-quality *forward*
displacement field with SimpleITK's ``InvertDisplacementFieldImageFilter``
(explicit convergence tolerance). Works for both ``greedy`` and ``syn`` since
it only needs the forward field, so it also unblocks ``deform_algo='syn'``.
* ``fireants``: FireANTs' own inverse-consistency solve
(``save_as_ants_transforms(save_inverse=True)``); ``greedy`` only.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from nipype import Node
from nipype.interfaces.utility import Function
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from .paths import (
ResolvedRegistrationJob,
get_registration_job_inverse_warped_image_path,
get_registration_job_transform_dir,
get_registration_job_warped_image_path,
registration_transform_filenames,
resolve_fixed_image_for_job,
resolve_moving_image_for_job,
resolve_registration_jobs,
)
__all__ = [
"_run_staged",
"build_fireants_node",
"run_fireants_cli_task",
"run_fireants_registration_task",
]
# Nipype's Function interface treats every string input that points at an
# existing file as a file-hashable input. warped_image / inverse_warped_image
# are OUTPUT paths for this node — this node itself overwrites them every run,
# so hashing their content self-invalidates the cache. The canonical Nipype
# escape hatch is ``hash_files=False`` trait metadata, but that metadata is a
# dict mutation on ``trait_type._metadata`` AFTER trait construction and does
# NOT survive the pickle round-trip that the MultiProc plugin does when
# forking workers. So workers see the default ``hash_files=True`` regardless
# of what the main process set, open the file, MD5 the bytes, and the stored
# hashfile ends up containing a content hash of an output that the next run
# will rewrite — cache miss on every run.
#
# Workaround: prefix the path with a non-filesystem-path token so Nipype's
# ``os.path.isfile(objekt)`` probe in ``specs._get_sorteddict`` returns False
# and the string is hashed as a plain string (path-only). The task strips the
# prefix at function entry so filesystem writes still land at the real path.
# The trait metadata is still set alongside (defense in depth, covers Linear
# plugin and in-process tests).
_OUTPUT_PATH_SENTINEL = "__thesis_nohash__:"
def _run_staged(
fixed_image: str,
moving_image: str,
warped_image: str,
inverse_warped_image: str,
transform_dir: str,
patient_id: str,
transform_type: str,
device: str,
scales: list[int],
affine_iterations: list[int],
deformable_iterations: list[int],
rigid_iterations: list[int],
optimizer: str,
affine_lr: float,
deformable_lr: float,
rigid_lr: float,
cc_kernel_size: int,
deformation_type: str,
dtype: str,
do_moments: bool,
do_rigid: bool,
moments_scale: int,
moments_moments: int,
deform_algo: str,
deformable_max_spacing_mm: float | None = None,
job_name: str = "patient_to_template",
loss_type: str = "cc",
normalize: bool = True,
inverse_method: str = "simpleitk",
inverse_max_iterations: int = 50,
inverse_tolerance: float = 0.01,
) -> tuple[str, str, list[str], list[str], str]:
"""Run the native staged FireANTs registration and persist its outputs.
Stages run in the canonical FireANTs order: moments (initialization) →
rigid → affine → deformable (Greedy for ``SyN``). ``transform_type`` gates
which stages run and what is exported:
* ``Rigid``: moments → rigid. Forward = one ``.mat`` (rigid matrix saved
as an ANTs transform). Reverse = the inverse 4x4 written as a ``.mat``.
* ``Affine``: moments → rigid → affine. Forward/reverse as for ``Rigid``
but using the affine matrix.
* ``SyN``: moments → rigid → affine → Greedy(``init_affine``). Forward =
one composite warp (affine baked in); reverse = the analytic inverse warp
(``save_inverse=True``). ``deform_algo="syn"`` has no analytic inverse
and therefore raises on inverse export.
The moments stage always runs in float32 because FireANTs'
``MomentsRegistration`` rejects fp16/bf16. When the configured ``dtype``
is fp16/bf16 a separate float32 ``BatchedImages`` pair is loaded just for
moments; rigid/affine/deformable run at the configured dtype.
Args:
fixed_image: Path to the fixed (template) image.
moving_image: Path to the moving (patient) image.
warped_image: Output path for the moving image warped to template space.
inverse_warped_image: Output path for the template warped to patient
space (only written for ``SyN``).
transform_dir: Directory for transform artifacts.
patient_id: Patient identifier used in output filenames.
transform_type: One of ``"Rigid"``, ``"Affine"``, ``"SyN"``.
device: Torch device string.
scales: Multi-resolution scales.
affine_iterations: Affine iterations per scale.
deformable_iterations: Deformable iterations per scale.
rigid_iterations: Rigid iterations per scale.
optimizer: FireANTs optimizer name.
affine_lr: Affine learning rate.
deformable_lr: Deformable learning rate.
rigid_lr: Rigid learning rate.
cc_kernel_size: Cross-correlation kernel size.
deformation_type: FireANTs deformation model.
dtype: Torch dtype name (``float16``/``float32``/``bfloat16``).
do_moments: Whether to run the moments initialization stage.
do_rigid: Whether to run the rigid stage.
moments_scale: Downsampling scale for the moments stage.
moments_moments: Number of moments (1 or 2).
deform_algo: ``"greedy"`` (analytic inverse) or ``"syn"``.
deformable_max_spacing_mm: Optional VRAM cap; when set and the fixed
image is finer, resample fixed+moving isotropically before greedy.
job_name: Registration job name; selects the on-disk filename chain.
loss_type: Similarity metric for rigid/affine/deformable (``"cc"``,
``"mi"``, ``"mse"``, ``"fusedcc"``, ``"fusedmi"``). On CUDA,
``"cc"``/``"mi"`` are auto-upgraded to their fused equivalents; the
moments stage always keeps its own default metric.
normalize: Min-max normalize fixed+moving intensities to [0, 1] before
registration (mirrors FireANTs' template pipeline).
Returns:
Tuple of (warped image path, inverse-warped image path, forward
transforms, reverse transforms, informational note).
"""
from pathlib import Path
import SimpleITK as sitk
import torch
from fireants.interpolator import fireants_interpolator
from fireants.io import BatchedImages, Image
from fireants.io.image import FakeBatchedImages
from fireants.registration.affine import AffineRegistration
from fireants.registration.moments import MomentsRegistration
from fireants.registration.rigid import RigidRegistration
def normalize_dtype_name(dtype_name: str) -> str:
"""Normalize configured dtype names before torch lookup."""
return dtype_name.strip().strip('"').strip("'")
def _resample_iso(img, ref, spacing_mm, out_path):
"""Resample ``img`` to isotropic ``spacing_mm`` on ``ref``'s grid and write it.
``ref`` supplies the target size (computed from its own size/spacing),
origin and direction; ``img`` provides the intensities being resampled.
"""
target_size = [int(s * sp / spacing_mm) for s, sp in zip(ref.GetSize(), ref.GetSpacing())]
resampled = sitk.Resample(
img,
target_size,
sitk.Transform(),
sitk.sitkLinear,
ref.GetOrigin(),
[spacing_mm, spacing_mm, spacing_mm],
ref.GetDirection(),
)
sitk.WriteImage(resampled, out_path)
return out_path
def _write_inverse_mat(homogeneous_matrix, out_path: str) -> None:
"""Write the inverse of a 4x4 (homogeneous) matrix as an ANTs ``.mat``."""
inverse = torch.linalg.inv(homogeneous_matrix)
dims = inverse.shape[0] - 1
sitk_transform = sitk.AffineTransform(dims)
sitk_transform.SetMatrix(inverse[:dims, :dims].reshape(-1).tolist())
sitk_transform.SetTranslation(inverse[:dims, -1].tolist())
sitk.WriteTransform(sitk_transform, out_path)
def _invert_forward_warp_sitk(
forward_warp_path: str,
moving_ref_path: str,
template_ref_path: str,
reverse_warp_path: str,
inverse_warped_path: str,
) -> None:
"""Build the reverse warp + QC image by inverting the forward field.
The forward warp written by ``save_as_ants_transforms`` is an ANTs-format
composite displacement field (affine baked in) on the *fixed/template*
grid, in physical space, mapping fixed->moving coordinates. Its inverse
maps moving->fixed and is what downstream resampling needs to warp the
template into patient space.
``InvertDisplacementFieldImageFilter`` reconstructs the inverse on the
field's own (template) grid with an explicit fixed-point convergence
tolerance; the result is then resampled onto the moving grid so the field
covers the output (patient) space that transform application evaluates it
on. Displacement vectors are physical-space, so the identity-transform
resample simply re-samples them at the moving grid's physical points.
"""
forward_field = sitk.ReadImage(forward_warp_path, sitk.sitkVectorFloat64)
inverter = sitk.InvertDisplacementFieldImageFilter()
inverter.SetMaximumNumberOfIterations(int(inverse_max_iterations))
inverter.SetMeanErrorToleranceThreshold(float(inverse_tolerance))
inverter.SetMaxErrorToleranceThreshold(max(0.1, float(inverse_tolerance) * 10.0))
inverter.SetEnforceBoundaryCondition(True)
inverse_on_template_grid = inverter.Execute(forward_field)
moving_ref = sitk.ReadImage(moving_ref_path)
identity = sitk.Transform(moving_ref.GetDimension(), sitk.sitkIdentity)
inverse_field = sitk.Resample(
inverse_on_template_grid,
moving_ref,
identity,
sitk.sitkLinear,
0.0,
inverse_on_template_grid.GetPixelID(),
)
sitk.WriteImage(inverse_field, reverse_warp_path)
# QC image: warp the template into moving/patient space via the inverse.
inverse_transform = sitk.DisplacementFieldTransform(
sitk.Cast(inverse_field, sitk.sitkVectorFloat64)
)
template = sitk.ReadImage(template_ref_path, sitk.sitkFloat32)
warped_template = sitk.Resample(
template,
moving_ref,
inverse_transform,
sitk.sitkLinear,
0.0,
sitk.sitkFloat32,
)
sitk.WriteImage(warped_template, inverse_warped_path)
# Strip the no-hash sentinel (see build_fireants_node for why). The
# main-process metadata is redundant in MultiProc because pickle drops
# trait_type._metadata, so the sentinel is the load-bearing mechanism.
_NOHASH = "__thesis_nohash__:"
if warped_image.startswith(_NOHASH):
warped_image = warped_image[len(_NOHASH) :]
if inverse_warped_image.startswith(_NOHASH):
inverse_warped_image = inverse_warped_image[len(_NOHASH) :]
transform_dir_path = Path(transform_dir)
transform_dir_path.mkdir(parents=True, exist_ok=True)
dtype_map = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
normalized_dtype = normalize_dtype_name(dtype)
if normalized_dtype not in dtype_map:
raise ValueError(
f"Unsupported FireANTs dtype '{dtype}'. Expected one of {sorted(dtype_map)}."
)
torch_dtype = dtype_map[normalized_dtype]
# Pick the effective similarity metric. FireANTs' own scripts use the
# GPU-fused kernels on CUDA ("fusedcc"/"fusedmi") and the plain kernels on
# CPU; mirror that. abstract.py falls back to non-fused if the fused ops
# are not importable, so this only ever upgrades — never breaks.
use_cuda = str(device).startswith("cuda")
if use_cuda:
loss_type_eff = {"cc": "fusedcc", "mi": "fusedmi"}.get(loss_type, loss_type)
elif loss_type.startswith("fused"):
loss_type_eff = loss_type.replace("fused", "")
else:
loss_type_eff = loss_type
print(
f"FireANTs metric: requested loss_type='{loss_type}', "
f"effective='{loss_type_eff}', normalize={normalize}"
)
def _normalize01(arr):
"""Min-max normalize a ``[B, C, ...]`` tensor to [0, 1] per (batch, channel).
Matches FireANTs' ``template_helpers.normalize``; a small epsilon guards
against constant images.
"""
flat = arr.flatten(2)
amin = flat.min(dim=2).values
amax = flat.max(dim=2).values
for _ in range(arr.ndim - 2):
amin = amin.unsqueeze(-1)
amax = amax.unsqueeze(-1)
return (arr - amin) / (amax - amin).clamp_min(1e-8)
def _load_image(path: str, load_dtype: "torch.dtype") -> "Image":
"""Load an image and (optionally) normalize its intensities to [0, 1]."""
image = Image.load_file(path, device=device, dtype=load_dtype)
if normalize:
image.array = _normalize01(image.array)
return image
is_deformable = transform_type not in {"Rigid", "Affine"}
# Resolve the canonical filenames (shared with the transform resolver).
names = registration_transform_filenames(patient_id, job_name, transform_type)
forward_path = transform_dir_path / names["forward"][0]
reverse_path = transform_dir_path / names["reverse"][0]
# ------------------------------------------------------------------
# Optional VRAM cap: resample fixed + moving before the deformable step.
# ------------------------------------------------------------------
fixed_for_reg = fixed_image
moving_for_reg = moving_image
if is_deformable and deformable_max_spacing_mm is not None:
fixed_sitk = sitk.ReadImage(fixed_image)
min_spacing = min(fixed_sitk.GetSpacing())
if min_spacing < float(deformable_max_spacing_mm):
cap = float(deformable_max_spacing_mm)
fixed_for_reg = _resample_iso(
fixed_sitk,
fixed_sitk,
cap,
str(transform_dir_path / f"{patient_id}_fixed_deformable.nii.gz"),
)
moving_for_reg = _resample_iso(
sitk.ReadImage(moving_image),
fixed_sitk,
cap,
str(transform_dir_path / f"{patient_id}_moving_deformable.nii.gz"),
)
try:
fixed = BatchedImages([_load_image(fixed_for_reg, torch_dtype)])
moving = BatchedImages([_load_image(moving_for_reg, torch_dtype)])
except RuntimeError as exc:
message = str(exc)
if "no kernel image is available for execution on the device" in message:
raise RuntimeError(
"FireANTs CUDA initialization failed because the installed PyTorch build does not "
"support this GPU architecture. Set registration.fireants.device='cpu' or install "
"a PyTorch build compatible with this GPU. Original error: "
f"{message}"
) from exc
raise
# GPU-holding handles; pre-bound so the VRAM teardown can del them
# unconditionally regardless of which stages executed.
fixed_f32: Any = None
moving_f32: Any = None
moments: Any = None
rigid: Any = None
affine: Any = None
reg: Any = None
inverse_params: Any = None
inverse_moved: Any = None
moved: Any = None
forward_transforms: list[str] = []
reverse_transforms: list[str] = []
note = ""
# ------------------------------------------------------------------
# Stage 1 — moments (initialization). Must run in float32.
# ------------------------------------------------------------------
init_translation = None
init_moment = None
init_rigid = None
if do_moments:
if torch_dtype != torch.float32:
# MomentsRegistration rejects fp16/bf16 — load a float32 pair.
fixed_f32 = BatchedImages([_load_image(fixed_for_reg, torch.float32)])
moving_f32 = BatchedImages([_load_image(moving_for_reg, torch.float32)])
moments_fixed, moments_moving = fixed_f32, moving_f32
else:
moments_fixed, moments_moving = fixed, moving
moments = MomentsRegistration(
scale=moments_scale,
fixed_images=moments_fixed,
moving_images=moments_moving,
moments=moments_moments,
)
moments.optimize()
init_translation = moments.get_rigid_transl_init()
init_moment = moments.get_rigid_moment_init()
init_rigid = moments.get_affine_init()
# ------------------------------------------------------------------
# Stage 2 — rigid.
# ------------------------------------------------------------------
if do_rigid:
rigid = RigidRegistration(
scales=scales,
iterations=rigid_iterations,
fixed_images=fixed,
moving_images=moving,
loss_type=loss_type_eff,
optimizer=optimizer,
optimizer_lr=rigid_lr,
cc_kernel_size=cc_kernel_size,
init_translation=init_translation,
init_moment=init_moment,
dtype=torch_dtype,
)
rigid.optimize()
init_rigid = rigid.get_rigid_matrix()
# ------------------------------------------------------------------
# Stage 3 — affine / linear export, or deformable.
# ------------------------------------------------------------------
if transform_type in {"Rigid", "Affine"}:
if transform_type == "Affine":
affine = AffineRegistration(
scales=scales,
iterations=affine_iterations,
fixed_images=fixed,
moving_images=moving,
loss_type=loss_type_eff,
optimizer=optimizer,
optimizer_lr=affine_lr,
cc_kernel_size=cc_kernel_size,
init_rigid=init_rigid,
dtype=torch_dtype,
)
affine.optimize()
linear_stage: Any = affine
homogeneous = affine.get_affine_matrix().detach().cpu()[0]
else:
# Rigid: the rigid stage is the linear transform. When rigid is
# disabled, fall back to the moments affine init.
if rigid is not None:
linear_stage = rigid
homogeneous = rigid.get_rigid_matrix().detach().cpu()[0]
elif moments is not None:
linear_stage = moments
homogeneous = moments.get_affine_init().detach().cpu()[0]
# get_affine_init is [d, d+1]; pad to homogeneous 4x4.
row = torch.zeros(1, homogeneous.shape[1], dtype=homogeneous.dtype)
row[0, -1] = 1.0
homogeneous = torch.cat([homogeneous, row], dim=0)
else:
raise ValueError(
"Rigid registration requires do_rigid or do_moments to be enabled."
)
linear_stage.save_as_ants_transforms(str(forward_path))
forward_transforms = [str(forward_path)]
_write_inverse_mat(homogeneous, str(reverse_path))
reverse_transforms = [str(reverse_path)]
moved = linear_stage.evaluate(fixed, moving)
note = (
"FireANTs native staged linear registration "
f"({'affine' if transform_type == 'Affine' else 'rigid'}): "
"exported forward and inverse transforms."
)
else:
# ------------------------------------------------------------------
# Deformable registration (SyN-family).
# ------------------------------------------------------------------
if affine is None and (do_moments or do_rigid):
# Refine the affine before the deformable stage for best init.
affine = AffineRegistration(
scales=scales,
iterations=affine_iterations,
fixed_images=fixed,
moving_images=moving,
loss_type=loss_type_eff,
optimizer=optimizer,
optimizer_lr=affine_lr,
cc_kernel_size=cc_kernel_size,
init_rigid=init_rigid,
dtype=torch_dtype,
)
affine.optimize()
init_affine = affine.get_affine_matrix().detach()
else:
init_affine = init_rigid
if deform_algo == "syn":
from fireants.registration.syn import SyNRegistration
deform_cls: Any = SyNRegistration
else:
from fireants.registration.greedy import GreedyRegistration
deform_cls = GreedyRegistration
reg = deform_cls(
scales=scales,
iterations=deformable_iterations,
fixed_images=fixed,
moving_images=moving,
loss_type=loss_type_eff,
optimizer=optimizer,
optimizer_lr=deformable_lr,
cc_kernel_size=cc_kernel_size,
deformation_type=deformation_type,
smooth_grad_sigma=1,
init_affine=init_affine,
dtype=torch_dtype,
)
reg.optimize()
moved = reg.evaluate(fixed, moving)
# Single composite forward warp (affine baked in). This is the
# high-quality direction; both greedy and syn produce it.
reg.save_as_ants_transforms(str(forward_path))
forward_transforms = [str(forward_path)]
reverse_transforms = [str(reverse_path)]
if inverse_method == "simpleitk":
# Numerically invert the forward field (ANTs-free, explicit
# convergence). Independent of FireANTs' analytic inverse, so it
# works for syn as well as greedy.
_invert_forward_warp_sitk(
forward_warp_path=str(forward_path),
moving_ref_path=moving_for_reg,
template_ref_path=fixed_for_reg,
reverse_warp_path=str(reverse_path),
inverse_warped_path=str(inverse_warped_image),
)
note = (
f"FireANTs native staged deformable registration ({deform_algo}): "
"composite forward warp + inverse (reverse) warp via SimpleITK "
f"displacement-field inversion of the forward field "
f"(max_iter={inverse_max_iterations}, tol={inverse_tolerance})."
)
else:
if deform_algo == "syn":
raise NotImplementedError(
"deform_algo='syn' has no analytic inverse warp "
"(SyNRegistration.get_inverse_warp_parameters raises). Use "
"deform_algo='greedy', or set "
"registration.fireants.inverse_method='simpleitk'."
)
# Analytic inverse composite warp (FireANTs inverse-consistency solve).
reg.save_as_ants_transforms(str(reverse_path), save_inverse=True)
# Inverse-warped QC image (template → patient).
inverse_params = reg.get_inverse_warp_parameters(fixed, moving)
inverse_moved = fireants_interpolator(
fixed(),
**inverse_params,
mode=fixed.get_interpolator_type(),
align_corners=True,
)
FakeBatchedImages(inverse_moved, moving).write_image(str(inverse_warped_image))
note = (
"FireANTs native staged deformable registration (greedy): single "
"composite forward warp + analytic inverse warp via "
"save_as_ants_transforms."
)
FakeBatchedImages(moved, fixed).write_image(str(warped_image))
# Linear transforms produce no inverse-warped QC image; clean up a stale one.
inverse_path_obj = Path(inverse_warped_image)
if not is_deformable and inverse_path_obj.exists():
inverse_path_obj.unlink()
# Clean up VRAM-cap temp files.
for tmp_name in [
f"{patient_id}_fixed_deformable.nii.gz",
f"{patient_id}_moving_deformable.nii.gz",
]:
(transform_dir_path / tmp_name).unlink(missing_ok=True)
# Release PyTorch CUDA cache so subsequent GPU tasks (e.g. eddy_cuda,
# a second patient's fireants) can allocate VRAM without OOM.
import gc
del fixed, moving, fixed_f32, moving_f32
del moments, rigid, affine, reg
del moved, inverse_params, inverse_moved
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
return str(warped_image), "", forward_transforms, reverse_transforms, note
[docs]
def run_fireants_registration_task(
fixed_image: str,
moving_image: str,
warped_image: str,
inverse_warped_image: str,
transform_dir: str,
patient_id: str,
transform_type: str,
device: str,
scales: list[int],
affine_iterations: list[int],
deformable_iterations: list[int],
rigid_iterations: list[int],
optimizer: str,
affine_lr: float,
deformable_lr: float,
rigid_lr: float,
cc_kernel_size: int,
deformation_type: str,
dtype: str,
do_moments: bool,
do_rigid: bool,
moments_scale: int,
moments_moments: int,
deform_algo: str,
job_name: str = "patient_to_template",
deformable_max_spacing_mm: float | None = None,
use_gpu: bool = True,
loss_type: str = "cc",
normalize: bool = True,
inverse_method: str = "simpleitk",
inverse_max_iterations: int = 50,
inverse_tolerance: float = 0.01,
) -> tuple[str, str, list[str], list[str], str]:
"""Thin Function-node wrapper around :func:`_run_staged`.
Strips the ``_OUTPUT_PATH_SENTINEL`` from the output paths (the sentinel
is added by :func:`build_fireants_node` to dodge Nipype's file-content
hashing) and delegates to :func:`_run_staged`.
Returns:
Tuple of warped image path, inverse-warped image path, forward
transforms, reverse transforms, and an informational note.
"""
from thesis.workflows.registration.fireants_backend import _run_staged
return _run_staged(
fixed_image=fixed_image,
moving_image=moving_image,
warped_image=warped_image,
inverse_warped_image=inverse_warped_image,
transform_dir=transform_dir,
patient_id=patient_id,
transform_type=transform_type,
device=device,
scales=scales,
affine_iterations=affine_iterations,
deformable_iterations=deformable_iterations,
rigid_iterations=rigid_iterations,
optimizer=optimizer,
affine_lr=affine_lr,
deformable_lr=deformable_lr,
rigid_lr=rigid_lr,
cc_kernel_size=cc_kernel_size,
deformation_type=deformation_type,
dtype=dtype,
do_moments=do_moments,
do_rigid=do_rigid,
moments_scale=moments_scale,
moments_moments=moments_moments,
deform_algo=deform_algo,
job_name=job_name,
deformable_max_spacing_mm=deformable_max_spacing_mm,
loss_type=loss_type,
normalize=normalize,
inverse_method=inverse_method,
inverse_max_iterations=inverse_max_iterations,
inverse_tolerance=inverse_tolerance,
)
[docs]
def run_fireants_cli_task(
fixed_image: str,
moving_image: str,
warped_image: str,
inverse_warped_image: str,
transform_dir: str,
patient_id: str,
transform_type: str,
device: str,
scales: list[int],
affine_iterations: list[int],
deformable_iterations: list[int],
rigid_iterations: list[int],
optimizer: str,
affine_lr: float,
deformable_lr: float,
rigid_lr: float,
cc_kernel_size: int,
deformation_type: str,
dtype: str,
do_moments: bool,
do_rigid: bool,
moments_scale: int,
moments_moments: int,
deform_algo: str,
job_name: str = "patient_to_template",
deformable_max_spacing_mm: float | None = None,
use_gpu: bool = True,
loss_type: str = "cc",
normalize: bool = True,
inverse_method: str = "simpleitk",
inverse_max_iterations: int = 50,
inverse_tolerance: float = 0.01,
) -> tuple[str, str, list[str], list[str], str]:
"""Run staged FireANTs registration out-of-process via the CLI (Flavor 2).
Strips the ``_OUTPUT_PATH_SENTINEL`` from the output paths (added by
:func:`build_fireants_node`), builds the ``thesis-fireants-register`` argv,
shells out with :func:`subprocess.run`, and parses the last stdout line as
the JSON result. On a non-zero exit whose stderr is the CUDA "no kernel
image" failure, the helpful guidance message is re-raised so the user sees
the same advice as the in-process path.
Returns:
Tuple of warped image path, inverse-warped image path, forward
transforms, reverse transforms, and an informational note.
"""
import json
import subprocess
import sys
_NOHASH = "__thesis_nohash__:"
if warped_image.startswith(_NOHASH):
warped_image = warped_image[len(_NOHASH) :]
if inverse_warped_image.startswith(_NOHASH):
inverse_warped_image = inverse_warped_image[len(_NOHASH) :]
argv: list[str] = [
sys.executable,
"-m",
"thesis.cli_tools.fireants_register",
"--fixed",
str(fixed_image),
"--moving",
str(moving_image),
"--warped",
str(warped_image),
"--inverse-warped",
str(inverse_warped_image),
"--transform-dir",
str(transform_dir),
"--patient-id",
str(patient_id),
"--job-name",
str(job_name),
"--transform-type",
str(transform_type),
"--device",
str(device),
"--dtype",
str(dtype),
"--optimizer",
str(optimizer),
"--affine-lr",
str(affine_lr),
"--deformable-lr",
str(deformable_lr),
"--rigid-lr",
str(rigid_lr),
"--cc-kernel-size",
str(cc_kernel_size),
"--deformation-type",
str(deformation_type),
"--deform-algo",
str(deform_algo),
"--loss-type",
str(loss_type),
"--inverse-method",
str(inverse_method),
"--inverse-max-iterations",
str(inverse_max_iterations),
"--inverse-tolerance",
str(inverse_tolerance),
"--moments-scale",
str(moments_scale),
"--moments-moments",
str(moments_moments),
"--scales",
*[str(s) for s in scales],
"--affine-iterations",
*[str(i) for i in affine_iterations],
"--deformable-iterations",
*[str(i) for i in deformable_iterations],
"--rigid-iterations",
*[str(i) for i in rigid_iterations],
]
if not do_moments:
argv.append("--no-moments")
if not do_rigid:
argv.append("--no-rigid")
if not normalize:
argv.append("--no-normalize")
if deformable_max_spacing_mm is not None:
argv += ["--deformable-max-spacing-mm", str(deformable_max_spacing_mm)]
completed = subprocess.run(argv, capture_output=True, text=True)
if completed.returncode != 0:
stderr = completed.stderr or ""
if "no kernel image is available for execution on the device" in stderr:
raise RuntimeError(
"FireANTs CUDA initialization failed because the installed PyTorch build does "
"not support this GPU architecture. Set registration.fireants.device='cpu' or "
"install a PyTorch build compatible with this GPU. Original error: "
f"{stderr}"
)
raise RuntimeError(
f"thesis-fireants-register failed (exit {completed.returncode}): {stderr}"
)
stdout = (completed.stdout or "").strip()
last_line = stdout.splitlines()[-1] if stdout else ""
result = json.loads(last_line)
return (
result["warped_image"],
result["inverse_warped_image"],
result["forward_transforms"],
result["reverse_transforms"],
result["note"],
)
[docs]
def build_fireants_node(
config: PipelineConfig,
context: ProcessingContext,
*,
moving: Path | None = None,
fixed: Path | None = None,
job: ResolvedRegistrationJob | None = None,
) -> Node:
"""Build the FireANTs registration execution node.
Args:
config: The merged pipeline configuration.
context: The processing context.
moving: Optional explicit moving-image override (resolved from *job*
otherwise).
fixed: Optional explicit fixed-image override (resolved from *job*
otherwise).
job: The resolved registration job this node implements. When ``None``
the implicit default job is used (legacy behaviour).
Returns:
The configured FireANTs Nipype Function node. The default job keeps the
legacy node name ``"fireants_registration"``; explicit jobs use
``"fireants_<job>"``.
"""
if job is None:
job = resolve_registration_jobs(config)[0]
fireants_cfg = job.fireants
if fixed is None:
fixed = resolve_fixed_image_for_job(config, context, job)
if moving is None:
moving = resolve_moving_image_for_job(config, context, job)
node_name = "fireants_registration" if job.is_default else f"fireants_{job.safe_name}"
# Flavor 1 (in-process) is the default; Flavor 2 shells out to the CLI.
task_function = (
run_fireants_cli_task
if getattr(fireants_cfg, "driver", "inprocess") == "cli"
else run_fireants_registration_task
)
node = Node(
Function(
input_names=[
"fixed_image",
"moving_image",
"warped_image",
"inverse_warped_image",
"transform_dir",
"patient_id",
"job_name",
"transform_type",
"device",
"scales",
"affine_iterations",
"deformable_iterations",
"rigid_iterations",
"optimizer",
"affine_lr",
"deformable_lr",
"rigid_lr",
"cc_kernel_size",
"deformation_type",
"dtype",
"do_moments",
"do_rigid",
"moments_scale",
"moments_moments",
"deform_algo",
"deformable_max_spacing_mm",
"use_gpu",
"loss_type",
"normalize",
"inverse_method",
"inverse_max_iterations",
"inverse_tolerance",
],
output_names=[
"warped_image",
"inverse_warped_image",
"forward_transforms",
"reverse_transforms",
"note",
],
function=task_function,
),
name=node_name,
)
node.inputs.fixed_image = str(fixed)
node.inputs.moving_image = str(moving)
# See module-level comment on _OUTPUT_PATH_SENTINEL for why we prefix.
# Short version: trait metadata is dropped by MultiProc's pickle, so the
# hash_files=False knob doesn't reach the worker that writes the hashfile.
# Prefixing with a non-path token makes os.path.isfile() return False in
# the worker too, so the string is hashed as a string regardless of
# whether metadata survived.
node.inputs.warped_image = _OUTPUT_PATH_SENTINEL + str(
get_registration_job_warped_image_path(config, context, job)
)
node.inputs.inverse_warped_image = _OUTPUT_PATH_SENTINEL + str(
get_registration_job_inverse_warped_image_path(config, context, job)
)
node.inputs.transform_dir = str(get_registration_job_transform_dir(config, context, job))
# Keep the trait metadata too — defensive, covers the Linear plugin and
# in-process callers that never pickle. Metadata lives on
# trait_type._metadata; assigning directly to the CTrait wrapper
# (e.g. ``trait(...).hash_files = False``) does not take effect because
# Nipype's has_metadata reads from _metadata.
for out_field in ("warped_image", "inverse_warped_image"):
node.inputs.trait(out_field).trait_type._metadata["hash_files"] = False
node.inputs.patient_id = context.patient_id
node.inputs.job_name = job.name
node.inputs.transform_type = job.transform_type
node.inputs.device = fireants_cfg.device
node.inputs.scales = list(fireants_cfg.scales)
node.inputs.affine_iterations = list(fireants_cfg.affine_iterations)
node.inputs.deformable_iterations = list(fireants_cfg.deformable_iterations)
node.inputs.rigid_iterations = list(fireants_cfg.rigid_iterations)
node.inputs.optimizer = fireants_cfg.optimizer
node.inputs.affine_lr = fireants_cfg.affine_lr
node.inputs.deformable_lr = fireants_cfg.deformable_lr
node.inputs.rigid_lr = fireants_cfg.rigid_lr
node.inputs.cc_kernel_size = fireants_cfg.cc_kernel_size
node.inputs.deformation_type = fireants_cfg.deformation_type
node.inputs.dtype = fireants_cfg.dtype
node.inputs.do_moments = fireants_cfg.do_moments
node.inputs.do_rigid = fireants_cfg.do_rigid
node.inputs.moments_scale = fireants_cfg.moments_scale
node.inputs.moments_moments = fireants_cfg.moments_moments
node.inputs.deform_algo = fireants_cfg.deform_algo
node.inputs.deformable_max_spacing_mm = fireants_cfg.deformable_max_spacing_mm
node.inputs.loss_type = fireants_cfg.loss_type
node.inputs.normalize = fireants_cfg.normalize
node.inputs.inverse_method = getattr(fireants_cfg, "inverse_method", "simpleitk")
node.inputs.inverse_max_iterations = getattr(fireants_cfg, "inverse_max_iterations", 50)
node.inputs.inverse_tolerance = getattr(fireants_cfg, "inverse_tolerance", 0.01)
node._mem_gb = max(4.0, float(getattr(config.hardware, "memory_gb", 16)) / 4.0)
# Declare GPU usage so Nipype's MultiProc scheduler counts this node toward
# the n_gpu_procs budget. Node.is_gpu_node() checks for use_gpu/use_cuda in
# node.inputs; _n_gpus is NOT used by the scheduler for Function nodes.
node.inputs.use_gpu = True
return node