Source code for nabu.resources.dataset_analyzer

import os
from bisect import bisect_left
import numpy as np
from silx.io import get_data
from silx.io.url import DataUrl
from tomoscan.esrf.scan.edfscan import EDFTomoScan
from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
from ..utils import check_supported
from ..io.utils import get_compacted_dataslices
from .utils import is_hdf5_extension, get_values_from_file
from .logger import LoggerOrPrint

from ..pipeline.utils import nabu_env_settings


[docs] class DatasetAnalyzer: _scanner = None kind = "none" """ Base class for datasets analyzers. """ def __init__(self, location, extra_options=None, logger=None): """ Initialize a Dataset analyzer. Parameters ---------- location: str Dataset location (directory or file name) extra_options: dict, optional Extra options on how to interpret the dataset. logger: logging object, optional Logger. If not set, messages will just be printed in stdout. """ self.logger = LoggerOrPrint(logger) self.location = location self._set_extra_options(extra_options) self._get_excluded_projections() self._set_default_dataset_values() self._init_dataset_scan() self._finish_init() def _set_extra_options(self, extra_options): if extra_options is None: extra_options = {} # COMPAT. advanced_options = { "force_flatfield": False, "output_dir": None, "exclude_projections": None, "hdf5_entry": None, "nx_version": 1.0, } # -- advanced_options.update(extra_options) self.extra_options = advanced_options # pylint: disable=E1136 def _get_excluded_projections(self): self._ignore_projections_indices = None self._need_rebuild_tomoscan_object_to_exclude_projections = False excluded_projs = self.extra_options["exclude_projections"] if excluded_projs is None: return if excluded_projs["type"] == "indices": projs_idx = get_values_from_file(excluded_projs["file"], any_size=True).astype(np.int32).tolist() self._ignore_projections_indices = projs_idx else: self._need_rebuild_tomoscan_object_to_exclude_projections = True def _init_dataset_scan(self, **kwargs): if self._scanner is None: raise ValueError("Base class") if self._scanner is NXtomoScan: if self.extra_options.get("hdf5_entry", None) is not None: kwargs["entry"] = self.extra_options["hdf5_entry"] if self.extra_options.get("nx_version", None) is not None: kwargs["nx_version"] = self.extra_options["nx_version"] if self._scanner is EDFTomoScan: # Assume 1 frame per file (otherwise too long to open each file) kwargs["n_frames"] = 1 self.dataset_scanner = self._scanner( # pylint: disable=E1102 self.location, ignore_projections=self._ignore_projections_indices, **kwargs ) self.projections = self.dataset_scanner.projections # --- if self._need_rebuild_tomoscan_object_to_exclude_projections: # pylint: disable=E1136 exclude_projs = self.extra_options["exclude_projections"] rot_angles_deg = np.rad2deg(self.rotation_angles) self._rotation_angles = None # prevent caching # tomoscan only supports ignore_projections=<list of integers> # However this is cumbersome to use, it's more convenient to use angular range or list of angles # But having angles instead of indices implies to already have information on current scan angular range ignore_projections_indices = [] if exclude_projs["type"] == "angular_range": exclude_angle_min, exclude_angle_max = exclude_projs["range"] projections_indices = np.array(sorted(self.dataset_scanner.projections.keys())) for proj_idx, angle in zip(projections_indices, rot_angles_deg): if exclude_angle_min <= angle and angle <= exclude_angle_max: ignore_projections_indices.append(proj_idx) elif exclude_projs["type"] == "angles": excluded_angles = get_values_from_file(exclude_projs["file"], any_size=True).astype(np.float32).tolist() for excluded_angle in excluded_angles: proj_idx = bisect_left(rot_angles_deg, excluded_angle) if proj_idx < rot_angles_deg.size: ignore_projections_indices.append(proj_idx) # Rebuild the dataset_scanner instance self._ignore_projections_indices = ignore_projections_indices self.dataset_scanner = self._scanner( # pylint: disable=E1102 self.location, ignore_projections=self._ignore_projections_indices, **kwargs ) # --- if self._ignore_projections_indices is not None: self.logger.info("Excluding projections: %s" % str(self._ignore_projections_indices)) if nabu_env_settings.skip_tomoscan_checks: self.logger.warning( " WARNING: according to nabu_env_settings.skip_tomoscan_checks, skipping virtual layout integrity check of tomoscan which is time consuming" ) self.dataset_scanner.set_check_behavior(run_check=False, raise_error=False) self.projections = self.dataset_scanner.projections self.flats = self.dataset_scanner.flats self.darks = self.dataset_scanner.darks self.n_angles = len(self.dataset_scanner.projections) self.radio_dims = (self.dataset_scanner.dim_1, self.dataset_scanner.dim_2) self._radio_dims_notbinned = self.radio_dims # COMPAT def _finish_init(self): pass def _set_default_dataset_values(self): self._detector_tilt = None self.translations = None self.ctf_translations = None self.axis_position = None self._rotation_angles = None self.z_per_proj = None self.x_per_proj = None self._energy = None self._pixel_size = None self._distance = None self._flats_srcurrent = None self._projections_srcurrent = None @property def energy(self): """ Return the energy in kev. """ if self._energy is None: self._energy = self.dataset_scanner.energy return self._energy @energy.setter def energy(self, val): self._energy = val @property def distance(self): """ Return the sample-detector distance in meters. """ if self._distance is None: self._distance = abs(self.dataset_scanner.distance) return self._distance @distance.setter def distance(self, val): self._distance = val @property def pixel_size(self): """ Return the pixel size in microns. """ # TODO X and Y pixel size if self._pixel_size is None: self._pixel_size = self.dataset_scanner.pixel_size * 1e6 return self._pixel_size @pixel_size.setter def pixel_size(self, val): self._pixel_size = val def _get_rotation_angles(self): return self._rotation_angles # None by default @property def rotation_angles(self): """ Return the rotation angles in radians. """ return self._get_rotation_angles() @rotation_angles.setter def rotation_angles(self, angles): self._rotation_angles = angles def _is_halftomo(self): return None # base class @property def is_halftomo(self): """ Indicates whether the current dataset was performed with half acquisition. """ return self._is_halftomo() @property def detector_tilt(self): """ Return the detector tilt in degrees """ return self._detector_tilt @detector_tilt.setter def detector_tilt(self, tilt): self._detector_tilt = tilt def _get_srcurrent(self, indices): srcurrent = self.dataset_scanner.electric_current if srcurrent is None or len(srcurrent) == 0: return None srcurrent_all = np.array(srcurrent) if np.any(indices >= len(srcurrent_all)): self.logger.error("Something wrong with SRCurrent: not enough values!") return None return srcurrent_all[indices].astype("f") @property def projections_srcurrent(self): """ Return the synchrotron electric current for each projection. """ if self._projections_srcurrent is None: projections_indices = np.array(sorted(self.projections.keys())) self._projections_srcurrent = self._get_srcurrent(projections_indices) return self._projections_srcurrent @projections_srcurrent.setter def projections_srcurrent(self, val): self._projections_srcurrent = val @property def flats_srcurrent(self): """ Return the synchrotron electric current for each flat image. """ if self._flats_srcurrent is None: flats_indices = np.array(sorted(self.flats.keys())) self._flats_srcurrent = self._get_srcurrent(flats_indices) return self._flats_srcurrent @flats_srcurrent.setter def flats_srcurrent(self, val): self._flats_srcurrent = val
[docs] def check_defined_attribute(self, name, error_msg=None): """ Utility function to check that a given attribute is defined. """ if getattr(self, name, None) is None: raise ValueError(error_msg or str("No information on %s was found in the dataset" % name))
[docs] class EDFDatasetAnalyzer(DatasetAnalyzer): """ EDF Dataset analyzer for legacy ESRF acquisitions """ _scanner = EDFTomoScan kind = "edf" def _finish_init(self): self.remove_unused_radios()
[docs] def remove_unused_radios(self): """ Remove "unused" radios. This is used for legacy ESRF scans. """ # Extraneous projections are assumed to be on the end projs_indices = sorted(self.projections.keys()) used_radios_range = range(projs_indices[0], len(self.projections)) radios_not_used = [] for idx in self.projections.keys(): if idx not in used_radios_range: radios_not_used.append(idx) for idx in radios_not_used: self.projections.pop(idx) return radios_not_used
def _get_flats_darks(self): return @property def hdf5_entry(self): """ Return the HDF5 entry of the current dataset. Not applicable for EDF (return None) """ return None def _is_halftomo(self): return None def _get_rotation_angles(self): if self._rotation_angles is None: scan_range = self.dataset_scanner.scan_range if scan_range is not None: fullturn = abs(scan_range - 360) < abs(scan_range - 180) angles = np.linspace(0, scan_range, num=len(self.projections), endpoint=fullturn, dtype="f") self._rotation_angles = np.deg2rad(angles) return self._rotation_angles
[docs] class HDF5DatasetAnalyzer(DatasetAnalyzer): """ HDF5 dataset analyzer """ _scanner = NXtomoScan kind = "hdf5" @property def z_translation(self): raw_data = np.array(self.dataset_scanner.z_translation) projs_idx = np.array(list(self.projections.keys())) filtered_data = raw_data[projs_idx] return 1.0e6 * filtered_data / self.pixel_size @property def x_translation(self): raw_data = np.array(self.dataset_scanner.x_translation) projs_idx = np.array(list(self.projections.keys())) filtered_data = raw_data[projs_idx] return 1.0e6 * filtered_data / self.pixel_size def _get_rotation_angles(self): if self._rotation_angles is None: angles = np.array(self.dataset_scanner.rotation_angle) projs_idx = np.array(list(self.projections.keys())) angles = angles[projs_idx] self._rotation_angles = np.deg2rad(angles) return self._rotation_angles def _get_dataset_hdf5_url(self): if len(self.projections) > 0: frames_to_take = self.projections elif len(self.flats) > 0: frames_to_take = self.flats elif len(self.darks) > 0: frames_to_take = self.darks else: raise ValueError("No projections, no flats and no darks ?!") first_proj_idx = sorted(frames_to_take.keys())[0] first_proj_url = frames_to_take[first_proj_idx] return DataUrl( file_path=first_proj_url.file_path(), data_path=first_proj_url.data_path(), data_slice=None, scheme="silx" ) @property def dataset_hdf5_url(self): return self._get_dataset_hdf5_url() @property def hdf5_entry(self): """ Return the HDF5 entry of the current dataset """ return self.dataset_scanner.entry def _is_halftomo(self): try: is_halftomo = self.dataset_scanner.field_of_view.value.lower() == "half" except: is_halftomo = None return is_halftomo
[docs] def get_data_slices(self, what): """ Return indices in the data volume where images correspond to a given kind. Parameters ---------- what: str Which keys to get. Can be "projections", "flats", "darks" Returns -------- slices: list of slice A list where each item is a slice. """ check_supported(what, ["projections", "flats", "darks"], "image type") images = getattr(self, what) # dict # we can't directly use set() on slice() object (unhashable). Use tuples slices = set() for du in get_compacted_dataslices(images).values(): if du.data_slice() is not None: s = (du.data_slice().start, du.data_slice().stop) else: s = None slices.add(s) slices_list = [slice(item[0], item[1]) if item is not None else None for item in list(slices)] return slices_list
[docs] def analyze_dataset(dataset_path, extra_options=None, logger=None): if not (os.path.isdir(dataset_path)): if not (os.path.isfile(dataset_path)): raise ValueError("Error: %s no such file or directory" % dataset_path) if not (is_hdf5_extension(os.path.splitext(dataset_path)[-1].replace(".", ""))): raise ValueError("Error: expected a HDF5 file") dataset_analyzer_class = HDF5DatasetAnalyzer else: # directory -> assuming EDF dataset_analyzer_class = EDFDatasetAnalyzer dataset_structure = dataset_analyzer_class(dataset_path, extra_options=extra_options, logger=logger) return dataset_structure
[docs] def get_radio_pair(dataset_info, radio_angles: tuple, return_indices=False): """ Get closest radios at radio_angles[0] and radio_angles[1] angles must be in angles Parameters ---------- dataset_info: `DatasetAnalyzer` instance Data structure with the dataset information radio_angles: tuple tuple of two elements: angles (in radian) to get return_indices: bool, optional Whether to return radios indices along with the radios array. Returns ------- res: array or tuple If return_indices is True, return a tuple (radios, indices). Otherwise, return an array with the radios. """ if not (isinstance(radio_angles, tuple) and len(radio_angles) == 2): raise TypeError("radio_angles should be a tuple of two elements.") if not isinstance(radio_angles[0], (np.floating, float)) or not isinstance(radio_angles[1], (np.floating, float)): raise TypeError( f"radio_angles should be float. Get {type(radio_angles[0])} and {type(radio_angles[1])} instead" ) radios_indices = [] radios_indices = sorted(dataset_info.projections.keys()) angles = dataset_info.rotation_angles angles = angles - angles.min() i_radio_1 = np.argmin(np.abs(angles - radio_angles[0])) i_radio_2 = np.argmin(np.abs(angles - radio_angles[1])) radios_indices = [radios_indices[i_radio_1], radios_indices[i_radio_2]] n_radios = 2 radios = np.zeros((n_radios,) + dataset_info.radio_dims[::-1], "f") for i in range(n_radios): radio_idx = radios_indices[i] radios[i] = get_data(dataset_info.projections[radio_idx]).astype("f") if return_indices: return radios, radios_indices else: return radios