Source code for nabu.stitching.alignment

import h5py
import numpy
from typing import Union
from silx.utils.enum import Enum as _Enum
from tomoscan.volumebase import VolumeBase
from tomoscan.esrf.volume.hdf5volume import HDF5Volume
from nabu.io.utils import DatasetReader


[docs] class AlignmentAxis2(_Enum): CENTER = "center" LEFT = "left" RIGTH = "right"
[docs] class AlignmentAxis1(_Enum): FRONT = "front" CENTER = "center" BACK = "back"
[docs] def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width: int, pad_mode="constant"): """ Align data horizontally to make sure new data width will ne `new_width`. :param numpy.ndarray data: data to align :param HAlignment alignment: alignment strategy :param int new_width: output data width """ current_width = data.shape[-1] alignment = AlignmentAxis2.from_value(alignment) if current_width > new_width: raise ValueError(f"data.shape[-1] ({data.shape[-1]}) > new_width ({new_width}). Unable to crop data") elif current_width == new_width: return data else: if alignment is AlignmentAxis2.CENTER: left_width = (new_width - current_width) // 2 right_width = (new_width - current_width) - left_width elif alignment is AlignmentAxis2.LEFT: left_width = 0 right_width = new_width - current_width elif alignment is AlignmentAxis2.RIGTH: left_width = new_width - current_width right_width = 0 else: raise ValueError(f"alignment {alignment.value} is not handled") assert left_width >= 0, f"pad width must be positive - left width isn't ({left_width})" assert right_width >= 0, f"pad width must be positive - right width isn't ({right_width})" return numpy.pad( data, pad_width=((0, 0), (left_width, right_width)), mode=pad_mode, )
[docs] class PaddedRawData: """ Util class to extend a data when necessary Must to aplpy to a volume and to an hdf5dataset - array The idea behind is to avoid loading all the data in memory """ def __init__(self, data: Union[numpy.ndarray, h5py.Dataset], axis_1_pad_width: tuple) -> None: self._axis_1_pad_width = numpy.array(axis_1_pad_width) if not (self._axis_1_pad_width.size == 2 and self._axis_1_pad_width[0] >= 0 and self._axis_1_pad_width[1] >= 0): raise ValueError(f"'axis_1_pad_width' expects to positive elements. Get {axis_1_pad_width}") self._raw_data = data self._raw_data_end = None # note: for now we return only frames with zeros for padded frames. # in the future we could imagine having a method and miror existing volume or extend the closest frame, or get a mean value... self._empty_frame = None self._dtype = None self._shape = None self._raw_data_shape = self.raw_data.shape
[docs] @staticmethod def get_empty_frame(shape, dtype): return numpy.zeros( shape=shape, dtype=dtype, )
@property def empty_frame(self): if self._empty_frame is None: self._empty_frame = self.get_empty_frame( shape=(self.shape[0], 1, self.shape[2]), dtype=self.dtype, ) return self._empty_frame @property def shape(self): if self._shape is None: self._shape = tuple( ( self._raw_data_shape[0], numpy.sum( numpy.array(self._axis_1_pad_width), ) + self._raw_data_shape[1], self._raw_data_shape[2], ) ) return self._shape @property def raw_data(self): return self._raw_data @property def raw_data_start(self): return self._axis_1_pad_width[0] @property def raw_data_end(self): if self._raw_data_end is None: self._raw_data_end = self._axis_1_pad_width[0] + self._raw_data_shape[1] return self._raw_data_end @property def dtype(self): if self._dtype is None: self._dtype = self.raw_data.dtype return self._dtype def __getitem__(self, args): if not isinstance(args, tuple) and len(args) == 3: raise ValueError("only handles 3D slicing") elif not (args[0] == slice(None, None, None) and args[2] == slice(None, None, None)): raise ValueError( "slicing only handled along axis 1. First and third tuple item are expected to be empty slice as slice(None, None, None)" ) else: if numpy.isscalar(args[1]): args = ( args[0], slice(args[1], args[1] + 1, 1), args[2], ) start = args[1].start if start is None: start = 0 stop = args[1].stop if stop is None: stop = self.shape[1] step = args[1].step # some test if start < 0 or stop < 0: raise ValueError("only positive position are handled") if start >= stop: raise ValueError("start >= stop") if stop > self.shape[1]: raise ValueError("stop > self.shape[1]") if step not in (1, None): raise ValueError("for now PaddedVolume only handles steps of 1") first_part_array = None if start < self.raw_data_start and (stop - start > 0): stop_first_part = min(stop, self.raw_data_start) first_part_array = numpy.repeat(self.empty_frame, repeats=stop_first_part - start, axis=1) start = stop_first_part third_part_array = None if stop > self.raw_data_end and (stop - start > 0): if stop > self.shape[1]: raise ValueError("requested slice is out of boundaries") start_third_part = max(start, self.raw_data_end) third_part_array = numpy.repeat(self.empty_frame, repeats=stop - start_third_part, axis=1) stop = self.raw_data_end if start >= self.raw_data_start and stop >= self.raw_data_start and (stop - start > 0): second_part_array = self.raw_data[:, start - self.raw_data_start : stop - self.raw_data_start, :] else: second_part_array = None parts = tuple(filter(lambda a: a is not None, (first_part_array, second_part_array, third_part_array))) return numpy.hstack( parts, )