"""Standalone ANTs transform application workflow.
Applies one or more transform jobs defined under ``transforms.jobs`` in
configuration to a patient, warping images from template space into patient
space (or the reverse direction).
Register as ``thesis run -w transform -p <patient_id> -c <config>``.
"""
from __future__ import annotations
import shutil
from pathlib import Path
from typing import List, Tuple, Union
from nipype import Node, Workflow
from nipype.interfaces.utility import Function, IdentityInterface, Merge
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.decorators import verify, workflow
from thesis.core.logging import get_logger
from thesis.core.naming import safe_node_name as _safe_node_name
from ..registration.paths import get_registration_job_transform_paths
from .operations import (
build_output_filename,
get_transform_type_from_paths,
validate_transform_inputs,
)
logger = get_logger(__name__)
__all__ = ["build_workflow", "verify_requirements"]
_PATIENT_SPACE = "patient"
_TEMPLATE_SPACE = "template"
# See registration.fireants_backend for the full rationale. Short version:
# Nipype's Function interface content-hashes every string input that
# os.path.isfile() says is a real file, including output paths this node
# itself overwrites every run. ``hash_files=False`` trait metadata gets
# dropped by MultiProc's pickle round-trip; prefixing the path with a
# non-filesystem-path sentinel makes os.path.isfile return False in the
# worker regardless of metadata, so the string is hashed verbatim.
_OUTPUT_PATH_SENTINEL = "__thesis_nohash__:"
_APPLY_INPUTS = [
"input_file",
"output_file",
"transforms",
"reference_image",
"interpolation",
"_ordering_signal",
"_ordering_signal_preprocess",
]
def _apply_transform_task(
input_file: str,
output_file: str,
transforms: list,
reference_image: str,
interpolation: str,
_ordering_signal: str = "",
_ordering_signal_preprocess: str = "",
) -> str:
"""Nipype Function node: apply ANTs transforms to one image."""
from pathlib import Path
from thesis.workflows.transforms.operations import apply_transform_ants
_NOHASH = "__thesis_nohash__:"
if output_file.startswith(_NOHASH):
output_file = output_file[len(_NOHASH) :]
return str(
apply_transform_ants(
input_image=Path(input_file),
output_image=Path(output_file),
transforms=[Path(t) for t in transforms],
reference_image=Path(reference_image),
interpolation=interpolation,
)
)
def _fmt(value: str, patient_id: str) -> str:
"""Substitute ``{patient_id}`` in *value* if present."""
return value.format(patient_id=patient_id) if value and "{patient_id}" in value else value
def _chain(raw: Union[str, List[str], None], patient_id: str) -> List[str]:
"""Format and return a transform chain as a list of resolved path strings."""
if isinstance(raw, list):
return [_fmt(v, patient_id) for v in raw if v]
return [_fmt(raw, patient_id)] if raw else []
def _absolutize_reference(path: str, output_base: Path) -> str:
"""Anchor a relative path (reference image *or* input image) at the trial root.
Per-job ``reference_image`` overrides and ``input_files`` entries in protocol
YAML are often written relative to the trial root (e.g.
``outputs/cohort/atlas/atlas_mean.nii.gz``), so the config does not hard-code a
specific trial directory. ``output_base`` is ``<trial>/outputs/<pid>``, so the
trial root is ``output_base.parent.parent``. Anchoring here makes ``.exists()``
work from any Nipype worker CWD (workers ``chdir`` into their node dir at
runtime, so a bare relative path would not resolve). Absolute paths — external
assets such as masks or the template — pass through unchanged.
"""
if not path:
return path
p = Path(path)
if p.is_absolute():
return path
return str((output_base.parent.parent / p).resolve())
def _resolve_job(
config: PipelineConfig,
job,
patient_id: str,
output_base: Path,
context: ProcessingContext,
) -> Tuple[List[str], str, str, str]:
"""Resolve (transform_chain, reference, target_space, raw_reference) for a job.
When ``job.from_registration`` is set the transform chain is read directly
from the produced registration-job transforms on disk (via
:func:`get_registration_job_transform_paths`) instead of the static
``transforms.*`` chains: the job's ``direction`` selects forward
(``patient_to_template``) or reverse (``template_to_patient``) chain. The
reference image still resolves from the per-job / global config.
"""
if job.direction == "template_to_patient":
raw_reference = job.reference_image or config.transforms.reference_image
target_space = _PATIENT_SPACE
else:
raw_reference = job.reference_image or config.transforms.template_reference_image
target_space = _TEMPLATE_SPACE
from_registration = getattr(job, "from_registration", None)
if from_registration:
chain = get_registration_job_transform_paths(
config, context, from_registration, job.direction
)
else:
raw_chain = (
config.transforms.template_to_patient
if job.direction == "template_to_patient"
else config.transforms.patient_to_template
)
chain = _chain(raw_chain, patient_id)
reference = _absolutize_reference(_fmt(raw_reference or "", patient_id), output_base)
return chain, reference, target_space, raw_reference or ""
def _output_base(context: ProcessingContext) -> Path:
"""Return the patient-output base directory."""
return Path(context.output_dir) if context.output_dir else Path("outputs") / context.patient_id
[docs]
def verify_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]:
"""Preflight checks: ANTs binary, jobs configured, transform/reference paths exist."""
errors: List[str] = []
if shutil.which("antsApplyTransforms") is None:
errors.append("ANTs 'antsApplyTransforms' not found on PATH.")
if not config.transforms.jobs:
errors.append(
"No transform jobs defined. Add at least one entry under "
"'transforms.jobs' in your configuration."
)
return errors
patient_id = context.patient_id
output_base = _output_base(context)
for job in config.transforms.jobs:
prefix = f"Job '{job.name}'"
chain, reference, _, raw_reference = _resolve_job(
config, job, patient_id, output_base, context
)
# Jobs driven by a registration job read their transform chain from
# produced registration outputs (which do not exist until registration
# runs), so skip the static-chain existence checks for those.
from_registration = getattr(job, "from_registration", None)
if job.direction == "template_to_patient":
if not from_registration and not config.transforms.template_to_patient:
errors.append(
f"{prefix}: direction is 'template_to_patient' but "
"'transforms.template_to_patient' is not configured."
)
if not raw_reference:
errors.append(
f"{prefix}: direction is 'template_to_patient' but neither "
"'job.reference_image' nor 'transforms.reference_image' is configured."
)
else:
if not from_registration and not config.transforms.patient_to_template:
errors.append(
f"{prefix}: direction is 'patient_to_template' but "
"'transforms.patient_to_template' is not configured."
)
if not raw_reference:
errors.append(
f"{prefix}: direction is 'patient_to_template' but neither "
"'job.reference_image' nor 'transforms.template_reference_image' "
"is configured."
)
if not job.input_files:
errors.append(f"{prefix}: 'input_files' is empty — no images to transform.")
continue
resolved_inputs = [
Path(_absolutize_reference(_fmt(f, patient_id), output_base)) for f in job.input_files
]
# from_registration chains point at transforms produced at runtime by
# the registration stage; they do not exist yet at preflight, so do not
# validate their existence (input + reference are still checked).
transform_paths = [] if from_registration else [Path(t) for t in chain]
validation_errors = validate_transform_inputs(
input_files=resolved_inputs,
transform_paths=transform_paths,
reference_image=Path(reference),
)
errors.extend(f"{prefix}: {e}" for e in validation_errors)
return errors
def _referenced_registration_jobs(config: PipelineConfig) -> List[str]:
"""Return the distinct registration-job names referenced by transform jobs.
Preserves first-seen order so the gate fields and edges are deterministic.
"""
seen: List[str] = []
for job in config.transforms.jobs:
name = getattr(job, "from_registration", None)
if name and name not in seen:
seen.append(name)
return seen
def _build_entry_gate(config: PipelineConfig, patient_id: str) -> Node:
"""Build the fixed-name entry gate and seed it with static config values.
Besides the legacy ``forward_transforms``/``reverse_transforms`` fields (the
default ``patient_to_template`` registration job), the gate gains a
``forward_transforms__<regjob>`` / ``reverse_transforms__<regjob>`` pair for
every explicit registration job referenced by ``transforms.jobs[*]
.from_registration``. When embedded in full_pipeline, the meta-workflow
drives these fields from the matching registration outputnode field.
"""
# Anchor for cross-workflow ordering and runtime transform injection. When
# embedded in full_pipeline with FireANTs, the meta-workflow connects the
# registration node's forward/reverse transforms and the preprocess-produced
# reference_image here; those values override the static values seeded
# below. ``preprocess_done`` is an ordering-only anchor that gates per-job
# reference_image overrides on preprocess completion.
fields = [
"ready",
"forward_transforms",
"reverse_transforms",
"reference_image",
"preprocess_done",
]
for regjob in _referenced_registration_jobs(config):
if regjob == "patient_to_template":
continue
fields.append(f"forward_transforms__{regjob}")
fields.append(f"reverse_transforms__{regjob}")
gate = Node(IdentityInterface(fields=fields), name="entry_gate")
if config.transforms.template_to_patient:
gate.inputs.reverse_transforms = _chain(config.transforms.template_to_patient, patient_id)
if config.transforms.patient_to_template:
gate.inputs.forward_transforms = _chain(config.transforms.patient_to_template, patient_id)
if config.transforms.reference_image:
gate.inputs.reference_image = _fmt(config.transforms.reference_image, patient_id)
return gate
[docs]
@workflow(
name="transform",
description=(
"Apply pre-computed ANTs transforms to images defined in "
"config.transforms.jobs. Supports template→patient and "
"patient→template directions."
),
)
@verify(verify_requirements)
def build_workflow(*, config: PipelineConfig, context: ProcessingContext) -> Workflow:
"""Build a transform workflow from ``config.transforms.jobs``."""
if not config.transforms.jobs:
raise ValueError(
"No transform jobs defined. Add at least one entry under "
"'transforms.jobs' in your configuration."
)
wf = Workflow(name=f"transform_{context.patient_id}")
if context.working_dir:
wf.base_dir = str(context.working_dir)
patient_id = context.patient_id
output_base = _output_base(context)
entry_gate = _build_entry_gate(config, patient_id)
wf.add_nodes([entry_gate])
transform_nodes: List[Node] = []
for job in config.transforms.jobs:
chain, reference, target_space, _ = _resolve_job(
config, job, patient_id, output_base, context
)
transform_type = get_transform_type_from_paths([Path(t) for t in chain])
out_dir = output_base / job.output_subdir
out_dir.mkdir(parents=True, exist_ok=True)
for raw_input in job.input_files:
# Anchor relative input_files (e.g. a cohort atlas at
# outputs/cohort/atlas/atlas_mean.nii.gz) at the trial root, the same
# way per-job reference_image is anchored. Absolute external inputs
# (masks, templates) pass through unchanged. This lets the config use
# trial-relative paths instead of hard-coding a specific trial dir.
input_path = Path(_absolutize_reference(_fmt(raw_input, patient_id), output_base))
output_name = build_output_filename(
input_path, transform_type, job.direction, target_space
)
output_path = out_dir / output_name
node_name = _safe_node_name(f"transform_{job.name}_{input_path.stem}")
node = Node(
Function(
input_names=_APPLY_INPUTS,
output_names=["transformed_image"],
function=_apply_transform_task,
),
name=node_name,
)
node.inputs.input_file = str(input_path)
# Prefix output_file with the no-hash sentinel: this is an OUTPUT
# path that this node overwrites every run. Without the prefix,
# Nipype content-hashes the file bytes as part of the input hash,
# and .nii.gz gzip timestamps self-invalidate the cache. The
# task strips the prefix at function entry.
node.inputs.output_file = _OUTPUT_PATH_SENTINEL + str(output_path)
node.inputs.transforms = chain
node.inputs.reference_image = reference
node.inputs.interpolation = job.interpolation
wf.add_nodes([node])
transform_nodes.append(node)
wf.connect(entry_gate, "ready", node, "_ordering_signal")
wf.connect(entry_gate, "preprocess_done", node, "_ordering_signal_preprocess")
# Runtime overrides: a meta-workflow (full_pipeline) connects the
# gate's forward/reverse_transforms + reference_image from upstream
# nodes. Per-job reference_image override is authoritative; skip
# the gate connection so it isn't clobbered.
#
# The transform-chain source depends on the chain direction AND on
# which registration job (if any) drives this transform job. For
# the legacy default registration job ("patient_to_template", or no
# from_registration at all) we use the unsuffixed gate fields; for
# an explicit referenced job we use the per-job suffixed fields.
from_registration = getattr(job, "from_registration", None)
if from_registration and from_registration != "patient_to_template":
fwd_field = f"forward_transforms__{from_registration}"
rev_field = f"reverse_transforms__{from_registration}"
else:
fwd_field = "forward_transforms"
rev_field = "reverse_transforms"
if job.direction == "template_to_patient":
wf.connect(entry_gate, rev_field, node, "transforms")
if not job.reference_image:
wf.connect(entry_gate, "reference_image", node, "reference_image")
else:
wf.connect(entry_gate, fwd_field, node, "transforms")
logger.debug("Transform node '{}': {} → {}", node_name, input_path, output_path)
# Fixed-name sink aggregating every transform_* output into a single
# "all transforms finished" signal for meta-workflow consumers.
if transform_nodes:
merge = Node(Merge(len(transform_nodes)), name="exit_merge")
exit_gate = Node(IdentityInterface(fields=["done"]), name="exit_gate")
wf.add_nodes([merge, exit_gate])
for idx, node in enumerate(transform_nodes, start=1):
wf.connect(node, "transformed_image", merge, f"in{idx}")
wf.connect(merge, "out", exit_gate, "done")
logger.info(
"Built transform workflow for {} | {} job(s), {} node(s)",
patient_id,
len(config.transforms.jobs),
sum(len(j.input_files) for j in config.transforms.jobs),
)
return wf