"""
Processing context for the thesis framework.
Provides a context object that carries information about the current
processing state, patient data, paths, and configuration through
the pipeline.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Union
from thesis.core.config import PipelineConfig
from thesis.core.exceptions import ValidationError
from thesis.core.logging import get_logger
from thesis.core.utils import to_path
logger = get_logger(__name__)
__all__ = ["ProcessingContext", "create_context"]
[docs]
@dataclass
class ProcessingContext:
"""
Context object for medical imaging processing pipelines.
Carries all relevant information about a processing run, including
patient ID, data paths, configuration, and intermediate results.
Attributes:
patient_id: Unique patient identifier
config: Configuration for this processing run
data_dir: Shared assets base (config ``paths.assets_dir``): templates,
atlases, ROIs. Not a parent of input_dir/output_dir.
input_dir: Per-patient input data directory (``inputs_dir/<patient_id>``)
output_dir: Per-patient output directory (``output_dir/<patient_id>``)
working_dir: Temporary working/scratch directory
metadata: Additional metadata
results: Dictionary to store intermediate results
Example:
>>> ctx = ProcessingContext(
... patient_id="DTI_LDF001",
... config=my_config,
... data_dir=Path("./data")
... )
>>> ctx.add_result("registration", transform_matrix)
>>> transform = ctx.get_result("registration")
"""
patient_id: str
config: PipelineConfig
data_dir: Path
input_dir: Optional[Path] = None
output_dir: Optional[Path] = None
working_dir: Optional[Path] = None
metadata: Dict[str, Any] = field(default_factory=dict)
results: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Initialize paths and create directories."""
# ``data_dir`` is the resolved assets base (config ``paths.assets_dir``);
# it anchors DataFile/DataDir lookups but is NOT a parent of the input
# or output roots, which are independent.
self.data_dir = to_path(self.data_dir)
def _resolve_project(base: Union[Path, str]) -> Path:
"""Resolve a project-level path relative to the current working directory.
The input/output/scratch roots are independent, project-level
directories. Absolute paths are returned unchanged; relative paths
are resolved against CWD at context creation time so they remain
stable even if the process working directory changes later (e.g.
during Nipype execution).
"""
p = to_path(base)
if p.is_absolute():
return p
return Path.cwd() / p
if self.input_dir is None:
self.input_dir = _resolve_project(self.config.paths.inputs_dir) / self.patient_id
else:
self.input_dir = to_path(self.input_dir)
if self.output_dir is None:
self.output_dir = _resolve_project(self.config.paths.output_dir) / self.patient_id
else:
self.output_dir = to_path(self.output_dir)
if self.working_dir is None:
if self.config.paths.scratch_dir:
self.working_dir = _resolve_project(self.config.paths.scratch_dir) / self.patient_id
else:
self.working_dir = self.output_dir / "temp"
else:
self.working_dir = to_path(self.working_dir)
# Create output directories
self.output_dir.mkdir(parents=True, exist_ok=True)
self.working_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Created processing context for patient: {self.patient_id}")
[docs]
def add_result(self, key: str, value: Any) -> None:
"""
Store a processing result.
Args:
key: Result identifier
value: Result data
Example:
>>> ctx.add_result("brain_mask", mask_path)
"""
self.results[key] = value
logger.debug(f"Added result to context: {key}")
[docs]
def get_result(self, key: str, default: Any = None) -> Any:
"""
Retrieve a processing result.
Args:
key: Result identifier
default: Default value if key not found
Returns:
Result data or default
Example:
>>> mask_path = ctx.get_result("brain_mask")
"""
return self.results.get(key, default)
[docs]
def has_result(self, key: str) -> bool:
"""
Check if a result exists.
Args:
key: Result identifier
Returns:
True if result exists, False otherwise
"""
return key in self.results
[docs]
def add_metadata(self, key: str, value: Any) -> None:
"""
Add metadata to the context.
Args:
key: Metadata key
value: Metadata value
Example:
>>> ctx.add_metadata("acquisition_date", "2025-01-15")
"""
self.metadata[key] = value
logger.debug(f"Added metadata to context: {key}")
[docs]
def get_metadata(self, key: str, default: Any = None) -> Any:
"""
Retrieve metadata.
Args:
key: Metadata key
default: Default value if key not found
Returns:
Metadata value or default
"""
return self.metadata.get(key, default)
@staticmethod
def _check_path_traversal(resolved: Path, base_dir: Path, label: str) -> None:
"""Raise ValueError if *resolved* escapes *base_dir*."""
try:
resolved.resolve().relative_to(base_dir.resolve())
except ValueError:
raise ValidationError(
f"Path traversal detected: '{resolved}' escapes {label} '{base_dir}'"
)
[docs]
def get_input_path(self, filename: str) -> Path:
"""
Get full path for an input file.
Args:
filename: Input filename
Returns:
Full path to input file
Raises:
ValueError: If a *relative* filename escapes input_dir.
Example:
>>> t1_path = ctx.get_input_path("DTI_LDF001_T1.nii.gz")
"""
if self.input_dir is None:
raise ValidationError("input_dir is not set")
candidate = to_path(filename)
# An absolute path is an explicit, deliberate location (e.g. a config
# override pointing at a prior trial's outputs) — honour it as-is. The
# traversal guard exists to catch RELATIVE escapes ("../.."), so apply it
# only when anchoring a relative filename under input_dir.
if candidate.is_absolute():
return candidate
result = self.input_dir / candidate
self._check_path_traversal(result, self.input_dir, "input_dir")
return result
[docs]
def get_output_path(self, filename: str, subdir: Optional[str] = None) -> Path:
"""
Get full path for an output file.
Args:
filename: Output filename
subdir: Optional subdirectory within output_dir
Returns:
Full path to output file
Raises:
ValueError: If a *relative* path escapes output_dir.
Example:
>>> reg_path = ctx.get_output_path("T1_registered.nii.gz", "registration")
"""
if self.output_dir is None:
raise ValidationError("output_dir is not set")
# Absolute filename → explicit location; honour as-is (guard only catches
# relative escapes). See get_input_path for the rationale.
candidate = to_path(filename)
if candidate.is_absolute():
candidate.parent.mkdir(parents=True, exist_ok=True)
return candidate
if subdir:
output_path = self.output_dir / subdir / candidate
output_path.parent.mkdir(parents=True, exist_ok=True)
else:
output_path = self.output_dir / candidate
self._check_path_traversal(output_path, self.output_dir, "output_dir")
return output_path
[docs]
def get_working_path(self, filename: str) -> Path:
"""
Get full path for a temporary working file.
Args:
filename: Working filename
Returns:
Full path to working file
Example:
>>> temp_path = ctx.get_working_path("temp_mask.nii.gz")
"""
if self.working_dir is None:
raise ValidationError("working_dir is not set")
result = self.working_dir / filename
self._check_path_traversal(result, self.working_dir, "working_dir")
return result
[docs]
def cleanup_working_dir(self) -> None:
"""
Remove all files from the working directory.
Use with caution - this deletes temporary files.
"""
import shutil
if self.working_dir is None:
raise ValidationError("working_dir is not set")
if self.working_dir.exists():
shutil.rmtree(self.working_dir)
self.working_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Cleaned up working directory: {self.working_dir}")
[docs]
def list_input_files(self, pattern: str = "*") -> list[Path]:
"""
List files in the input directory.
Args:
pattern: Glob pattern for matching files
Returns:
List of matching file paths
Example:
>>> nii_files = ctx.list_input_files("*.nii.gz")
"""
if self.input_dir is None or not self.input_dir.exists():
return []
return sorted(self.input_dir.glob(pattern))
[docs]
def list_output_files(self, pattern: str = "*") -> list[Path]:
"""
List files in the output directory.
Args:
pattern: Glob pattern for matching files
Returns:
List of matching file paths
"""
if self.output_dir is None or not self.output_dir.exists():
return []
return sorted(self.output_dir.glob(pattern))
[docs]
def to_dict(self) -> Dict[str, Any]:
"""
Convert context to dictionary.
Returns:
Dictionary representation of context
"""
return {
"patient_id": self.patient_id,
"input_dir": str(self.input_dir),
"output_dir": str(self.output_dir),
"working_dir": str(self.working_dir),
"metadata": self.metadata.copy(),
"results": {k: str(v) if isinstance(v, Path) else v for k, v in self.results.items()},
}
def __repr__(self) -> str:
"""String representation."""
return (
f"ProcessingContext(patient_id='{self.patient_id}', "
f"output_dir='{self.output_dir}', "
f"results={len(self.results)})"
)
[docs]
def create_context(
patient_id: str, config: PipelineConfig, data_dir: Optional[Path] = None, **kwargs
) -> ProcessingContext:
"""
Factory function to create a processing context.
Args:
patient_id: Patient identifier
config: Configuration object
data_dir: Shared assets base (reads config ``paths.assets_dir`` if not provided)
**kwargs: Additional arguments passed to ProcessingContext
Returns:
Initialized ProcessingContext
Example:
>>> config = load_config("default")
>>> ctx = create_context("DTI_LDF001", config)
"""
if data_dir is None:
# The context's ``data_dir`` is the shared assets base (config
# ``paths.assets_dir``): templates, atlases, ROIs. Input and output
# roots are resolved independently in ``__post_init__``.
data_dir = config.paths.assets_dir
data_dir = Path(data_dir)
context = ProcessingContext(patient_id=patient_id, config=config, data_dir=data_dir, **kwargs)
logger.info(f"Created processing context for patient: {patient_id}")
return context