Source code for thesis.core.io

"""Input/Output utilities for the thesis framework."""

from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, List, Literal, Optional, Union, cast, overload

import numpy as np

from thesis.core.exceptions import DependencyError
from thesis.core.logging import get_logger

logger = get_logger(__name__)

if TYPE_CHECKING:
    from nibabel.spatialimages import SpatialImage

__all__ = [
    "load_nifti",
    "save_nifti",
    "load_bvals",
    "load_bvecs",
    "save_bvals",
    "save_bvecs",
    "check_file_exists",
    "ensure_directory",
    "get_file_info",
    "find_files",
    "copy_nifti_metadata",
]

# Try to import nibabel
try:
    import nibabel as nib

    nib_module: Optional[ModuleType] = nib
    NIBABEL_AVAILABLE = True
except ImportError:
    nib_module = None
    NIBABEL_AVAILABLE = False
    logger.warning("nibabel not available - NIfTI I/O will not work")


@overload
def load_nifti(file_path: Union[str, Path], as_array: Literal[True] = True) -> np.ndarray: ...


@overload
def load_nifti(file_path: Union[str, Path], as_array: Literal[False]) -> "SpatialImage": ...


[docs] def load_nifti( file_path: Union[str, Path], as_array: bool = True ) -> Union[np.ndarray, "SpatialImage"]: """ Load a NIfTI image file. Args: file_path: Path to NIfTI file (.nii or .nii.gz) as_array: If True, returns numpy array; if False, returns nibabel image Returns: Numpy array or nibabel image object Raises: FileNotFoundError: If file doesn't exist ImportError: If nibabel is not installed Example: >>> data = load_nifti("T1.nii.gz") >>> print(data.shape) (256, 256, 256) """ if not NIBABEL_AVAILABLE: raise DependencyError( "nibabel is required for NIfTI I/O. Install with: pip install nibabel" ) file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError(f"NIfTI file not found: {file_path}") try: if nib_module is None: raise DependencyError( "nibabel is required for NIfTI I/O. Install with: pip install nibabel" ) nib_api = nib_module img = nib_api.load(str(file_path)) shape = img.shape # type: ignore[attr-defined] logger.debug(f"Loaded NIfTI: {file_path} (shape: {shape})") if as_array: return cast(np.ndarray, img.get_fdata()) # type: ignore[attr-defined] return cast("SpatialImage", img) except Exception as e: logger.error(f"Error loading NIfTI file {file_path}: {e}") raise
[docs] def save_nifti( data: Union[np.ndarray, "SpatialImage"], file_path: Union[str, Path], affine: Optional[np.ndarray] = None, header: Optional[object] = None, ) -> Path: """ Save data as a NIfTI image file. Args: data: Numpy array or nibabel image to save file_path: Path where to save the file affine: Affine transformation matrix (required if data is array) header: Optional NIfTI header Returns: Path to saved file Raises: ImportError: If nibabel is not installed ValueError: If affine is missing for array data Example: >>> data = np.random.rand(64, 64, 64) >>> affine = np.eye(4) >>> save_nifti(data, "output.nii.gz", affine=affine) """ if not NIBABEL_AVAILABLE: raise DependencyError( "nibabel is required for NIfTI I/O. Install with: pip install nibabel" ) file_path = Path(file_path) file_path.parent.mkdir(parents=True, exist_ok=True) # Handle nibabel image directly if hasattr(data, "get_fdata"): img = data else: # Create nibabel image from array if affine is None: raise ValueError("affine matrix is required when saving array data") if nib_module is None: raise DependencyError( "nibabel is required for NIfTI I/O. Install with: pip install nibabel" ) nib_api = nib_module img = nib_api.Nifti1Image(data, affine, header=header) try: if nib_module is None: raise DependencyError( "nibabel is required for NIfTI I/O. Install with: pip install nibabel" ) nib_module.save(img, str(file_path)) # type: ignore[arg-type] logger.debug(f"Saved NIfTI: {file_path} (shape: {img.shape})") return file_path except Exception as e: logger.error(f"Error saving NIfTI file {file_path}: {e}") raise
[docs] def load_bvals(file_path: Union[str, Path]) -> np.ndarray: """ Load b-values from a text file. Args: file_path: Path to .bval file Returns: 1D numpy array of b-values Example: >>> bvals = load_bvals("data.bval") >>> print(bvals.shape) (32,) """ file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError(f"b-values file not found: {file_path}") try: bvals = np.loadtxt(file_path) if bvals.ndim == 0: bvals = np.array([bvals]) elif bvals.ndim == 2: bvals = bvals.flatten() logger.debug(f"Loaded b-values: {file_path} (n={len(bvals)})") return cast(np.ndarray, bvals) except Exception as e: logger.error(f"Error loading b-values from {file_path}: {e}") raise
[docs] def load_bvecs(file_path: Union[str, Path]) -> np.ndarray: """ Load b-vectors from a text file. Args: file_path: Path to .bvec file Returns: 2D numpy array of shape (3, N) where N is number of directions Example: >>> bvecs = load_bvecs("data.bvec") >>> print(bvecs.shape) (3, 32) """ file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError(f"b-vectors file not found: {file_path}") try: bvecs = np.loadtxt(file_path) # Ensure shape is (3, N) if bvecs.shape[0] != 3: if bvecs.shape[1] == 3: bvecs = bvecs.T else: raise ValueError(f"Invalid b-vectors shape: {bvecs.shape}") logger.debug(f"Loaded b-vectors: {file_path} (shape: {bvecs.shape})") return cast(np.ndarray, bvecs) except Exception as e: logger.error(f"Error loading b-vectors from {file_path}: {e}") raise
[docs] def save_bvals(bvals: np.ndarray, file_path: Union[str, Path]) -> Path: """ Save b-values to a text file. Args: bvals: 1D array of b-values file_path: Path where to save Returns: Path to saved file """ file_path = Path(file_path) file_path.parent.mkdir(parents=True, exist_ok=True) try: np.savetxt(file_path, bvals, fmt="%d") logger.debug(f"Saved b-values: {file_path}") return file_path except Exception as e: logger.error(f"Error saving b-values to {file_path}: {e}") raise
[docs] def save_bvecs(bvecs: np.ndarray, file_path: Union[str, Path]) -> Path: """ Save b-vectors to a text file. Args: bvecs: 2D array of shape (3, N) file_path: Path where to save Returns: Path to saved file """ file_path = Path(file_path) file_path.parent.mkdir(parents=True, exist_ok=True) if bvecs.shape[0] != 3: raise ValueError(f"b-vectors must have shape (3, N), got {bvecs.shape}") try: np.savetxt(file_path, bvecs, fmt="%.6f") logger.debug(f"Saved b-vectors: {file_path}") return file_path except Exception as e: logger.error(f"Error saving b-vectors to {file_path}: {e}") raise
[docs] def check_file_exists(file_path: Union[str, Path], raise_error: bool = True) -> bool: """ Check if a file exists. Args: file_path: Path to check raise_error: If True, raises FileNotFoundError; if False, returns bool Returns: True if file exists, False otherwise (if raise_error=False) Raises: FileNotFoundError: If file doesn't exist and raise_error=True """ file_path = Path(file_path) exists = file_path.exists() and file_path.is_file() if not exists and raise_error: raise FileNotFoundError(f"File not found: {file_path}") return exists
[docs] def ensure_directory(dir_path: Union[str, Path], parents: bool = True) -> Path: """ Ensure a directory exists, creating it if necessary. Args: dir_path: Directory path parents: Whether to create parent directories Returns: Path object Example: >>> output_dir = ensure_directory("./results/patient_001") """ dir_path = Path(dir_path) dir_path.mkdir(parents=parents, exist_ok=True) logger.debug(f"Ensured directory exists: {dir_path}") return dir_path
[docs] def get_file_info(file_path: Union[str, Path]) -> dict[str, object]: """ Get information about a file. Args: file_path: Path to file Returns: Dictionary with file information Example: >>> info = get_file_info("data.nii.gz") >>> print(info["size_mb"]) """ file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError(f"File not found: {file_path}") stat = file_path.stat() info = { "path": str(file_path), "name": file_path.name, "size_bytes": stat.st_size, "size_mb": stat.st_size / (1024 * 1024), "modified": stat.st_mtime, "is_file": file_path.is_file(), "is_dir": file_path.is_dir(), "suffix": file_path.suffix, } # Add NIfTI-specific info if applicable if file_path.suffix in [".nii", ".gz"] and NIBABEL_AVAILABLE: try: if nib_module is None: raise DependencyError( "nibabel is required for NIfTI I/O. Install with: pip install nibabel" ) nib_api = nib_module img = nib_api.load(str(file_path)) info["shape"] = img.shape # type: ignore[attr-defined] info["dtype"] = str(img.get_data_dtype()) # type: ignore[attr-defined] info["ndim"] = len(img.shape) # type: ignore[attr-defined] except Exception: pass return info
[docs] def find_files( directory: Union[str, Path], pattern: str = "*", recursive: bool = False ) -> List[Path]: """ Find files matching a pattern in a directory. Args: directory: Directory to search pattern: Glob pattern to match recursive: If True, search recursively Returns: List of matching file paths Example: >>> nii_files = find_files("./data", "*.nii.gz", recursive=True) """ directory = Path(directory) if not directory.exists(): logger.warning(f"Directory not found: {directory}") return [] if recursive: files = sorted(directory.rglob(pattern)) else: files = sorted(directory.glob(pattern)) # Filter to only files files = [f for f in files if f.is_file()] logger.debug(f"Found {len(files)} files matching '{pattern}' in {directory}") return files
[docs] def copy_nifti_metadata( source_path: Union[str, Path], target_data: np.ndarray, target_path: Union[str, Path] ) -> Path: """ Copy metadata from source NIfTI to save target data. Useful when you process an image and want to keep the same affine matrix and header information. Args: source_path: Source NIfTI file (for metadata) target_data: New data array to save target_path: Where to save the new file Returns: Path to saved file Example: >>> processed = process_image(data) >>> copy_nifti_metadata("input.nii.gz", processed, "output.nii.gz") """ if not NIBABEL_AVAILABLE: raise DependencyError("nibabel is required for NIfTI metadata copy") if nib_module is None: raise DependencyError("nibabel is required for NIfTI metadata copy") nib_api = nib_module source_img = nib_api.load(str(source_path)) source_affine = source_img.affine # type: ignore[attr-defined] source_header = source_img.header # type: ignore[attr-defined] new_img = nib_api.Nifti1Image(target_data, source_affine, source_header) return save_nifti(new_img, target_path)