import os
import numpy as np
from silx.io.url import DataUrl
from silx.io import get_data
from tomoscan.esrf.scan.edfscan import EDFTomoScan
from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
from ..utils import check_supported, indices_to_slices
from ..io.reader import EDFStackReader, NXDarksFlats, NXTomoReader
from ..io.utils import get_compacted_dataslices
from .utils import get_values_from_file, is_hdf5_extension
from .logger import LoggerOrPrint
from ..pipeline.utils import nabu_env_settings
[docs]
class DatasetAnalyzer:
_scanner = None
kind = "none"
"""
Base class for datasets analyzers.
"""
def __init__(self, location, extra_options=None, logger=None):
"""
Initialize a Dataset analyzer.
Parameters
----------
location: str
Dataset location (directory or file name)
extra_options: dict, optional
Extra options on how to interpret the dataset.
logger: logging object, optional
Logger. If not set, messages will just be printed in stdout.
"""
self.logger = LoggerOrPrint(logger)
self.location = location
self._set_extra_options(extra_options)
self._get_excluded_projections()
self._set_default_dataset_values()
self._init_dataset_scan()
self._finish_init()
def _set_extra_options(self, extra_options):
if extra_options is None:
extra_options = {}
# COMPAT.
advanced_options = {
"force_flatfield": False,
"output_dir": None,
"exclude_projections": None,
"hdf5_entry": None,
# "nx_version": 1.0,
}
# --
advanced_options.update(extra_options)
self.extra_options = advanced_options
# pylint: disable=E1136
def _get_excluded_projections(self):
excluded_projs = self.extra_options["exclude_projections"]
self._ignore_projections = None
if excluded_projs is None:
return
if excluded_projs["type"] == "angular_range":
excluded_projs["type"] = "range" # compat with tomoscan #pylint: disable=E1137
values = excluded_projs["range"]
for ignore_kind, dtype in {"indices": np.int32, "angles": np.float32}.items():
if excluded_projs["type"] == ignore_kind:
values = get_values_from_file(excluded_projs["file"], any_size=True).astype(dtype).tolist()
self._ignore_projections = {"kind": excluded_projs["type"], "values": values} # pylint: disable=E0606
def _init_dataset_scan(self, **kwargs):
if self._scanner is None:
raise ValueError("Base class")
if self._scanner is NXtomoScan:
if self.extra_options.get("hdf5_entry", None) is not None:
kwargs["entry"] = self.extra_options["hdf5_entry"]
if self.extra_options.get("nx_version", None) is not None:
kwargs["nx_version"] = self.extra_options["nx_version"]
if self._scanner is EDFTomoScan:
# Assume 1 frame per file (otherwise too long to open each file)
kwargs["n_frames"] = 1
self.dataset_scanner = self._scanner( # pylint: disable=E1102
self.location, ignore_projections=self._ignore_projections, **kwargs
)
if self._ignore_projections is not None:
self.logger.info("Excluding projections: %s" % str(self._ignore_projections))
if nabu_env_settings.skip_tomoscan_checks:
self.logger.warning(
" WARNING: according to nabu_env_settings.skip_tomoscan_checks, skipping virtual layout integrity check of tomoscan which is time consuming"
)
self.dataset_scanner.set_check_behavior(run_check=False, raise_error=False)
self.raw_flats = self.dataset_scanner.flats
self.raw_darks = self.dataset_scanner.darks
self.n_angles = len(self.dataset_scanner.projections)
self.radio_dims = (self.dataset_scanner.dim_1, self.dataset_scanner.dim_2)
self._radio_dims_notbinned = self.radio_dims # COMPAT
def _finish_init(self):
pass
def _set_default_dataset_values(self):
self._detector_tilt = None
self.translations = None
self.ctf_translations = None
self.axis_position = None
self._rotation_angles = None
self.z_per_proj = None
self.x_per_proj = None
self._energy = None
self._pixel_size = None
self._distance = None
self._flats_srcurrent = None
self._projections = None
self._projections_srcurrent = None
self._reduced_flats = None
self._reduced_darks = None
@property
def energy(self):
"""
Return the energy in kev.
"""
if self._energy is None:
self._energy = self.dataset_scanner.energy
return self._energy
@energy.setter
def energy(self, val):
self._energy = val
@property
def distance(self):
"""
Return the sample-detector distance in meters.
"""
if self._distance is None:
self._distance = abs(self.dataset_scanner.distance)
return self._distance
@distance.setter
def distance(self, val):
self._distance = val
@property
def pixel_size(self):
"""
Return the pixel size in microns.
"""
# TODO X and Y pixel size
if self._pixel_size is None:
self._pixel_size = self.dataset_scanner.pixel_size * 1e6
return self._pixel_size
@pixel_size.setter
def pixel_size(self, val):
self._pixel_size = val
def _get_rotation_angles(self):
return self._rotation_angles # None by default
@property
def rotation_angles(self):
"""
Return the rotation angles in radians.
"""
return self._get_rotation_angles()
@rotation_angles.setter
def rotation_angles(self, angles):
self._rotation_angles = angles
def _is_halftomo(self):
return None # base class
@property
def is_halftomo(self):
"""
Indicates whether the current dataset was performed with half acquisition.
"""
return self._is_halftomo()
@property
def detector_tilt(self):
"""
Return the detector tilt in degrees
"""
return self._detector_tilt
@detector_tilt.setter
def detector_tilt(self, tilt):
self._detector_tilt = tilt
def _get_srcurrent(self, frame_type):
# To be implemented by inheriting class
return None
@property
def projections(self):
if self._projections is None:
self._projections = self.dataset_scanner.projections
return self._projections
@projections.setter
def projections(self, val):
raise ValueError
@property
def projections_srcurrent(self):
"""
Return the synchrotron electric current for each projection.
"""
if self._projections_srcurrent is None:
self._projections_srcurrent = self._get_srcurrent("radios") # pylint: disable=E1128
return self._projections_srcurrent
@projections_srcurrent.setter
def projections_srcurrent(self, val):
self._projections_srcurrent = val
@property
def flats_srcurrent(self):
"""
Return the synchrotron electric current for each flat image.
"""
if self._flats_srcurrent is None:
self._flats_srcurrent = self._get_srcurrent("flats") # pylint: disable=E1128
return self._flats_srcurrent
@flats_srcurrent.setter
def flats_srcurrent(self, val):
self._flats_srcurrent = val
[docs]
def check_defined_attribute(self, name, error_msg=None):
"""
Utility function to check that a given attribute is defined.
"""
if getattr(self, name, None) is None:
raise ValueError(error_msg or str("No information on %s was found in the dataset" % name))
@property
def flats(self):
"""
Return the REDUCED flat-field images. Either by reducing (median) the raw flats, or a user-defined reduced flats.
"""
if self._reduced_flats is None:
self._reduced_flats = self.get_reduced_flats()
return self._reduced_flats
@flats.setter
def flats(self, val):
self._reduced_flats = val
@property
def darks(self):
"""
Return the REDUCED flat-field images. Either by reducing (mean) the raw darks, or a user-defined reduced darks.
"""
if self._reduced_darks is None:
self._reduced_darks = self.get_reduced_darks()
return self._reduced_darks
@darks.setter
def darks(self, val):
self._reduced_darks = val
[docs]
class EDFDatasetAnalyzer(DatasetAnalyzer):
"""
EDF Dataset analyzer for legacy ESRF acquisitions
"""
_scanner = EDFTomoScan
kind = "edf"
def _finish_init(self):
pass
def _get_flats_darks(self):
return
@property
def hdf5_entry(self):
"""
Return the HDF5 entry of the current dataset.
Not applicable for EDF (return None)
"""
return None
def _is_halftomo(self):
return None
def _get_rotation_angles(self):
return np.deg2rad(self.dataset_scanner.rotation_angle())
[docs]
def get_reduced_flats(self, **reader_kwargs):
if self.raw_flats in [None, {}]:
raise FileNotFoundError("No reduced flat ('refHST') found in %s" % self.location)
# A few notes:
# (1) In principle we could do the reduction (mean/median) from raw frames (ref_xxxx_yyyy)
# but for legacy datasets it's always already done (by fasttomo3), and EDF support is supposed to be dropped on our side
# (2) We use EDFStackReader class to handle the possible additional data modifications
# (eg. subsampling, binning, distortion correction...)
# (3) The following spawns one reader instance per file, which is not elegant,
# but in principle there are typically 1-2 reduced flats in a scan
readers = {k: EDFStackReader([self.raw_flats[k].file_path()], **reader_kwargs) for k in self.raw_flats.keys()}
return {k: readers[k].load_data()[0] for k in self.raw_flats.keys()}
[docs]
def get_reduced_darks(self, **reader_kwargs):
# See notes in get_reduced_flats() above
if self.raw_darks in [None, {}]:
raise FileNotFoundError("No reduced dark ('darkend.edf' or 'dark.edf') found in %s" % self.location)
readers = {k: EDFStackReader([self.raw_darks[k].file_path()], **reader_kwargs) for k in self.raw_darks.keys()}
return {k: readers[k].load_data()[0] for k in self.raw_darks.keys()}
@property
def files(self):
return sorted([u.file_path() for u in self.dataset_scanner.projections.values()])
[docs]
def get_reader(self, **kwargs):
return EDFStackReader(self.files, **kwargs)
[docs]
class HDF5DatasetAnalyzer(DatasetAnalyzer):
"""
HDF5 dataset analyzer
"""
_scanner = NXtomoScan
kind = "nx"
# We could import the 1000+ LoC nxtomo.nxobject.nxdetector.ImageKey... or we can do this
_image_key_value = {"flats": 1, "darks": 2, "radios": 0}
#
@property
def z_translation(self):
raw_data = np.array(self.dataset_scanner.z_translation)
projs_idx = np.array(list(self.projections.keys()))
filtered_data = raw_data[projs_idx]
return 1.0e6 * filtered_data / self.pixel_size
@property
def x_translation(self):
raw_data = np.array(self.dataset_scanner.x_translation)
projs_idx = np.array(list(self.projections.keys()))
filtered_data = raw_data[projs_idx]
return 1.0e6 * filtered_data / self.pixel_size
def _get_rotation_angles(self):
if self._rotation_angles is None:
angles = np.array(self.dataset_scanner.rotation_angle)
projs_idx = np.array(list(self.projections.keys()))
angles = angles[projs_idx]
self._rotation_angles = np.deg2rad(angles)
return self._rotation_angles
def _get_dataset_hdf5_url(self):
if len(self.projections) > 0:
frames_to_take = self.projections
elif len(self.raw_flats) > 0:
frames_to_take = self.raw_flats
elif len(self.raw_darks) > 0:
frames_to_take = self.raw_darks
else:
raise ValueError("No projections, no flats and no darks ?!")
first_proj_idx = sorted(frames_to_take.keys())[0]
first_proj_url = frames_to_take[first_proj_idx]
return DataUrl(
file_path=first_proj_url.file_path(), data_path=first_proj_url.data_path(), data_slice=None, scheme="silx"
)
@property
def dataset_hdf5_url(self):
return self._get_dataset_hdf5_url()
@property
def hdf5_entry(self):
"""
Return the HDF5 entry of the current dataset
"""
return self.dataset_scanner.entry
def _is_halftomo(self):
try:
is_halftomo = self.dataset_scanner.field_of_view.value.lower() == "half"
except:
is_halftomo = None
return is_halftomo
[docs]
def get_data_slices(self, what):
"""
Return indices in the data volume where images correspond to a given kind.
Parameters
----------
what: str
Which keys to get. Can be "projections", "flats", "darks"
Returns
--------
slices: list of slice
A list where each item is a slice.
"""
name_to_attr = {
"projections": self.projections,
"flats": self.raw_flats,
"darks": self.raw_darks,
}
check_supported(what, name_to_attr.keys(), "image type")
images = name_to_attr[what] # dict
# we can't directly use set() on slice() object (unhashable). Use tuples
slices = set()
for du in get_compacted_dataslices(images).values():
if du.data_slice() is not None:
s = (du.data_slice().start, du.data_slice().stop)
else:
s = None
slices.add(s)
slices_list = [slice(item[0], item[1]) if item is not None else None for item in list(slices)]
return slices_list
def _select_according_to_frame_type(self, data, frame_type):
if data is None:
return None
return data[self.dataset_scanner.image_key_control == self._image_key_value[frame_type]]
[docs]
def get_reduced_flats(self, method="median", force_reload=False, **reader_kwargs):
dkrf_reader = NXDarksFlats(
self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **reader_kwargs
)
return dkrf_reader.get_reduced_flats(method=method, force_reload=force_reload, as_dict=True)
[docs]
def get_reduced_darks(self, method="mean", force_reload=False, **reader_kwargs):
dkrf_reader = NXDarksFlats(
self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **reader_kwargs
)
return dkrf_reader.get_reduced_darks(method=method, force_reload=force_reload, as_dict=True)
def _get_srcurrent(self, frame_type):
return self._select_according_to_frame_type(self.dataset_scanner.electric_current, frame_type)
[docs]
def frames_slices(self, frame_type):
"""
Return a list of slice objects corresponding to the data corresponding to "frame_type".
For example, if the dataset flats are located at indices [1, 2, ..., 99], then
frame_slices("flats") will return [slice(0, 100)].
"""
return indices_to_slices(
np.where(self.dataset_scanner.image_key_control == self._image_key_value[frame_type])[0]
)
[docs]
def get_reader(self, **kwargs):
return NXTomoReader(self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **kwargs)
[docs]
def analyze_dataset(dataset_path, extra_options=None, logger=None):
if not (os.path.isdir(dataset_path)):
if not (os.path.isfile(dataset_path)):
raise ValueError("Error: %s no such file or directory" % dataset_path)
if not (is_hdf5_extension(os.path.splitext(dataset_path)[-1].replace(".", ""))):
raise ValueError("Error: expected a HDF5 file")
dataset_analyzer_class = HDF5DatasetAnalyzer
else: # directory -> assuming EDF
dataset_analyzer_class = EDFDatasetAnalyzer
dataset_structure = dataset_analyzer_class(dataset_path, extra_options=extra_options, logger=logger)
return dataset_structure
[docs]
def get_radio_pair(dataset_info, radio_angles: tuple, return_indices=False):
"""
Get closest radios at radio_angles[0] and radio_angles[1]
angles must be in angles
Parameters
----------
dataset_info: `DatasetAnalyzer` instance
Data structure with the dataset information
radio_angles: tuple
tuple of two elements: angles (in radian) to get
return_indices: bool, optional
Whether to return radios indices along with the radios array.
Returns
-------
res: array or tuple
If return_indices is True, return a tuple (radios, indices).
Otherwise, return an array with the radios.
"""
if not (isinstance(radio_angles, tuple) and len(radio_angles) == 2):
raise TypeError("radio_angles should be a tuple of two elements.")
if not isinstance(radio_angles[0], (np.floating, float)) or not isinstance(radio_angles[1], (np.floating, float)):
raise TypeError(
f"radio_angles should be float. Get {type(radio_angles[0])} and {type(radio_angles[1])} instead"
)
radios_indices = []
radios_indices = sorted(dataset_info.projections.keys())
angles = dataset_info.rotation_angles
angles = angles - angles.min()
i_radio_1 = np.argmin(np.abs(angles - radio_angles[0]))
i_radio_2 = np.argmin(np.abs(angles - radio_angles[1]))
radios_indices = [radios_indices[i_radio_1], radios_indices[i_radio_2]]
n_radios = 2
radios = np.zeros((n_radios,) + dataset_info.radio_dims[::-1], "f")
for i in range(n_radios):
radio_idx = radios_indices[i]
radios[i] = get_data(dataset_info.projections[radio_idx]).astype("f")
if return_indices:
return radios, radios_indices
else:
return radios