"""Diffeomorphic registration U-Net and learnable template for the learned atlas.
Compact in-repo VoxelMorph/AtlasMorph model. The network predicts a stationary
velocity field only; scaling-and-squaring integration turns it into a
diffeomorphic deformation, and a spatial transformer resamples a jointly-learned
base template ``T``. The network emits only deformation fields (never density),
so the warped template can relocate existing template mass but cannot
hallucinate new mass.
All tensors are 5D ``(N, 1, X, Y, Z)``; the template is ``(K, 1, X, Y, Z)``.
torch is imported lazily so importing the package is safe without the optional
``ml`` extra; these classes are only used inside the Nipype Function body.
"""
from __future__ import annotations
from typing import Sequence
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
_TORCH_AVAILABLE = True
except ImportError: # pragma: no cover - exercised only when torch is absent.
_TORCH_AVAILABLE = False
def _require_torch() -> None:
"""Raise the project DependencyError when torch is unavailable.
Raises:
DependencyError: When the optional ``ml`` dependency group is missing.
"""
if not _TORCH_AVAILABLE:
from thesis.core.exceptions import DependencyError
raise DependencyError(
"The learned_atlas model requires PyTorch. Install with: " "pip install -e '.[ml]'"
)
if _TORCH_AVAILABLE:
class ConvBlock(nn.Module):
"""3D conv + LeakyReLU block."""
def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
"""Initialise the block."""
super().__init__()
self.conv = nn.Conv3d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1
)
self.activation = nn.LeakyReLU(0.2)
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
"""Apply convolution and activation."""
return self.activation(self.conv(x)) # type: ignore[no-any-return]
class SpatialTransformer(nn.Module):
"""Resample a volume with a voxel-displacement field via ``grid_sample``.
Fully functional: builds the normalised sampling grid out-of-place (no
in-place index assignment) so autograd works through the integrated
deformation. An identity voxel grid is cached per (shape, device, dtype).
"""
def __init__(self, mode: str = "bilinear") -> None:
"""Initialise the transformer."""
super().__init__()
self.mode = mode
self._grid_cache: dict = {}
def _identity_grid(self, shape: Sequence[int], device, dtype) -> "torch.Tensor":
"""Build (or fetch cached) identity voxel grid ``(1, 3, X, Y, Z)``."""
key = (tuple(shape), str(device), str(dtype))
grid = self._grid_cache.get(key)
if grid is None:
vectors = [torch.arange(0, s, device=device, dtype=dtype) for s in shape]
coords = torch.meshgrid(*vectors, indexing="ij")
grid = torch.stack(coords, dim=0).unsqueeze(0)
self._grid_cache[key] = grid
return grid
def forward(self, src: "torch.Tensor", flow: "torch.Tensor") -> "torch.Tensor":
"""Warp ``src`` by the voxel-displacement field ``flow``.
Args:
src: Source volume ``(N, C, X, Y, Z)``.
flow: Displacement field ``(N, 3, X, Y, Z)`` in voxel units.
Returns:
Warped volume ``(N, C, X, Y, Z)``.
"""
shape = src.shape[2:]
grid = self._identity_grid(shape, src.device, src.dtype)
new_locs = grid + flow # (N, 3, X, Y, Z)
# Out-of-place normalisation to [-1, 1]: size[i] = max(dim_i - 1, 1).
size = torch.tensor(
[max(int(s) - 1, 1) for s in shape],
device=src.device,
dtype=src.dtype,
).view(1, 3, 1, 1, 1)
new_locs = 2.0 * (new_locs / size) - 1.0
# grid_sample expects (N, X, Y, Z, 3) with the last spatial axis
# first: reorder channels (z, y, x).
new_locs = new_locs.permute(0, 2, 3, 4, 1)
new_locs = new_locs[..., [2, 1, 0]]
return F.grid_sample(
src, new_locs, mode=self.mode, padding_mode="border", align_corners=True
)
class VecInt(nn.Module):
"""Scaling-and-squaring integration of a stationary velocity field."""
def __init__(self, transformer: "SpatialTransformer", n_steps: int = 7) -> None:
"""Initialise the integrator with a SHARED transformer (one grid cache)."""
super().__init__()
if n_steps < 0:
raise ValueError(f"n_steps must be >= 0, got {n_steps}.")
self.n_steps = n_steps
self.scale = 1.0 / (2**n_steps)
self.transformer = transformer
def forward(self, velocity: "torch.Tensor") -> "torch.Tensor":
"""Integrate the velocity field into a diffeomorphic deformation."""
disp = velocity * self.scale
for _ in range(self.n_steps):
disp = disp + self.transformer(disp, disp)
return disp # type: ignore[no-any-return]
class UNet3D(nn.Module):
"""Compact VoxelMorph-style 3D U-Net producing a velocity field.
The flow head is zero-initialised so the network starts at the identity
transform (guarantees a fold-free start). Skip bookkeeping uses the same
n-1 skip list in ``__init__`` and ``forward`` so non-default enc/dec
widths do not mis-pair channels.
"""
def __init__(
self,
in_channels: int = 2,
enc_channels: Sequence[int] = (16, 32, 32, 32),
dec_channels: Sequence[int] = (32, 32, 32, 16, 16),
ndims: int = 3,
) -> None:
"""Initialise encoder/decoder and the zero-init velocity head."""
super().__init__()
self.ndims = ndims
self.encoders = nn.ModuleList()
prev = in_channels
for i, ch in enumerate(enc_channels):
stride = 1 if i == 0 else 2
self.encoders.append(ConvBlock(prev, ch, stride=stride))
prev = ch
# Skips are produced by the n-1 non-bottleneck encoders, reversed.
skip_channels = list(enc_channels[:-1][::-1])
self.decoders = nn.ModuleList()
for i, ch in enumerate(dec_channels):
skip = skip_channels[i] if i < len(skip_channels) else 0
self.decoders.append(ConvBlock(prev + skip, ch))
prev = ch
self.flow = nn.Conv3d(prev, ndims, kernel_size=3, padding=1)
nn.init.zeros_(self.flow.weight)
nn.init.zeros_(self.flow.bias) # type: ignore[arg-type]
def forward(self, moving: "torch.Tensor", fixed: "torch.Tensor") -> "torch.Tensor":
"""Predict the stationary velocity field aligning ``moving`` to ``fixed``."""
x = torch.cat([moving, fixed], dim=1)
skips: list = []
for i, enc in enumerate(self.encoders):
x = enc(x)
if i < len(self.encoders) - 1:
skips.append(x)
skips = skips[::-1]
for i, dec in enumerate(self.decoders):
if i < len(skips):
skip = skips[i]
x = F.interpolate(x, size=skip.shape[2:], mode="trilinear", align_corners=False)
x = torch.cat([x, skip], dim=1)
x = dec(x)
return self.flow(x) # type: ignore[no-any-return]
class AssignmentHead(nn.Module):
"""OPTIONAL Phase-2: soft assignment over K templates (bypassed for K=1)."""
def __init__(self, in_channels: int, n_templates: int) -> None:
"""Initialise the assignment head."""
super().__init__()
self.pool = nn.AdaptiveAvgPool3d(1)
self.fc = nn.Linear(in_channels, n_templates)
def forward(self, fixed: "torch.Tensor") -> "torch.Tensor":
"""Return soft assignment weights ``(N, K)`` summing to 1."""
pooled = self.pool(fixed).flatten(1)
return F.softmax(self.fc(pooled), dim=1)
class LearnedAtlasModel(nn.Module):
"""Learned conditional deformable atlas: template(s) + diffeomorphic U-Net."""
def __init__(
self,
volume_shape: Sequence[int],
n_templates: int = 1,
int_steps: int = 7,
enc_features: Sequence[int] = (16, 32, 32, 32),
dec_features: Sequence[int] = (32, 32, 32, 16, 16),
init_template: "torch.Tensor | None" = None,
) -> None:
"""Initialise the learned atlas model.
Args:
volume_shape: Spatial shape ``(X, Y, Z)``.
n_templates: Number of templates K (1 = single sharp template).
int_steps: Scaling-and-squaring integration steps.
enc_features: U-Net encoder widths.
dec_features: U-Net decoder widths.
init_template: Optional ``(X, Y, Z)`` / ``(K, X, Y, Z)`` init.
Raises:
ValueError: On invalid ``n_templates`` / ``volume_shape``.
"""
super().__init__()
_require_torch()
if n_templates < 1:
raise ValueError(f"n_templates must be >= 1, got {n_templates}.")
if len(volume_shape) != 3:
raise ValueError(f"volume_shape must be (X, Y, Z), got {tuple(volume_shape)}.")
self.volume_shape = tuple(int(s) for s in volume_shape)
self.n_templates = int(n_templates)
self.template = nn.Parameter(self._build_template_init(init_template))
self.unet = UNet3D(
in_channels=2, enc_channels=tuple(enc_features), dec_channels=tuple(dec_features)
)
# ONE shared SpatialTransformer (single grid cache) for both the
# integrator and the final warp.
self.warp = SpatialTransformer(mode="bilinear")
self.integrate = VecInt(self.warp, n_steps=int_steps)
self.assignment = (
AssignmentHead(in_channels=1, n_templates=self.n_templates)
if self.n_templates > 1
else None
)
def _build_template_init(self, init_template: "torch.Tensor | None") -> "torch.Tensor":
"""Create the initial template tensor ``(K, 1, X, Y, Z)``."""
x, y, z = self.volume_shape
if init_template is None:
return torch.rand(self.n_templates, 1, x, y, z) * 1e-3
t = init_template.float()
if t.dim() == 3:
t = t.unsqueeze(0)
if t.dim() == 4:
t = t.unsqueeze(1)
if t.shape[0] == 1 and self.n_templates > 1:
t = t.repeat(self.n_templates, 1, 1, 1, 1)
if tuple(t.shape) != (self.n_templates, 1, x, y, z):
raise ValueError(
f"init_template shape {tuple(init_template.shape)} incompatible with "
f"({self.n_templates}, 1, {x}, {y}, {z})."
)
return t.contiguous()
def _select_template(self, subject: "torch.Tensor") -> "torch.Tensor":
"""Return the per-subject template ``(N, 1, X, Y, Z)``."""
n = subject.shape[0]
if self.n_templates == 1 or self.assignment is None:
return self.template[0:1].expand(n, -1, -1, -1, -1)
weights = self.assignment(subject)
blended = (
weights.view(n, self.n_templates, 1, 1, 1, 1) * self.template.unsqueeze(0)
).sum(dim=1)
return blended # type: ignore[no-any-return]
def forward(self, subject: "torch.Tensor") -> dict:
"""Warp the template onto a subject volume."""
template = self._select_template(subject)
velocity = self.unet(template, subject)
deformation = self.integrate(velocity)
warped = self.warp(template, deformation)
return {
"warped_template": warped,
"velocity": velocity,
"deformation": deformation,
"template": template,
}
@torch.no_grad()
def export_template(self) -> "torch.Tensor":
"""Return the learned template(s) as contiguous ``(K, X, Y, Z)`` float32."""
return self.template.detach().squeeze(1).float().contiguous().cpu()
[docs]
def jacobian_determinant(deformation: "torch.Tensor") -> "torch.Tensor":
"""Voxel-space Jacobian determinant of ``phi(x) = x + deformation(x)``.
NOTE: voxel-space only (internal fold-free check). The displacement field is
in VOXEL units, so this determinant must NOT be fed to ``antsApplyTransforms``
or compared against ANTs/SimpleITK physical-space warps: physical-space
validation requires first composing the field with the reference affine.
Args:
deformation: Displacement field ``(N, 3, X, Y, Z)`` in voxel units.
Returns:
Determinant ``(N, X-1, Y-1, Z-1)``.
Raises:
DependencyError: When torch is unavailable.
"""
_require_torch()
disp = deformation
grad_x = disp[:, :, 1:, :-1, :-1] - disp[:, :, :-1, :-1, :-1]
grad_y = disp[:, :, :-1, 1:, :-1] - disp[:, :, :-1, :-1, :-1]
grad_z = disp[:, :, :-1, :-1, 1:] - disp[:, :, :-1, :-1, :-1]
j00 = grad_x[:, 0] + 1.0
j01 = grad_y[:, 0]
j02 = grad_z[:, 0]
j10 = grad_x[:, 1]
j11 = grad_y[:, 1] + 1.0
j12 = grad_z[:, 1]
j20 = grad_x[:, 2]
j21 = grad_y[:, 2]
j22 = grad_z[:, 2] + 1.0
return (
j00 * (j11 * j22 - j12 * j21)
- j01 * (j10 * j22 - j12 * j20)
+ j02 * (j10 * j21 - j11 * j20)
)
[docs]
def count_negative_jacobians(deformation: "torch.Tensor") -> int:
"""Return the number of voxels with non-positive (folding) Jacobian det (voxel-space)."""
det = jacobian_determinant(deformation)
return int((det <= 0).sum().item())