"""MRtrix3 path resolution and preparation.
Resolves DWI image, gradient table, brain mask, T1, and per-patient output
locations from the pipeline config. Reuses the HCP path conventions since
the upstream data layout is identical (HCP-style preprocessed subjects);
only the tractography backend differs.
"""
from pathlib import Path
from typing import TypedDict
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.logging import get_logger
from thesis.core.utils import resolve_path
from ...hcp.config.values import resolve_hcp_value
from ..common import format_patient_path, resolve_t1_path, resolve_with_fallback
from ..config.values import resolve_mrtrix3_value
logger = get_logger(__name__)
[docs]
class MRtrix3Paths(TypedDict):
"""Resolved filesystem paths for the MRtrix3 workflow."""
input_dir: Path
diffusion_dir: Path
dwi_image: Path
bvec: Path
bval: Path
mask_path: Path
t1_path: Path
DEFAULT_DWI_NAME = "data.nii.gz"
DEFAULT_BVEC_NAME = "bvecs"
DEFAULT_BVAL_NAME = "bvals"
DEFAULT_MASK_NAME = "nodif_brain_mask.nii.gz"
def _resolve_diffusion_file(
diffusion_dir: Path,
input_dir: Path,
context: ProcessingContext,
primary_name: str,
fallback_names: tuple[str, ...] = (),
) -> Path:
"""Locate a file inside ``diffusion_dir`` with optional alternative names.
Falls back to scanning ``input_dir`` and the patient ``data_dir`` if the
primary location does not exist. Useful for DWI data where filenames
vary slightly between datasets (``data.nii.gz`` vs ``data.nii``).
Args:
diffusion_dir: Primary directory expected to contain the file.
input_dir: Patient input root, used as a fallback search base.
context: Processing context (for ``data_dir`` fallback).
primary_name: Preferred filename.
fallback_names: Additional filenames to try in order.
Returns:
Resolved path. May not exist; verifier will surface that as an error.
"""
candidates = (primary_name, *fallback_names)
fallback_bases = [input_dir, getattr(context, "data_dir", None)]
for name in candidates:
primary = diffusion_dir / name
if primary.exists():
return primary
resolved = resolve_with_fallback(name, diffusion_dir, fallback_bases)
if resolved.exists():
return resolved
return diffusion_dir / primary_name
[docs]
def prepare_mrtrix3_paths(config: PipelineConfig, context: ProcessingContext) -> MRtrix3Paths:
"""Resolve filesystem paths required by the MRtrix3 workflow."""
diffusion_dirname = format_patient_path(
resolve_hcp_value(config, "diffusion_dir", "T1w/Diffusion"), context.patient_id
)
input_dir = (
Path(context.input_dir).resolve() if context.input_dir is not None else Path(".").resolve()
)
diffusion_dir = resolve_path(input_dir, diffusion_dirname)
dwi_image = _resolve_diffusion_file(
diffusion_dir,
input_dir,
context,
resolve_mrtrix3_value(config, "dwi_name", DEFAULT_DWI_NAME),
fallback_names=("data.nii", "dwi.nii.gz", "dwi.nii"),
)
bvec = _resolve_diffusion_file(
diffusion_dir,
input_dir,
context,
resolve_mrtrix3_value(config, "bvec_name", DEFAULT_BVEC_NAME),
fallback_names=("bvec", "bvecs.txt"),
)
bval = _resolve_diffusion_file(
diffusion_dir,
input_dir,
context,
resolve_mrtrix3_value(config, "bval_name", DEFAULT_BVAL_NAME),
fallback_names=("bval", "bvals.txt"),
)
mask_name = resolve_hcp_value(config, "mask_name", DEFAULT_MASK_NAME)
mask_path_cfg = resolve_hcp_value(config, "mask_path")
if mask_path_cfg:
mask_path_cfg = format_patient_path(mask_path_cfg, context.patient_id)
mask_path = resolve_with_fallback(
mask_path_cfg, input_dir, [context.output_dir, context.data_dir]
)
else:
mask_path = diffusion_dir / mask_name
if not mask_path.exists():
for fallback in (input_dir, getattr(context, "data_dir", None)):
if fallback is None:
continue
candidate = Path(fallback).resolve() / mask_name
if candidate.exists():
logger.warning(
"Brain mask not found at {}, using fallback {}",
diffusion_dir / mask_name,
candidate,
)
mask_path = candidate
break
return {
"input_dir": input_dir,
"diffusion_dir": diffusion_dir,
"dwi_image": dwi_image,
"bvec": bvec,
"bval": bval,
"mask_path": mask_path,
"t1_path": resolve_t1_path(config, context),
}