Source code for thesis.workflows.registration.fireants_backend

"""FireANTs backend for patient-to-template registration.

The native staged flow mirrors FireANTs' own reference pipeline
(``MomentsRegistration → RigidRegistration → AffineRegistration →
GreedyRegistration``).  Moments registration supplies center-of-mass and
principal-axis initialization, which is what the old ANTs hack was substituting
for; rigid/affine refine that initialization; and (for ``SyN``) Greedy/SyN bakes
the affine into a single composite forward warp.  No ANTs subprocesses.

The reverse (template->patient) warp is produced by one of two ANTs-free routes,
selected via ``registration.fireants.inverse_method``:

* ``simpleitk`` (default): numerically invert the high-quality *forward*
  displacement field with SimpleITK's ``InvertDisplacementFieldImageFilter``
  (explicit convergence tolerance).  Works for both ``greedy`` and ``syn`` since
  it only needs the forward field, so it also unblocks ``deform_algo='syn'``.
* ``fireants``: FireANTs' own inverse-consistency solve
  (``save_as_ants_transforms(save_inverse=True)``); ``greedy`` only.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

from nipype import Node
from nipype.interfaces.utility import Function

from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext

from .paths import (
    ResolvedRegistrationJob,
    get_registration_job_inverse_warped_image_path,
    get_registration_job_transform_dir,
    get_registration_job_warped_image_path,
    registration_transform_filenames,
    resolve_fixed_image_for_job,
    resolve_moving_image_for_job,
    resolve_registration_jobs,
)

__all__ = [
    "_run_staged",
    "build_fireants_node",
    "run_fireants_cli_task",
    "run_fireants_registration_task",
]


# Nipype's Function interface treats every string input that points at an
# existing file as a file-hashable input.  warped_image / inverse_warped_image
# are OUTPUT paths for this node — this node itself overwrites them every run,
# so hashing their content self-invalidates the cache.  The canonical Nipype
# escape hatch is ``hash_files=False`` trait metadata, but that metadata is a
# dict mutation on ``trait_type._metadata`` AFTER trait construction and does
# NOT survive the pickle round-trip that the MultiProc plugin does when
# forking workers.  So workers see the default ``hash_files=True`` regardless
# of what the main process set, open the file, MD5 the bytes, and the stored
# hashfile ends up containing a content hash of an output that the next run
# will rewrite — cache miss on every run.
#
# Workaround: prefix the path with a non-filesystem-path token so Nipype's
# ``os.path.isfile(objekt)`` probe in ``specs._get_sorteddict`` returns False
# and the string is hashed as a plain string (path-only).  The task strips the
# prefix at function entry so filesystem writes still land at the real path.
# The trait metadata is still set alongside (defense in depth, covers Linear
# plugin and in-process tests).
_OUTPUT_PATH_SENTINEL = "__thesis_nohash__:"


def _run_staged(
    fixed_image: str,
    moving_image: str,
    warped_image: str,
    inverse_warped_image: str,
    transform_dir: str,
    patient_id: str,
    transform_type: str,
    device: str,
    scales: list[int],
    affine_iterations: list[int],
    deformable_iterations: list[int],
    rigid_iterations: list[int],
    optimizer: str,
    affine_lr: float,
    deformable_lr: float,
    rigid_lr: float,
    cc_kernel_size: int,
    deformation_type: str,
    dtype: str,
    do_moments: bool,
    do_rigid: bool,
    moments_scale: int,
    moments_moments: int,
    deform_algo: str,
    deformable_max_spacing_mm: float | None = None,
    job_name: str = "patient_to_template",
    loss_type: str = "cc",
    normalize: bool = True,
    inverse_method: str = "simpleitk",
    inverse_max_iterations: int = 50,
    inverse_tolerance: float = 0.01,
) -> tuple[str, str, list[str], list[str], str]:
    """Run the native staged FireANTs registration and persist its outputs.

    Stages run in the canonical FireANTs order: moments (initialization) →
    rigid → affine → deformable (Greedy for ``SyN``).  ``transform_type`` gates
    which stages run and what is exported:

    * ``Rigid``: moments → rigid.  Forward = one ``.mat`` (rigid matrix saved
      as an ANTs transform).  Reverse = the inverse 4x4 written as a ``.mat``.
    * ``Affine``: moments → rigid → affine.  Forward/reverse as for ``Rigid``
      but using the affine matrix.
    * ``SyN``: moments → rigid → affine → Greedy(``init_affine``).  Forward =
      one composite warp (affine baked in); reverse = the analytic inverse warp
      (``save_inverse=True``).  ``deform_algo="syn"`` has no analytic inverse
      and therefore raises on inverse export.

    The moments stage always runs in float32 because FireANTs'
    ``MomentsRegistration`` rejects fp16/bf16.  When the configured ``dtype``
    is fp16/bf16 a separate float32 ``BatchedImages`` pair is loaded just for
    moments; rigid/affine/deformable run at the configured dtype.

    Args:
        fixed_image: Path to the fixed (template) image.
        moving_image: Path to the moving (patient) image.
        warped_image: Output path for the moving image warped to template space.
        inverse_warped_image: Output path for the template warped to patient
            space (only written for ``SyN``).
        transform_dir: Directory for transform artifacts.
        patient_id: Patient identifier used in output filenames.
        transform_type: One of ``"Rigid"``, ``"Affine"``, ``"SyN"``.
        device: Torch device string.
        scales: Multi-resolution scales.
        affine_iterations: Affine iterations per scale.
        deformable_iterations: Deformable iterations per scale.
        rigid_iterations: Rigid iterations per scale.
        optimizer: FireANTs optimizer name.
        affine_lr: Affine learning rate.
        deformable_lr: Deformable learning rate.
        rigid_lr: Rigid learning rate.
        cc_kernel_size: Cross-correlation kernel size.
        deformation_type: FireANTs deformation model.
        dtype: Torch dtype name (``float16``/``float32``/``bfloat16``).
        do_moments: Whether to run the moments initialization stage.
        do_rigid: Whether to run the rigid stage.
        moments_scale: Downsampling scale for the moments stage.
        moments_moments: Number of moments (1 or 2).
        deform_algo: ``"greedy"`` (analytic inverse) or ``"syn"``.
        deformable_max_spacing_mm: Optional VRAM cap; when set and the fixed
            image is finer, resample fixed+moving isotropically before greedy.
        job_name: Registration job name; selects the on-disk filename chain.
        loss_type: Similarity metric for rigid/affine/deformable (``"cc"``,
            ``"mi"``, ``"mse"``, ``"fusedcc"``, ``"fusedmi"``).  On CUDA,
            ``"cc"``/``"mi"`` are auto-upgraded to their fused equivalents; the
            moments stage always keeps its own default metric.
        normalize: Min-max normalize fixed+moving intensities to [0, 1] before
            registration (mirrors FireANTs' template pipeline).

    Returns:
        Tuple of (warped image path, inverse-warped image path, forward
        transforms, reverse transforms, informational note).
    """
    from pathlib import Path

    import SimpleITK as sitk
    import torch
    from fireants.interpolator import fireants_interpolator
    from fireants.io import BatchedImages, Image
    from fireants.io.image import FakeBatchedImages
    from fireants.registration.affine import AffineRegistration
    from fireants.registration.moments import MomentsRegistration
    from fireants.registration.rigid import RigidRegistration

    def normalize_dtype_name(dtype_name: str) -> str:
        """Normalize configured dtype names before torch lookup."""
        return dtype_name.strip().strip('"').strip("'")

    def _resample_iso(img, ref, spacing_mm, out_path):
        """Resample ``img`` to isotropic ``spacing_mm`` on ``ref``'s grid and write it.

        ``ref`` supplies the target size (computed from its own size/spacing),
        origin and direction; ``img`` provides the intensities being resampled.
        """
        target_size = [int(s * sp / spacing_mm) for s, sp in zip(ref.GetSize(), ref.GetSpacing())]
        resampled = sitk.Resample(
            img,
            target_size,
            sitk.Transform(),
            sitk.sitkLinear,
            ref.GetOrigin(),
            [spacing_mm, spacing_mm, spacing_mm],
            ref.GetDirection(),
        )
        sitk.WriteImage(resampled, out_path)
        return out_path

    def _write_inverse_mat(homogeneous_matrix, out_path: str) -> None:
        """Write the inverse of a 4x4 (homogeneous) matrix as an ANTs ``.mat``."""
        inverse = torch.linalg.inv(homogeneous_matrix)
        dims = inverse.shape[0] - 1
        sitk_transform = sitk.AffineTransform(dims)
        sitk_transform.SetMatrix(inverse[:dims, :dims].reshape(-1).tolist())
        sitk_transform.SetTranslation(inverse[:dims, -1].tolist())
        sitk.WriteTransform(sitk_transform, out_path)

    def _invert_forward_warp_sitk(
        forward_warp_path: str,
        moving_ref_path: str,
        template_ref_path: str,
        reverse_warp_path: str,
        inverse_warped_path: str,
    ) -> None:
        """Build the reverse warp + QC image by inverting the forward field.

        The forward warp written by ``save_as_ants_transforms`` is an ANTs-format
        composite displacement field (affine baked in) on the *fixed/template*
        grid, in physical space, mapping fixed->moving coordinates.  Its inverse
        maps moving->fixed and is what downstream resampling needs to warp the
        template into patient space.

        ``InvertDisplacementFieldImageFilter`` reconstructs the inverse on the
        field's own (template) grid with an explicit fixed-point convergence
        tolerance; the result is then resampled onto the moving grid so the field
        covers the output (patient) space that transform application evaluates it
        on.  Displacement vectors are physical-space, so the identity-transform
        resample simply re-samples them at the moving grid's physical points.
        """
        forward_field = sitk.ReadImage(forward_warp_path, sitk.sitkVectorFloat64)

        inverter = sitk.InvertDisplacementFieldImageFilter()
        inverter.SetMaximumNumberOfIterations(int(inverse_max_iterations))
        inverter.SetMeanErrorToleranceThreshold(float(inverse_tolerance))
        inverter.SetMaxErrorToleranceThreshold(max(0.1, float(inverse_tolerance) * 10.0))
        inverter.SetEnforceBoundaryCondition(True)
        inverse_on_template_grid = inverter.Execute(forward_field)

        moving_ref = sitk.ReadImage(moving_ref_path)
        identity = sitk.Transform(moving_ref.GetDimension(), sitk.sitkIdentity)
        inverse_field = sitk.Resample(
            inverse_on_template_grid,
            moving_ref,
            identity,
            sitk.sitkLinear,
            0.0,
            inverse_on_template_grid.GetPixelID(),
        )
        sitk.WriteImage(inverse_field, reverse_warp_path)

        # QC image: warp the template into moving/patient space via the inverse.
        inverse_transform = sitk.DisplacementFieldTransform(
            sitk.Cast(inverse_field, sitk.sitkVectorFloat64)
        )
        template = sitk.ReadImage(template_ref_path, sitk.sitkFloat32)
        warped_template = sitk.Resample(
            template,
            moving_ref,
            inverse_transform,
            sitk.sitkLinear,
            0.0,
            sitk.sitkFloat32,
        )
        sitk.WriteImage(warped_template, inverse_warped_path)

    # Strip the no-hash sentinel (see build_fireants_node for why).  The
    # main-process metadata is redundant in MultiProc because pickle drops
    # trait_type._metadata, so the sentinel is the load-bearing mechanism.
    _NOHASH = "__thesis_nohash__:"
    if warped_image.startswith(_NOHASH):
        warped_image = warped_image[len(_NOHASH) :]
    if inverse_warped_image.startswith(_NOHASH):
        inverse_warped_image = inverse_warped_image[len(_NOHASH) :]

    transform_dir_path = Path(transform_dir)
    transform_dir_path.mkdir(parents=True, exist_ok=True)

    dtype_map = {
        "float16": torch.float16,
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
    }
    normalized_dtype = normalize_dtype_name(dtype)
    if normalized_dtype not in dtype_map:
        raise ValueError(
            f"Unsupported FireANTs dtype '{dtype}'. Expected one of {sorted(dtype_map)}."
        )
    torch_dtype = dtype_map[normalized_dtype]

    # Pick the effective similarity metric.  FireANTs' own scripts use the
    # GPU-fused kernels on CUDA ("fusedcc"/"fusedmi") and the plain kernels on
    # CPU; mirror that.  abstract.py falls back to non-fused if the fused ops
    # are not importable, so this only ever upgrades — never breaks.
    use_cuda = str(device).startswith("cuda")
    if use_cuda:
        loss_type_eff = {"cc": "fusedcc", "mi": "fusedmi"}.get(loss_type, loss_type)
    elif loss_type.startswith("fused"):
        loss_type_eff = loss_type.replace("fused", "")
    else:
        loss_type_eff = loss_type
    print(
        f"FireANTs metric: requested loss_type='{loss_type}', "
        f"effective='{loss_type_eff}', normalize={normalize}"
    )

    def _normalize01(arr):
        """Min-max normalize a ``[B, C, ...]`` tensor to [0, 1] per (batch, channel).

        Matches FireANTs' ``template_helpers.normalize``; a small epsilon guards
        against constant images.
        """
        flat = arr.flatten(2)
        amin = flat.min(dim=2).values
        amax = flat.max(dim=2).values
        for _ in range(arr.ndim - 2):
            amin = amin.unsqueeze(-1)
            amax = amax.unsqueeze(-1)
        return (arr - amin) / (amax - amin).clamp_min(1e-8)

    def _load_image(path: str, load_dtype: "torch.dtype") -> "Image":
        """Load an image and (optionally) normalize its intensities to [0, 1]."""
        image = Image.load_file(path, device=device, dtype=load_dtype)
        if normalize:
            image.array = _normalize01(image.array)
        return image

    is_deformable = transform_type not in {"Rigid", "Affine"}

    # Resolve the canonical filenames (shared with the transform resolver).
    names = registration_transform_filenames(patient_id, job_name, transform_type)
    forward_path = transform_dir_path / names["forward"][0]
    reverse_path = transform_dir_path / names["reverse"][0]

    # ------------------------------------------------------------------
    # Optional VRAM cap: resample fixed + moving before the deformable step.
    # ------------------------------------------------------------------
    fixed_for_reg = fixed_image
    moving_for_reg = moving_image
    if is_deformable and deformable_max_spacing_mm is not None:
        fixed_sitk = sitk.ReadImage(fixed_image)
        min_spacing = min(fixed_sitk.GetSpacing())
        if min_spacing < float(deformable_max_spacing_mm):
            cap = float(deformable_max_spacing_mm)
            fixed_for_reg = _resample_iso(
                fixed_sitk,
                fixed_sitk,
                cap,
                str(transform_dir_path / f"{patient_id}_fixed_deformable.nii.gz"),
            )
            moving_for_reg = _resample_iso(
                sitk.ReadImage(moving_image),
                fixed_sitk,
                cap,
                str(transform_dir_path / f"{patient_id}_moving_deformable.nii.gz"),
            )

    try:
        fixed = BatchedImages([_load_image(fixed_for_reg, torch_dtype)])
        moving = BatchedImages([_load_image(moving_for_reg, torch_dtype)])
    except RuntimeError as exc:
        message = str(exc)
        if "no kernel image is available for execution on the device" in message:
            raise RuntimeError(
                "FireANTs CUDA initialization failed because the installed PyTorch build does not "
                "support this GPU architecture. Set registration.fireants.device='cpu' or install "
                "a PyTorch build compatible with this GPU. Original error: "
                f"{message}"
            ) from exc
        raise

    # GPU-holding handles; pre-bound so the VRAM teardown can del them
    # unconditionally regardless of which stages executed.
    fixed_f32: Any = None
    moving_f32: Any = None
    moments: Any = None
    rigid: Any = None
    affine: Any = None
    reg: Any = None
    inverse_params: Any = None
    inverse_moved: Any = None
    moved: Any = None
    forward_transforms: list[str] = []
    reverse_transforms: list[str] = []
    note = ""

    # ------------------------------------------------------------------
    # Stage 1 — moments (initialization).  Must run in float32.
    # ------------------------------------------------------------------
    init_translation = None
    init_moment = None
    init_rigid = None
    if do_moments:
        if torch_dtype != torch.float32:
            # MomentsRegistration rejects fp16/bf16 — load a float32 pair.
            fixed_f32 = BatchedImages([_load_image(fixed_for_reg, torch.float32)])
            moving_f32 = BatchedImages([_load_image(moving_for_reg, torch.float32)])
            moments_fixed, moments_moving = fixed_f32, moving_f32
        else:
            moments_fixed, moments_moving = fixed, moving

        moments = MomentsRegistration(
            scale=moments_scale,
            fixed_images=moments_fixed,
            moving_images=moments_moving,
            moments=moments_moments,
        )
        moments.optimize()
        init_translation = moments.get_rigid_transl_init()
        init_moment = moments.get_rigid_moment_init()
        init_rigid = moments.get_affine_init()

    # ------------------------------------------------------------------
    # Stage 2 — rigid.
    # ------------------------------------------------------------------
    if do_rigid:
        rigid = RigidRegistration(
            scales=scales,
            iterations=rigid_iterations,
            fixed_images=fixed,
            moving_images=moving,
            loss_type=loss_type_eff,
            optimizer=optimizer,
            optimizer_lr=rigid_lr,
            cc_kernel_size=cc_kernel_size,
            init_translation=init_translation,
            init_moment=init_moment,
            dtype=torch_dtype,
        )
        rigid.optimize()
        init_rigid = rigid.get_rigid_matrix()

    # ------------------------------------------------------------------
    # Stage 3 — affine / linear export, or deformable.
    # ------------------------------------------------------------------
    if transform_type in {"Rigid", "Affine"}:
        if transform_type == "Affine":
            affine = AffineRegistration(
                scales=scales,
                iterations=affine_iterations,
                fixed_images=fixed,
                moving_images=moving,
                loss_type=loss_type_eff,
                optimizer=optimizer,
                optimizer_lr=affine_lr,
                cc_kernel_size=cc_kernel_size,
                init_rigid=init_rigid,
                dtype=torch_dtype,
            )
            affine.optimize()
            linear_stage: Any = affine
            homogeneous = affine.get_affine_matrix().detach().cpu()[0]
        else:
            # Rigid: the rigid stage is the linear transform.  When rigid is
            # disabled, fall back to the moments affine init.
            if rigid is not None:
                linear_stage = rigid
                homogeneous = rigid.get_rigid_matrix().detach().cpu()[0]
            elif moments is not None:
                linear_stage = moments
                homogeneous = moments.get_affine_init().detach().cpu()[0]
                # get_affine_init is [d, d+1]; pad to homogeneous 4x4.
                row = torch.zeros(1, homogeneous.shape[1], dtype=homogeneous.dtype)
                row[0, -1] = 1.0
                homogeneous = torch.cat([homogeneous, row], dim=0)
            else:
                raise ValueError(
                    "Rigid registration requires do_rigid or do_moments to be enabled."
                )

        linear_stage.save_as_ants_transforms(str(forward_path))
        forward_transforms = [str(forward_path)]

        _write_inverse_mat(homogeneous, str(reverse_path))
        reverse_transforms = [str(reverse_path)]

        moved = linear_stage.evaluate(fixed, moving)
        note = (
            "FireANTs native staged linear registration "
            f"({'affine' if transform_type == 'Affine' else 'rigid'}): "
            "exported forward and inverse transforms."
        )
    else:
        # ------------------------------------------------------------------
        # Deformable registration (SyN-family).
        # ------------------------------------------------------------------
        if affine is None and (do_moments or do_rigid):
            # Refine the affine before the deformable stage for best init.
            affine = AffineRegistration(
                scales=scales,
                iterations=affine_iterations,
                fixed_images=fixed,
                moving_images=moving,
                loss_type=loss_type_eff,
                optimizer=optimizer,
                optimizer_lr=affine_lr,
                cc_kernel_size=cc_kernel_size,
                init_rigid=init_rigid,
                dtype=torch_dtype,
            )
            affine.optimize()
            init_affine = affine.get_affine_matrix().detach()
        else:
            init_affine = init_rigid

        if deform_algo == "syn":
            from fireants.registration.syn import SyNRegistration

            deform_cls: Any = SyNRegistration
        else:
            from fireants.registration.greedy import GreedyRegistration

            deform_cls = GreedyRegistration

        reg = deform_cls(
            scales=scales,
            iterations=deformable_iterations,
            fixed_images=fixed,
            moving_images=moving,
            loss_type=loss_type_eff,
            optimizer=optimizer,
            optimizer_lr=deformable_lr,
            cc_kernel_size=cc_kernel_size,
            deformation_type=deformation_type,
            smooth_grad_sigma=1,
            init_affine=init_affine,
            dtype=torch_dtype,
        )
        reg.optimize()
        moved = reg.evaluate(fixed, moving)

        # Single composite forward warp (affine baked in).  This is the
        # high-quality direction; both greedy and syn produce it.
        reg.save_as_ants_transforms(str(forward_path))
        forward_transforms = [str(forward_path)]
        reverse_transforms = [str(reverse_path)]

        if inverse_method == "simpleitk":
            # Numerically invert the forward field (ANTs-free, explicit
            # convergence).  Independent of FireANTs' analytic inverse, so it
            # works for syn as well as greedy.
            _invert_forward_warp_sitk(
                forward_warp_path=str(forward_path),
                moving_ref_path=moving_for_reg,
                template_ref_path=fixed_for_reg,
                reverse_warp_path=str(reverse_path),
                inverse_warped_path=str(inverse_warped_image),
            )
            note = (
                f"FireANTs native staged deformable registration ({deform_algo}): "
                "composite forward warp + inverse (reverse) warp via SimpleITK "
                f"displacement-field inversion of the forward field "
                f"(max_iter={inverse_max_iterations}, tol={inverse_tolerance})."
            )
        else:
            if deform_algo == "syn":
                raise NotImplementedError(
                    "deform_algo='syn' has no analytic inverse warp "
                    "(SyNRegistration.get_inverse_warp_parameters raises). Use "
                    "deform_algo='greedy', or set "
                    "registration.fireants.inverse_method='simpleitk'."
                )

            # Analytic inverse composite warp (FireANTs inverse-consistency solve).
            reg.save_as_ants_transforms(str(reverse_path), save_inverse=True)

            # Inverse-warped QC image (template → patient).
            inverse_params = reg.get_inverse_warp_parameters(fixed, moving)
            inverse_moved = fireants_interpolator(
                fixed(),
                **inverse_params,
                mode=fixed.get_interpolator_type(),
                align_corners=True,
            )
            FakeBatchedImages(inverse_moved, moving).write_image(str(inverse_warped_image))

            note = (
                "FireANTs native staged deformable registration (greedy): single "
                "composite forward warp + analytic inverse warp via "
                "save_as_ants_transforms."
            )

    FakeBatchedImages(moved, fixed).write_image(str(warped_image))

    # Linear transforms produce no inverse-warped QC image; clean up a stale one.
    inverse_path_obj = Path(inverse_warped_image)
    if not is_deformable and inverse_path_obj.exists():
        inverse_path_obj.unlink()

    # Clean up VRAM-cap temp files.
    for tmp_name in [
        f"{patient_id}_fixed_deformable.nii.gz",
        f"{patient_id}_moving_deformable.nii.gz",
    ]:
        (transform_dir_path / tmp_name).unlink(missing_ok=True)

    # Release PyTorch CUDA cache so subsequent GPU tasks (e.g. eddy_cuda,
    # a second patient's fireants) can allocate VRAM without OOM.
    import gc

    del fixed, moving, fixed_f32, moving_f32
    del moments, rigid, affine, reg
    del moved, inverse_params, inverse_moved
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    torch.cuda.empty_cache()

    return str(warped_image), "", forward_transforms, reverse_transforms, note


[docs] def run_fireants_registration_task( fixed_image: str, moving_image: str, warped_image: str, inverse_warped_image: str, transform_dir: str, patient_id: str, transform_type: str, device: str, scales: list[int], affine_iterations: list[int], deformable_iterations: list[int], rigid_iterations: list[int], optimizer: str, affine_lr: float, deformable_lr: float, rigid_lr: float, cc_kernel_size: int, deformation_type: str, dtype: str, do_moments: bool, do_rigid: bool, moments_scale: int, moments_moments: int, deform_algo: str, job_name: str = "patient_to_template", deformable_max_spacing_mm: float | None = None, use_gpu: bool = True, loss_type: str = "cc", normalize: bool = True, inverse_method: str = "simpleitk", inverse_max_iterations: int = 50, inverse_tolerance: float = 0.01, ) -> tuple[str, str, list[str], list[str], str]: """Thin Function-node wrapper around :func:`_run_staged`. Strips the ``_OUTPUT_PATH_SENTINEL`` from the output paths (the sentinel is added by :func:`build_fireants_node` to dodge Nipype's file-content hashing) and delegates to :func:`_run_staged`. Returns: Tuple of warped image path, inverse-warped image path, forward transforms, reverse transforms, and an informational note. """ from thesis.workflows.registration.fireants_backend import _run_staged return _run_staged( fixed_image=fixed_image, moving_image=moving_image, warped_image=warped_image, inverse_warped_image=inverse_warped_image, transform_dir=transform_dir, patient_id=patient_id, transform_type=transform_type, device=device, scales=scales, affine_iterations=affine_iterations, deformable_iterations=deformable_iterations, rigid_iterations=rigid_iterations, optimizer=optimizer, affine_lr=affine_lr, deformable_lr=deformable_lr, rigid_lr=rigid_lr, cc_kernel_size=cc_kernel_size, deformation_type=deformation_type, dtype=dtype, do_moments=do_moments, do_rigid=do_rigid, moments_scale=moments_scale, moments_moments=moments_moments, deform_algo=deform_algo, job_name=job_name, deformable_max_spacing_mm=deformable_max_spacing_mm, loss_type=loss_type, normalize=normalize, inverse_method=inverse_method, inverse_max_iterations=inverse_max_iterations, inverse_tolerance=inverse_tolerance, )
[docs] def run_fireants_cli_task( fixed_image: str, moving_image: str, warped_image: str, inverse_warped_image: str, transform_dir: str, patient_id: str, transform_type: str, device: str, scales: list[int], affine_iterations: list[int], deformable_iterations: list[int], rigid_iterations: list[int], optimizer: str, affine_lr: float, deformable_lr: float, rigid_lr: float, cc_kernel_size: int, deformation_type: str, dtype: str, do_moments: bool, do_rigid: bool, moments_scale: int, moments_moments: int, deform_algo: str, job_name: str = "patient_to_template", deformable_max_spacing_mm: float | None = None, use_gpu: bool = True, loss_type: str = "cc", normalize: bool = True, inverse_method: str = "simpleitk", inverse_max_iterations: int = 50, inverse_tolerance: float = 0.01, ) -> tuple[str, str, list[str], list[str], str]: """Run staged FireANTs registration out-of-process via the CLI (Flavor 2). Strips the ``_OUTPUT_PATH_SENTINEL`` from the output paths (added by :func:`build_fireants_node`), builds the ``thesis-fireants-register`` argv, shells out with :func:`subprocess.run`, and parses the last stdout line as the JSON result. On a non-zero exit whose stderr is the CUDA "no kernel image" failure, the helpful guidance message is re-raised so the user sees the same advice as the in-process path. Returns: Tuple of warped image path, inverse-warped image path, forward transforms, reverse transforms, and an informational note. """ import json import subprocess import sys _NOHASH = "__thesis_nohash__:" if warped_image.startswith(_NOHASH): warped_image = warped_image[len(_NOHASH) :] if inverse_warped_image.startswith(_NOHASH): inverse_warped_image = inverse_warped_image[len(_NOHASH) :] argv: list[str] = [ sys.executable, "-m", "thesis.cli_tools.fireants_register", "--fixed", str(fixed_image), "--moving", str(moving_image), "--warped", str(warped_image), "--inverse-warped", str(inverse_warped_image), "--transform-dir", str(transform_dir), "--patient-id", str(patient_id), "--job-name", str(job_name), "--transform-type", str(transform_type), "--device", str(device), "--dtype", str(dtype), "--optimizer", str(optimizer), "--affine-lr", str(affine_lr), "--deformable-lr", str(deformable_lr), "--rigid-lr", str(rigid_lr), "--cc-kernel-size", str(cc_kernel_size), "--deformation-type", str(deformation_type), "--deform-algo", str(deform_algo), "--loss-type", str(loss_type), "--inverse-method", str(inverse_method), "--inverse-max-iterations", str(inverse_max_iterations), "--inverse-tolerance", str(inverse_tolerance), "--moments-scale", str(moments_scale), "--moments-moments", str(moments_moments), "--scales", *[str(s) for s in scales], "--affine-iterations", *[str(i) for i in affine_iterations], "--deformable-iterations", *[str(i) for i in deformable_iterations], "--rigid-iterations", *[str(i) for i in rigid_iterations], ] if not do_moments: argv.append("--no-moments") if not do_rigid: argv.append("--no-rigid") if not normalize: argv.append("--no-normalize") if deformable_max_spacing_mm is not None: argv += ["--deformable-max-spacing-mm", str(deformable_max_spacing_mm)] completed = subprocess.run(argv, capture_output=True, text=True) if completed.returncode != 0: stderr = completed.stderr or "" if "no kernel image is available for execution on the device" in stderr: raise RuntimeError( "FireANTs CUDA initialization failed because the installed PyTorch build does " "not support this GPU architecture. Set registration.fireants.device='cpu' or " "install a PyTorch build compatible with this GPU. Original error: " f"{stderr}" ) raise RuntimeError( f"thesis-fireants-register failed (exit {completed.returncode}): {stderr}" ) stdout = (completed.stdout or "").strip() last_line = stdout.splitlines()[-1] if stdout else "" result = json.loads(last_line) return ( result["warped_image"], result["inverse_warped_image"], result["forward_transforms"], result["reverse_transforms"], result["note"], )
[docs] def build_fireants_node( config: PipelineConfig, context: ProcessingContext, *, moving: Path | None = None, fixed: Path | None = None, job: ResolvedRegistrationJob | None = None, ) -> Node: """Build the FireANTs registration execution node. Args: config: The merged pipeline configuration. context: The processing context. moving: Optional explicit moving-image override (resolved from *job* otherwise). fixed: Optional explicit fixed-image override (resolved from *job* otherwise). job: The resolved registration job this node implements. When ``None`` the implicit default job is used (legacy behaviour). Returns: The configured FireANTs Nipype Function node. The default job keeps the legacy node name ``"fireants_registration"``; explicit jobs use ``"fireants_<job>"``. """ if job is None: job = resolve_registration_jobs(config)[0] fireants_cfg = job.fireants if fixed is None: fixed = resolve_fixed_image_for_job(config, context, job) if moving is None: moving = resolve_moving_image_for_job(config, context, job) node_name = "fireants_registration" if job.is_default else f"fireants_{job.safe_name}" # Flavor 1 (in-process) is the default; Flavor 2 shells out to the CLI. task_function = ( run_fireants_cli_task if getattr(fireants_cfg, "driver", "inprocess") == "cli" else run_fireants_registration_task ) node = Node( Function( input_names=[ "fixed_image", "moving_image", "warped_image", "inverse_warped_image", "transform_dir", "patient_id", "job_name", "transform_type", "device", "scales", "affine_iterations", "deformable_iterations", "rigid_iterations", "optimizer", "affine_lr", "deformable_lr", "rigid_lr", "cc_kernel_size", "deformation_type", "dtype", "do_moments", "do_rigid", "moments_scale", "moments_moments", "deform_algo", "deformable_max_spacing_mm", "use_gpu", "loss_type", "normalize", "inverse_method", "inverse_max_iterations", "inverse_tolerance", ], output_names=[ "warped_image", "inverse_warped_image", "forward_transforms", "reverse_transforms", "note", ], function=task_function, ), name=node_name, ) node.inputs.fixed_image = str(fixed) node.inputs.moving_image = str(moving) # See module-level comment on _OUTPUT_PATH_SENTINEL for why we prefix. # Short version: trait metadata is dropped by MultiProc's pickle, so the # hash_files=False knob doesn't reach the worker that writes the hashfile. # Prefixing with a non-path token makes os.path.isfile() return False in # the worker too, so the string is hashed as a string regardless of # whether metadata survived. node.inputs.warped_image = _OUTPUT_PATH_SENTINEL + str( get_registration_job_warped_image_path(config, context, job) ) node.inputs.inverse_warped_image = _OUTPUT_PATH_SENTINEL + str( get_registration_job_inverse_warped_image_path(config, context, job) ) node.inputs.transform_dir = str(get_registration_job_transform_dir(config, context, job)) # Keep the trait metadata too — defensive, covers the Linear plugin and # in-process callers that never pickle. Metadata lives on # trait_type._metadata; assigning directly to the CTrait wrapper # (e.g. ``trait(...).hash_files = False``) does not take effect because # Nipype's has_metadata reads from _metadata. for out_field in ("warped_image", "inverse_warped_image"): node.inputs.trait(out_field).trait_type._metadata["hash_files"] = False node.inputs.patient_id = context.patient_id node.inputs.job_name = job.name node.inputs.transform_type = job.transform_type node.inputs.device = fireants_cfg.device node.inputs.scales = list(fireants_cfg.scales) node.inputs.affine_iterations = list(fireants_cfg.affine_iterations) node.inputs.deformable_iterations = list(fireants_cfg.deformable_iterations) node.inputs.rigid_iterations = list(fireants_cfg.rigid_iterations) node.inputs.optimizer = fireants_cfg.optimizer node.inputs.affine_lr = fireants_cfg.affine_lr node.inputs.deformable_lr = fireants_cfg.deformable_lr node.inputs.rigid_lr = fireants_cfg.rigid_lr node.inputs.cc_kernel_size = fireants_cfg.cc_kernel_size node.inputs.deformation_type = fireants_cfg.deformation_type node.inputs.dtype = fireants_cfg.dtype node.inputs.do_moments = fireants_cfg.do_moments node.inputs.do_rigid = fireants_cfg.do_rigid node.inputs.moments_scale = fireants_cfg.moments_scale node.inputs.moments_moments = fireants_cfg.moments_moments node.inputs.deform_algo = fireants_cfg.deform_algo node.inputs.deformable_max_spacing_mm = fireants_cfg.deformable_max_spacing_mm node.inputs.loss_type = fireants_cfg.loss_type node.inputs.normalize = fireants_cfg.normalize node.inputs.inverse_method = getattr(fireants_cfg, "inverse_method", "simpleitk") node.inputs.inverse_max_iterations = getattr(fireants_cfg, "inverse_max_iterations", 50) node.inputs.inverse_tolerance = getattr(fireants_cfg, "inverse_tolerance", 0.01) node._mem_gb = max(4.0, float(getattr(config.hardware, "memory_gb", 16)) / 4.0) # Declare GPU usage so Nipype's MultiProc scheduler counts this node toward # the n_gpu_procs budget. Node.is_gpu_node() checks for use_gpu/use_cuda in # node.inputs; _n_gpus is NOT used by the scheduler for Function nodes. node.inputs.use_gpu = True return node