Source code for thesis.workflows.learned_atlas.model

"""Diffeomorphic registration U-Net and learnable template for the learned atlas.

Compact in-repo VoxelMorph/AtlasMorph model. The network predicts a stationary
velocity field only; scaling-and-squaring integration turns it into a
diffeomorphic deformation, and a spatial transformer resamples a jointly-learned
base template ``T``. The network emits only deformation fields (never density),
so the warped template can relocate existing template mass but cannot
hallucinate new mass.

All tensors are 5D ``(N, 1, X, Y, Z)``; the template is ``(K, 1, X, Y, Z)``.
torch is imported lazily so importing the package is safe without the optional
``ml`` extra; these classes are only used inside the Nipype Function body.
"""

from __future__ import annotations

from typing import Sequence

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    _TORCH_AVAILABLE = True
except ImportError:  # pragma: no cover - exercised only when torch is absent.
    _TORCH_AVAILABLE = False


def _require_torch() -> None:
    """Raise the project DependencyError when torch is unavailable.

    Raises:
        DependencyError: When the optional ``ml`` dependency group is missing.
    """
    if not _TORCH_AVAILABLE:
        from thesis.core.exceptions import DependencyError

        raise DependencyError(
            "The learned_atlas model requires PyTorch. Install with: " "pip install -e '.[ml]'"
        )


if _TORCH_AVAILABLE:

    class ConvBlock(nn.Module):
        """3D conv + LeakyReLU block."""

        def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
            """Initialise the block."""
            super().__init__()
            self.conv = nn.Conv3d(
                in_channels, out_channels, kernel_size=3, stride=stride, padding=1
            )
            self.activation = nn.LeakyReLU(0.2)

        def forward(self, x: "torch.Tensor") -> "torch.Tensor":
            """Apply convolution and activation."""
            return self.activation(self.conv(x))  # type: ignore[no-any-return]

    class SpatialTransformer(nn.Module):
        """Resample a volume with a voxel-displacement field via ``grid_sample``.

        Fully functional: builds the normalised sampling grid out-of-place (no
        in-place index assignment) so autograd works through the integrated
        deformation. An identity voxel grid is cached per (shape, device, dtype).
        """

        def __init__(self, mode: str = "bilinear") -> None:
            """Initialise the transformer."""
            super().__init__()
            self.mode = mode
            self._grid_cache: dict = {}

        def _identity_grid(self, shape: Sequence[int], device, dtype) -> "torch.Tensor":
            """Build (or fetch cached) identity voxel grid ``(1, 3, X, Y, Z)``."""
            key = (tuple(shape), str(device), str(dtype))
            grid = self._grid_cache.get(key)
            if grid is None:
                vectors = [torch.arange(0, s, device=device, dtype=dtype) for s in shape]
                coords = torch.meshgrid(*vectors, indexing="ij")
                grid = torch.stack(coords, dim=0).unsqueeze(0)
                self._grid_cache[key] = grid
            return grid

        def forward(self, src: "torch.Tensor", flow: "torch.Tensor") -> "torch.Tensor":
            """Warp ``src`` by the voxel-displacement field ``flow``.

            Args:
                src: Source volume ``(N, C, X, Y, Z)``.
                flow: Displacement field ``(N, 3, X, Y, Z)`` in voxel units.

            Returns:
                Warped volume ``(N, C, X, Y, Z)``.
            """
            shape = src.shape[2:]
            grid = self._identity_grid(shape, src.device, src.dtype)
            new_locs = grid + flow  # (N, 3, X, Y, Z)

            # Out-of-place normalisation to [-1, 1]: size[i] = max(dim_i - 1, 1).
            size = torch.tensor(
                [max(int(s) - 1, 1) for s in shape],
                device=src.device,
                dtype=src.dtype,
            ).view(1, 3, 1, 1, 1)
            new_locs = 2.0 * (new_locs / size) - 1.0

            # grid_sample expects (N, X, Y, Z, 3) with the last spatial axis
            # first: reorder channels (z, y, x).
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]
            return F.grid_sample(
                src, new_locs, mode=self.mode, padding_mode="border", align_corners=True
            )

    class VecInt(nn.Module):
        """Scaling-and-squaring integration of a stationary velocity field."""

        def __init__(self, transformer: "SpatialTransformer", n_steps: int = 7) -> None:
            """Initialise the integrator with a SHARED transformer (one grid cache)."""
            super().__init__()
            if n_steps < 0:
                raise ValueError(f"n_steps must be >= 0, got {n_steps}.")
            self.n_steps = n_steps
            self.scale = 1.0 / (2**n_steps)
            self.transformer = transformer

        def forward(self, velocity: "torch.Tensor") -> "torch.Tensor":
            """Integrate the velocity field into a diffeomorphic deformation."""
            disp = velocity * self.scale
            for _ in range(self.n_steps):
                disp = disp + self.transformer(disp, disp)
            return disp  # type: ignore[no-any-return]

    class UNet3D(nn.Module):
        """Compact VoxelMorph-style 3D U-Net producing a velocity field.

        The flow head is zero-initialised so the network starts at the identity
        transform (guarantees a fold-free start). Skip bookkeeping uses the same
        n-1 skip list in ``__init__`` and ``forward`` so non-default enc/dec
        widths do not mis-pair channels.
        """

        def __init__(
            self,
            in_channels: int = 2,
            enc_channels: Sequence[int] = (16, 32, 32, 32),
            dec_channels: Sequence[int] = (32, 32, 32, 16, 16),
            ndims: int = 3,
        ) -> None:
            """Initialise encoder/decoder and the zero-init velocity head."""
            super().__init__()
            self.ndims = ndims

            self.encoders = nn.ModuleList()
            prev = in_channels
            for i, ch in enumerate(enc_channels):
                stride = 1 if i == 0 else 2
                self.encoders.append(ConvBlock(prev, ch, stride=stride))
                prev = ch

            # Skips are produced by the n-1 non-bottleneck encoders, reversed.
            skip_channels = list(enc_channels[:-1][::-1])
            self.decoders = nn.ModuleList()
            for i, ch in enumerate(dec_channels):
                skip = skip_channels[i] if i < len(skip_channels) else 0
                self.decoders.append(ConvBlock(prev + skip, ch))
                prev = ch

            self.flow = nn.Conv3d(prev, ndims, kernel_size=3, padding=1)
            nn.init.zeros_(self.flow.weight)
            nn.init.zeros_(self.flow.bias)  # type: ignore[arg-type]

        def forward(self, moving: "torch.Tensor", fixed: "torch.Tensor") -> "torch.Tensor":
            """Predict the stationary velocity field aligning ``moving`` to ``fixed``."""
            x = torch.cat([moving, fixed], dim=1)

            skips: list = []
            for i, enc in enumerate(self.encoders):
                x = enc(x)
                if i < len(self.encoders) - 1:
                    skips.append(x)
            skips = skips[::-1]

            for i, dec in enumerate(self.decoders):
                if i < len(skips):
                    skip = skips[i]
                    x = F.interpolate(x, size=skip.shape[2:], mode="trilinear", align_corners=False)
                    x = torch.cat([x, skip], dim=1)
                x = dec(x)
            return self.flow(x)  # type: ignore[no-any-return]

    class AssignmentHead(nn.Module):
        """OPTIONAL Phase-2: soft assignment over K templates (bypassed for K=1)."""

        def __init__(self, in_channels: int, n_templates: int) -> None:
            """Initialise the assignment head."""
            super().__init__()
            self.pool = nn.AdaptiveAvgPool3d(1)
            self.fc = nn.Linear(in_channels, n_templates)

        def forward(self, fixed: "torch.Tensor") -> "torch.Tensor":
            """Return soft assignment weights ``(N, K)`` summing to 1."""
            pooled = self.pool(fixed).flatten(1)
            return F.softmax(self.fc(pooled), dim=1)

    class LearnedAtlasModel(nn.Module):
        """Learned conditional deformable atlas: template(s) + diffeomorphic U-Net."""

        def __init__(
            self,
            volume_shape: Sequence[int],
            n_templates: int = 1,
            int_steps: int = 7,
            enc_features: Sequence[int] = (16, 32, 32, 32),
            dec_features: Sequence[int] = (32, 32, 32, 16, 16),
            init_template: "torch.Tensor | None" = None,
        ) -> None:
            """Initialise the learned atlas model.

            Args:
                volume_shape: Spatial shape ``(X, Y, Z)``.
                n_templates: Number of templates K (1 = single sharp template).
                int_steps: Scaling-and-squaring integration steps.
                enc_features: U-Net encoder widths.
                dec_features: U-Net decoder widths.
                init_template: Optional ``(X, Y, Z)`` / ``(K, X, Y, Z)`` init.

            Raises:
                ValueError: On invalid ``n_templates`` / ``volume_shape``.
            """
            super().__init__()
            _require_torch()
            if n_templates < 1:
                raise ValueError(f"n_templates must be >= 1, got {n_templates}.")
            if len(volume_shape) != 3:
                raise ValueError(f"volume_shape must be (X, Y, Z), got {tuple(volume_shape)}.")

            self.volume_shape = tuple(int(s) for s in volume_shape)
            self.n_templates = int(n_templates)

            self.template = nn.Parameter(self._build_template_init(init_template))
            self.unet = UNet3D(
                in_channels=2, enc_channels=tuple(enc_features), dec_channels=tuple(dec_features)
            )
            # ONE shared SpatialTransformer (single grid cache) for both the
            # integrator and the final warp.
            self.warp = SpatialTransformer(mode="bilinear")
            self.integrate = VecInt(self.warp, n_steps=int_steps)
            self.assignment = (
                AssignmentHead(in_channels=1, n_templates=self.n_templates)
                if self.n_templates > 1
                else None
            )

        def _build_template_init(self, init_template: "torch.Tensor | None") -> "torch.Tensor":
            """Create the initial template tensor ``(K, 1, X, Y, Z)``."""
            x, y, z = self.volume_shape
            if init_template is None:
                return torch.rand(self.n_templates, 1, x, y, z) * 1e-3
            t = init_template.float()
            if t.dim() == 3:
                t = t.unsqueeze(0)
            if t.dim() == 4:
                t = t.unsqueeze(1)
            if t.shape[0] == 1 and self.n_templates > 1:
                t = t.repeat(self.n_templates, 1, 1, 1, 1)
            if tuple(t.shape) != (self.n_templates, 1, x, y, z):
                raise ValueError(
                    f"init_template shape {tuple(init_template.shape)} incompatible with "
                    f"({self.n_templates}, 1, {x}, {y}, {z})."
                )
            return t.contiguous()

        def _select_template(self, subject: "torch.Tensor") -> "torch.Tensor":
            """Return the per-subject template ``(N, 1, X, Y, Z)``."""
            n = subject.shape[0]
            if self.n_templates == 1 or self.assignment is None:
                return self.template[0:1].expand(n, -1, -1, -1, -1)
            weights = self.assignment(subject)
            blended = (
                weights.view(n, self.n_templates, 1, 1, 1, 1) * self.template.unsqueeze(0)
            ).sum(dim=1)
            return blended  # type: ignore[no-any-return]

        def forward(self, subject: "torch.Tensor") -> dict:
            """Warp the template onto a subject volume."""
            template = self._select_template(subject)
            velocity = self.unet(template, subject)
            deformation = self.integrate(velocity)
            warped = self.warp(template, deformation)
            return {
                "warped_template": warped,
                "velocity": velocity,
                "deformation": deformation,
                "template": template,
            }

        @torch.no_grad()
        def export_template(self) -> "torch.Tensor":
            """Return the learned template(s) as contiguous ``(K, X, Y, Z)`` float32."""
            return self.template.detach().squeeze(1).float().contiguous().cpu()


[docs] def jacobian_determinant(deformation: "torch.Tensor") -> "torch.Tensor": """Voxel-space Jacobian determinant of ``phi(x) = x + deformation(x)``. NOTE: voxel-space only (internal fold-free check). The displacement field is in VOXEL units, so this determinant must NOT be fed to ``antsApplyTransforms`` or compared against ANTs/SimpleITK physical-space warps: physical-space validation requires first composing the field with the reference affine. Args: deformation: Displacement field ``(N, 3, X, Y, Z)`` in voxel units. Returns: Determinant ``(N, X-1, Y-1, Z-1)``. Raises: DependencyError: When torch is unavailable. """ _require_torch() disp = deformation grad_x = disp[:, :, 1:, :-1, :-1] - disp[:, :, :-1, :-1, :-1] grad_y = disp[:, :, :-1, 1:, :-1] - disp[:, :, :-1, :-1, :-1] grad_z = disp[:, :, :-1, :-1, 1:] - disp[:, :, :-1, :-1, :-1] j00 = grad_x[:, 0] + 1.0 j01 = grad_y[:, 0] j02 = grad_z[:, 0] j10 = grad_x[:, 1] j11 = grad_y[:, 1] + 1.0 j12 = grad_z[:, 1] j20 = grad_x[:, 2] j21 = grad_y[:, 2] j22 = grad_z[:, 2] + 1.0 return ( j00 * (j11 * j22 - j12 * j21) - j01 * (j10 * j22 - j12 * j20) + j02 * (j10 * j21 - j11 * j20) )
[docs] def count_negative_jacobians(deformation: "torch.Tensor") -> int: """Return the number of voxels with non-positive (folding) Jacobian det (voxel-space).""" det = jacobian_determinant(deformation) return int((det <= 0).sum().item())