"""Input/Output utilities for the thesis framework."""
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, List, Literal, Optional, Union, cast, overload
import numpy as np
from thesis.core.exceptions import DependencyError
from thesis.core.logging import get_logger
logger = get_logger(__name__)
if TYPE_CHECKING:
from nibabel.spatialimages import SpatialImage
__all__ = [
"load_nifti",
"save_nifti",
"load_bvals",
"load_bvecs",
"save_bvals",
"save_bvecs",
"check_file_exists",
"ensure_directory",
"get_file_info",
"find_files",
"copy_nifti_metadata",
]
# Try to import nibabel
try:
import nibabel as nib
nib_module: Optional[ModuleType] = nib
NIBABEL_AVAILABLE = True
except ImportError:
nib_module = None
NIBABEL_AVAILABLE = False
logger.warning("nibabel not available - NIfTI I/O will not work")
@overload
def load_nifti(file_path: Union[str, Path], as_array: Literal[True] = True) -> np.ndarray: ...
@overload
def load_nifti(file_path: Union[str, Path], as_array: Literal[False]) -> "SpatialImage": ...
[docs]
def load_nifti(
file_path: Union[str, Path], as_array: bool = True
) -> Union[np.ndarray, "SpatialImage"]:
"""
Load a NIfTI image file.
Args:
file_path: Path to NIfTI file (.nii or .nii.gz)
as_array: If True, returns numpy array; if False, returns nibabel image
Returns:
Numpy array or nibabel image object
Raises:
FileNotFoundError: If file doesn't exist
ImportError: If nibabel is not installed
Example:
>>> data = load_nifti("T1.nii.gz")
>>> print(data.shape)
(256, 256, 256)
"""
if not NIBABEL_AVAILABLE:
raise DependencyError(
"nibabel is required for NIfTI I/O. Install with: pip install nibabel"
)
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"NIfTI file not found: {file_path}")
try:
if nib_module is None:
raise DependencyError(
"nibabel is required for NIfTI I/O. Install with: pip install nibabel"
)
nib_api = nib_module
img = nib_api.load(str(file_path))
shape = img.shape # type: ignore[attr-defined]
logger.debug(f"Loaded NIfTI: {file_path} (shape: {shape})")
if as_array:
return cast(np.ndarray, img.get_fdata()) # type: ignore[attr-defined]
return cast("SpatialImage", img)
except Exception as e:
logger.error(f"Error loading NIfTI file {file_path}: {e}")
raise
[docs]
def save_nifti(
data: Union[np.ndarray, "SpatialImage"],
file_path: Union[str, Path],
affine: Optional[np.ndarray] = None,
header: Optional[object] = None,
) -> Path:
"""
Save data as a NIfTI image file.
Args:
data: Numpy array or nibabel image to save
file_path: Path where to save the file
affine: Affine transformation matrix (required if data is array)
header: Optional NIfTI header
Returns:
Path to saved file
Raises:
ImportError: If nibabel is not installed
ValueError: If affine is missing for array data
Example:
>>> data = np.random.rand(64, 64, 64)
>>> affine = np.eye(4)
>>> save_nifti(data, "output.nii.gz", affine=affine)
"""
if not NIBABEL_AVAILABLE:
raise DependencyError(
"nibabel is required for NIfTI I/O. Install with: pip install nibabel"
)
file_path = Path(file_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
# Handle nibabel image directly
if hasattr(data, "get_fdata"):
img = data
else:
# Create nibabel image from array
if affine is None:
raise ValueError("affine matrix is required when saving array data")
if nib_module is None:
raise DependencyError(
"nibabel is required for NIfTI I/O. Install with: pip install nibabel"
)
nib_api = nib_module
img = nib_api.Nifti1Image(data, affine, header=header)
try:
if nib_module is None:
raise DependencyError(
"nibabel is required for NIfTI I/O. Install with: pip install nibabel"
)
nib_module.save(img, str(file_path)) # type: ignore[arg-type]
logger.debug(f"Saved NIfTI: {file_path} (shape: {img.shape})")
return file_path
except Exception as e:
logger.error(f"Error saving NIfTI file {file_path}: {e}")
raise
[docs]
def load_bvals(file_path: Union[str, Path]) -> np.ndarray:
"""
Load b-values from a text file.
Args:
file_path: Path to .bval file
Returns:
1D numpy array of b-values
Example:
>>> bvals = load_bvals("data.bval")
>>> print(bvals.shape)
(32,)
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"b-values file not found: {file_path}")
try:
bvals = np.loadtxt(file_path)
if bvals.ndim == 0:
bvals = np.array([bvals])
elif bvals.ndim == 2:
bvals = bvals.flatten()
logger.debug(f"Loaded b-values: {file_path} (n={len(bvals)})")
return cast(np.ndarray, bvals)
except Exception as e:
logger.error(f"Error loading b-values from {file_path}: {e}")
raise
[docs]
def load_bvecs(file_path: Union[str, Path]) -> np.ndarray:
"""
Load b-vectors from a text file.
Args:
file_path: Path to .bvec file
Returns:
2D numpy array of shape (3, N) where N is number of directions
Example:
>>> bvecs = load_bvecs("data.bvec")
>>> print(bvecs.shape)
(3, 32)
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"b-vectors file not found: {file_path}")
try:
bvecs = np.loadtxt(file_path)
# Ensure shape is (3, N)
if bvecs.shape[0] != 3:
if bvecs.shape[1] == 3:
bvecs = bvecs.T
else:
raise ValueError(f"Invalid b-vectors shape: {bvecs.shape}")
logger.debug(f"Loaded b-vectors: {file_path} (shape: {bvecs.shape})")
return cast(np.ndarray, bvecs)
except Exception as e:
logger.error(f"Error loading b-vectors from {file_path}: {e}")
raise
[docs]
def save_bvals(bvals: np.ndarray, file_path: Union[str, Path]) -> Path:
"""
Save b-values to a text file.
Args:
bvals: 1D array of b-values
file_path: Path where to save
Returns:
Path to saved file
"""
file_path = Path(file_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
try:
np.savetxt(file_path, bvals, fmt="%d")
logger.debug(f"Saved b-values: {file_path}")
return file_path
except Exception as e:
logger.error(f"Error saving b-values to {file_path}: {e}")
raise
[docs]
def save_bvecs(bvecs: np.ndarray, file_path: Union[str, Path]) -> Path:
"""
Save b-vectors to a text file.
Args:
bvecs: 2D array of shape (3, N)
file_path: Path where to save
Returns:
Path to saved file
"""
file_path = Path(file_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
if bvecs.shape[0] != 3:
raise ValueError(f"b-vectors must have shape (3, N), got {bvecs.shape}")
try:
np.savetxt(file_path, bvecs, fmt="%.6f")
logger.debug(f"Saved b-vectors: {file_path}")
return file_path
except Exception as e:
logger.error(f"Error saving b-vectors to {file_path}: {e}")
raise
[docs]
def check_file_exists(file_path: Union[str, Path], raise_error: bool = True) -> bool:
"""
Check if a file exists.
Args:
file_path: Path to check
raise_error: If True, raises FileNotFoundError; if False, returns bool
Returns:
True if file exists, False otherwise (if raise_error=False)
Raises:
FileNotFoundError: If file doesn't exist and raise_error=True
"""
file_path = Path(file_path)
exists = file_path.exists() and file_path.is_file()
if not exists and raise_error:
raise FileNotFoundError(f"File not found: {file_path}")
return exists
[docs]
def ensure_directory(dir_path: Union[str, Path], parents: bool = True) -> Path:
"""
Ensure a directory exists, creating it if necessary.
Args:
dir_path: Directory path
parents: Whether to create parent directories
Returns:
Path object
Example:
>>> output_dir = ensure_directory("./results/patient_001")
"""
dir_path = Path(dir_path)
dir_path.mkdir(parents=parents, exist_ok=True)
logger.debug(f"Ensured directory exists: {dir_path}")
return dir_path
[docs]
def get_file_info(file_path: Union[str, Path]) -> dict[str, object]:
"""
Get information about a file.
Args:
file_path: Path to file
Returns:
Dictionary with file information
Example:
>>> info = get_file_info("data.nii.gz")
>>> print(info["size_mb"])
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
stat = file_path.stat()
info = {
"path": str(file_path),
"name": file_path.name,
"size_bytes": stat.st_size,
"size_mb": stat.st_size / (1024 * 1024),
"modified": stat.st_mtime,
"is_file": file_path.is_file(),
"is_dir": file_path.is_dir(),
"suffix": file_path.suffix,
}
# Add NIfTI-specific info if applicable
if file_path.suffix in [".nii", ".gz"] and NIBABEL_AVAILABLE:
try:
if nib_module is None:
raise DependencyError(
"nibabel is required for NIfTI I/O. Install with: pip install nibabel"
)
nib_api = nib_module
img = nib_api.load(str(file_path))
info["shape"] = img.shape # type: ignore[attr-defined]
info["dtype"] = str(img.get_data_dtype()) # type: ignore[attr-defined]
info["ndim"] = len(img.shape) # type: ignore[attr-defined]
except Exception:
pass
return info
[docs]
def find_files(
directory: Union[str, Path], pattern: str = "*", recursive: bool = False
) -> List[Path]:
"""
Find files matching a pattern in a directory.
Args:
directory: Directory to search
pattern: Glob pattern to match
recursive: If True, search recursively
Returns:
List of matching file paths
Example:
>>> nii_files = find_files("./data", "*.nii.gz", recursive=True)
"""
directory = Path(directory)
if not directory.exists():
logger.warning(f"Directory not found: {directory}")
return []
if recursive:
files = sorted(directory.rglob(pattern))
else:
files = sorted(directory.glob(pattern))
# Filter to only files
files = [f for f in files if f.is_file()]
logger.debug(f"Found {len(files)} files matching '{pattern}' in {directory}")
return files