Source code for thesis.core.context

"""
Processing context for the thesis framework.

Provides a context object that carries information about the current
processing state, patient data, paths, and configuration through
the pipeline.
"""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Union

from thesis.core.config import PipelineConfig
from thesis.core.exceptions import ValidationError
from thesis.core.logging import get_logger
from thesis.core.utils import to_path

logger = get_logger(__name__)

__all__ = ["ProcessingContext", "create_context"]


[docs] @dataclass class ProcessingContext: """ Context object for medical imaging processing pipelines. Carries all relevant information about a processing run, including patient ID, data paths, configuration, and intermediate results. Attributes: patient_id: Unique patient identifier config: Configuration for this processing run data_dir: Shared assets base (config ``paths.assets_dir``): templates, atlases, ROIs. Not a parent of input_dir/output_dir. input_dir: Per-patient input data directory (``inputs_dir/<patient_id>``) output_dir: Per-patient output directory (``output_dir/<patient_id>``) working_dir: Temporary working/scratch directory metadata: Additional metadata results: Dictionary to store intermediate results Example: >>> ctx = ProcessingContext( ... patient_id="DTI_LDF001", ... config=my_config, ... data_dir=Path("./data") ... ) >>> ctx.add_result("registration", transform_matrix) >>> transform = ctx.get_result("registration") """ patient_id: str config: PipelineConfig data_dir: Path input_dir: Optional[Path] = None output_dir: Optional[Path] = None working_dir: Optional[Path] = None metadata: Dict[str, Any] = field(default_factory=dict) results: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): """Initialize paths and create directories.""" # ``data_dir`` is the resolved assets base (config ``paths.assets_dir``); # it anchors DataFile/DataDir lookups but is NOT a parent of the input # or output roots, which are independent. self.data_dir = to_path(self.data_dir) def _resolve_project(base: Union[Path, str]) -> Path: """Resolve a project-level path relative to the current working directory. The input/output/scratch roots are independent, project-level directories. Absolute paths are returned unchanged; relative paths are resolved against CWD at context creation time so they remain stable even if the process working directory changes later (e.g. during Nipype execution). """ p = to_path(base) if p.is_absolute(): return p return Path.cwd() / p if self.input_dir is None: self.input_dir = _resolve_project(self.config.paths.inputs_dir) / self.patient_id else: self.input_dir = to_path(self.input_dir) if self.output_dir is None: self.output_dir = _resolve_project(self.config.paths.output_dir) / self.patient_id else: self.output_dir = to_path(self.output_dir) if self.working_dir is None: if self.config.paths.scratch_dir: self.working_dir = _resolve_project(self.config.paths.scratch_dir) / self.patient_id else: self.working_dir = self.output_dir / "temp" else: self.working_dir = to_path(self.working_dir) # Create output directories self.output_dir.mkdir(parents=True, exist_ok=True) self.working_dir.mkdir(parents=True, exist_ok=True) logger.debug(f"Created processing context for patient: {self.patient_id}")
[docs] def add_result(self, key: str, value: Any) -> None: """ Store a processing result. Args: key: Result identifier value: Result data Example: >>> ctx.add_result("brain_mask", mask_path) """ self.results[key] = value logger.debug(f"Added result to context: {key}")
[docs] def get_result(self, key: str, default: Any = None) -> Any: """ Retrieve a processing result. Args: key: Result identifier default: Default value if key not found Returns: Result data or default Example: >>> mask_path = ctx.get_result("brain_mask") """ return self.results.get(key, default)
[docs] def has_result(self, key: str) -> bool: """ Check if a result exists. Args: key: Result identifier Returns: True if result exists, False otherwise """ return key in self.results
[docs] def add_metadata(self, key: str, value: Any) -> None: """ Add metadata to the context. Args: key: Metadata key value: Metadata value Example: >>> ctx.add_metadata("acquisition_date", "2025-01-15") """ self.metadata[key] = value logger.debug(f"Added metadata to context: {key}")
[docs] def get_metadata(self, key: str, default: Any = None) -> Any: """ Retrieve metadata. Args: key: Metadata key default: Default value if key not found Returns: Metadata value or default """ return self.metadata.get(key, default)
@staticmethod def _check_path_traversal(resolved: Path, base_dir: Path, label: str) -> None: """Raise ValueError if *resolved* escapes *base_dir*.""" try: resolved.resolve().relative_to(base_dir.resolve()) except ValueError: raise ValidationError( f"Path traversal detected: '{resolved}' escapes {label} '{base_dir}'" )
[docs] def get_input_path(self, filename: str) -> Path: """ Get full path for an input file. Args: filename: Input filename Returns: Full path to input file Raises: ValueError: If a *relative* filename escapes input_dir. Example: >>> t1_path = ctx.get_input_path("DTI_LDF001_T1.nii.gz") """ if self.input_dir is None: raise ValidationError("input_dir is not set") candidate = to_path(filename) # An absolute path is an explicit, deliberate location (e.g. a config # override pointing at a prior trial's outputs) — honour it as-is. The # traversal guard exists to catch RELATIVE escapes ("../.."), so apply it # only when anchoring a relative filename under input_dir. if candidate.is_absolute(): return candidate result = self.input_dir / candidate self._check_path_traversal(result, self.input_dir, "input_dir") return result
[docs] def get_output_path(self, filename: str, subdir: Optional[str] = None) -> Path: """ Get full path for an output file. Args: filename: Output filename subdir: Optional subdirectory within output_dir Returns: Full path to output file Raises: ValueError: If a *relative* path escapes output_dir. Example: >>> reg_path = ctx.get_output_path("T1_registered.nii.gz", "registration") """ if self.output_dir is None: raise ValidationError("output_dir is not set") # Absolute filename → explicit location; honour as-is (guard only catches # relative escapes). See get_input_path for the rationale. candidate = to_path(filename) if candidate.is_absolute(): candidate.parent.mkdir(parents=True, exist_ok=True) return candidate if subdir: output_path = self.output_dir / subdir / candidate output_path.parent.mkdir(parents=True, exist_ok=True) else: output_path = self.output_dir / candidate self._check_path_traversal(output_path, self.output_dir, "output_dir") return output_path
[docs] def get_working_path(self, filename: str) -> Path: """ Get full path for a temporary working file. Args: filename: Working filename Returns: Full path to working file Example: >>> temp_path = ctx.get_working_path("temp_mask.nii.gz") """ if self.working_dir is None: raise ValidationError("working_dir is not set") result = self.working_dir / filename self._check_path_traversal(result, self.working_dir, "working_dir") return result
[docs] def cleanup_working_dir(self) -> None: """ Remove all files from the working directory. Use with caution - this deletes temporary files. """ import shutil if self.working_dir is None: raise ValidationError("working_dir is not set") if self.working_dir.exists(): shutil.rmtree(self.working_dir) self.working_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Cleaned up working directory: {self.working_dir}")
[docs] def list_input_files(self, pattern: str = "*") -> list[Path]: """ List files in the input directory. Args: pattern: Glob pattern for matching files Returns: List of matching file paths Example: >>> nii_files = ctx.list_input_files("*.nii.gz") """ if self.input_dir is None or not self.input_dir.exists(): return [] return sorted(self.input_dir.glob(pattern))
[docs] def list_output_files(self, pattern: str = "*") -> list[Path]: """ List files in the output directory. Args: pattern: Glob pattern for matching files Returns: List of matching file paths """ if self.output_dir is None or not self.output_dir.exists(): return [] return sorted(self.output_dir.glob(pattern))
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert context to dictionary. Returns: Dictionary representation of context """ return { "patient_id": self.patient_id, "input_dir": str(self.input_dir), "output_dir": str(self.output_dir), "working_dir": str(self.working_dir), "metadata": self.metadata.copy(), "results": {k: str(v) if isinstance(v, Path) else v for k, v in self.results.items()}, }
def __repr__(self) -> str: """String representation.""" return ( f"ProcessingContext(patient_id='{self.patient_id}', " f"output_dir='{self.output_dir}', " f"results={len(self.results)})" )
[docs] def create_context( patient_id: str, config: PipelineConfig, data_dir: Optional[Path] = None, **kwargs ) -> ProcessingContext: """ Factory function to create a processing context. Args: patient_id: Patient identifier config: Configuration object data_dir: Shared assets base (reads config ``paths.assets_dir`` if not provided) **kwargs: Additional arguments passed to ProcessingContext Returns: Initialized ProcessingContext Example: >>> config = load_config("default") >>> ctx = create_context("DTI_LDF001", config) """ if data_dir is None: # The context's ``data_dir`` is the shared assets base (config # ``paths.assets_dir``): templates, atlases, ROIs. Input and output # roots are resolved independently in ``__post_init__``. data_dir = config.paths.assets_dir data_dir = Path(data_dir) context = ProcessingContext(patient_id=patient_id, config=config, data_dir=data_dir, **kwargs) logger.info(f"Created processing context for patient: {patient_id}") return context