Source code for nabu.stitching.overlap

import numpy
import logging
from typing import Optional, Union
from silx.utils.enum import Enum as _Enum
from nabu.misc import fourier_filters
from scipy.fft import rfftn as local_fftn
from scipy.fft import irfftn as local_ifftn
from tomoscan.utils.geometry import BoundingBox1D

_logger = logging.getLogger(__name__)


[docs] class OverlapStitchingStrategy(_Enum): MEAN = "mean" COSINUS_WEIGHTS = "cosinus weights" LINEAR_WEIGHTS = "linear weights" CLOSEST = "closest" IMAGE_MINIMUM_DIVERGENCE = "image minimum divergence" HIGHER_SIGNAL = "higher signal"
DEFAULT_OVERLAP_STRATEGY = OverlapStitchingStrategy.COSINUS_WEIGHTS DEFAULT_OVERLAP_SIZE = None # could also be an int # default overlap to be take for stitching. Ig None: take the largest possible area
[docs] class OverlapKernelBase: pass
[docs] class ImageStichOverlapKernel(OverlapKernelBase): """ Stitch two images along Y (axis 0 in image space) """ DEFAULT_HIGH_FREQUENCY_THRESHOLD = 2 def __init__( self, stitching_axis: int, frame_unstitched_axis_size: tuple, stitching_strategy: OverlapStitchingStrategy = DEFAULT_OVERLAP_STRATEGY, overlap_size: int = DEFAULT_OVERLAP_SIZE, extra_params: Optional[dict] = None, ) -> None: """ :param stitching_axis: axis along which stitching is operate. Must be in '0', '1' :param frame_unstitched_axis_size: according to the stitching axis the stitched framed will always have a constant size: * If stitching_axis == 0 then it will be the frame width * If stitching_axis == 1 then it will be the frame height :param stitching_strategy: stategy / algorithm to use in order to generate the stitching :param overlap_size: size (int) of the overlap (stitching) between the two images :param extra_params: possibly extra parameters to operate the stitching """ from nabu.stitching.config import KEY_THRESHOLD_FREQUENCY # avoid acylic import if not isinstance(overlap_size, int) and overlap_size > 0: raise TypeError( f"overlap_size is expected to be a positive int, {overlap_size} - not {overlap_size} ({type(overlap_size)})" ) if not isinstance(frame_unstitched_axis_size, int) or not frame_unstitched_axis_size > 0: raise TypeError( f"frame_width is expected to be a positive int, {frame_unstitched_axis_size} - not {frame_unstitched_axis_size} ({type(frame_unstitched_axis_size)})" ) if not stitching_axis in (0, 1): raise ValueError( "stitching_axis is expected to be the axis along which stitching must be done. It should be '0' or '1'" ) self._stitching_axis = stitching_axis self._overlap_size = abs(overlap_size) self._frame_unstitched_axis_size = frame_unstitched_axis_size self._stitching_strategy = OverlapStitchingStrategy.from_value(stitching_strategy) self._weights_img_1 = None self._weights_img_2 = None if extra_params is None: extra_params = {} self._high_frequency_threshold = extra_params.get( KEY_THRESHOLD_FREQUENCY, self.DEFAULT_HIGH_FREQUENCY_THRESHOLD ) def __str__(self) -> str: return f"z-stitching kernel (policy={self.stitching_strategy.value}, overlap_size={self.overlap_size}, frame={self._frame_unstitched_axis_size})" @staticmethod def __check_img(img, name): if not isinstance(img, numpy.ndarray) and img.ndim == 2: raise ValueError(f"{name} is expected to be 2D numpy array") @property def stitched_axis(self) -> int: return self._stitching_axis @property def unstitched_axis(self) -> int: """ util function. The kernel is operating stitching on images along a single axis (`stitching_axis`). This property is returning the other axis. """ if self.stitched_axis == 0: return 1 else: return 0 @property def overlap_size(self) -> int: return self._overlap_size @overlap_size.setter def overlap_size(self, size: int): if not isinstance(size, int): raise TypeError(f"height expects a int ({type(size)} provided instead)") if not size >= 0: raise ValueError(f"height is expected to be positive") self._overlap_size = abs(size) # update weights if needed if self._weights_img_1 is not None or self._weights_img_2 is not None: self.compute_weights() @property def img_2(self) -> numpy.ndarray: return self._img_2 @property def weights_img_1(self) -> Optional[numpy.ndarray]: return self._weights_img_1 @property def weights_img_2(self) -> Optional[numpy.ndarray]: return self._weights_img_2 @property def stitching_strategy(self) -> OverlapStitchingStrategy: return self._stitching_strategy
[docs] def compute_weights(self): if self.stitching_strategy is OverlapStitchingStrategy.MEAN: weights_img_1 = numpy.ones(self._overlap_size) * 0.5 weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStitchingStrategy.CLOSEST: n_item = self._overlap_size // 2 + self._overlap_size % 2 weights_img_1 = numpy.concatenate( [ numpy.ones(n_item), numpy.zeros(self._overlap_size - n_item), ] ) weights_img_2 = numpy.concatenate( [ numpy.zeros(n_item), numpy.ones(self._overlap_size - n_item), ] ) elif self.stitching_strategy is OverlapStitchingStrategy.LINEAR_WEIGHTS: weights_img_1 = numpy.linspace(1.0, 0.0, self._overlap_size) weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStitchingStrategy.COSINUS_WEIGHTS: angles = numpy.linspace(0.0, numpy.pi / 2.0, self._overlap_size) weights_img_1 = numpy.cos(angles) ** 2 weights_img_2 = numpy.sin(angles) ** 2 elif self.stitching_strategy in ( OverlapStitchingStrategy.IMAGE_MINIMUM_DIVERGENCE, OverlapStitchingStrategy.HIGHER_SIGNAL, ): # those strategies are not using constant weights but have treatments depending on the provided img_1 and mg_2 during stitching return else: raise NotImplementedError(f"{self.stitching_strategy} not implemented") if self._stitching_axis == 0: self._weights_img_1 = weights_img_1.reshape(-1, 1) * numpy.ones(self._frame_unstitched_axis_size).reshape( 1, -1 ) self._weights_img_2 = weights_img_2.reshape(-1, 1) * numpy.ones(self._frame_unstitched_axis_size).reshape( 1, -1 ) elif self._stitching_axis == 1: self._weights_img_1 = weights_img_1.reshape(1, -1) * numpy.ones(self._frame_unstitched_axis_size).reshape( -1, 1 ) self._weights_img_2 = weights_img_2.reshape(1, -1) * numpy.ones(self._frame_unstitched_axis_size).reshape( -1, 1 ) else: raise ValueError(f"stitching_axis should be in (0, 1). {self._stitching_axis} provided")
[docs] def stitch(self, img_1, img_2, check_input=True) -> tuple: """Compute overlap region from the defined strategy""" if check_input: self.__check_img(img_1, "img_1") self.__check_img(img_2, "img_2") if img_1.shape != img_2.shape: raise ValueError( f"both images are expected to be of the same shape to apply stitch ({img_1.shape} vs {img_2.shape})" ) if self._stitching_strategy is OverlapStitchingStrategy.IMAGE_MINIMUM_DIVERGENCE: return ( compute_image_minimum_divergence( img_1=img_1, img_2=img_2, high_frequency_threshold=self._high_frequency_threshold, stitching_axis=self.stitched_axis, ), None, None, ) elif self._stitching_strategy is OverlapStitchingStrategy.HIGHER_SIGNAL: return ( compute_image_higher_signal( img_1=img_1, img_2=img_2, ), None, None, ) else: if self.weights_img_1 is None or self.weights_img_2 is None: self.compute_weights() return ( img_1 * self.weights_img_1 + img_2 * self.weights_img_2, self.weights_img_1, self.weights_img_2, )
[docs] def compute_image_minimum_divergence( img_1: numpy.ndarray, img_2: numpy.ndarray, high_frequency_threshold, stitching_axis: int ): """ Algorithm to improve treatment of high frequency. It split the two images into two parts: high frequency and low frequency. The two low frequency part will be stitched using a 'sinusoidal' / cosinus weights approach. When the two high frequency parts will be stitched by taking the lower divergent pixels """ # split low and high frequencies def split_image(image: numpy.ndarray, threshold: int) -> tuple: """split an image to return (low_frequency, high_frequency)""" lowpass_filter = fourier_filters.get_lowpass_filter( image.shape[-2:], cutoff_par=threshold, use_rfft=True, data_type=image.dtype, ) highpass_filter = fourier_filters.get_highpass_filter( image.shape[-2:], cutoff_par=threshold, use_rfft=True, data_type=image.dtype, ) low_fre_part = local_ifftn(local_fftn(image, axes=(-2, -1)) * lowpass_filter, axes=(-2, -1)).real high_fre_part = local_ifftn(local_fftn(image, axes=(-2, -1)) * highpass_filter, axes=(-2, -1)).real return (low_fre_part, high_fre_part) low_freq_img_1, high_freq_img_1 = split_image(img_1, threshold=high_frequency_threshold) low_freq_img_2, high_freq_img_2 = split_image(img_2, threshold=high_frequency_threshold) # handle low frequency if stitching_axis == 0: frame_cst_size = img_1.shape[1] overlap_size = img_1.shape[0] elif stitching_axis == 1: frame_cst_size = img_1.shape[0] overlap_size = img_1.shape[1] else: raise ValueError("") low_freq_stitching_kernel = ImageStichOverlapKernel( frame_unstitched_axis_size=frame_cst_size, stitching_strategy=OverlapStitchingStrategy.COSINUS_WEIGHTS, overlap_size=overlap_size, stitching_axis=stitching_axis, ) low_freq_stitched = low_freq_stitching_kernel.stitch( img_1=low_freq_img_1, img_2=low_freq_img_2, check_input=False, )[0] # handle high frequency mean_high_frequency = numpy.mean([high_freq_img_1, high_freq_img_2]) assert numpy.isscalar(mean_high_frequency) high_freq_distance_img_1 = numpy.abs(high_freq_img_1 - mean_high_frequency) high_freq_distance_img_2 = numpy.abs(high_freq_img_2 - mean_high_frequency) high_freq_stitched = numpy.where( high_freq_distance_img_1 >= high_freq_distance_img_2, high_freq_distance_img_2, high_freq_distance_img_1 ) # merge back low and high frequencies together def merge_images(low_freq: numpy.ndarray, high_freq: numpy.ndarray) -> numpy.ndarray: """merge two part of an image. The low frequency part with the high frequency part""" return low_freq + high_freq return merge_images(low_freq_stitched, high_freq_stitched)
[docs] def compute_image_higher_signal(img_1: numpy.ndarray, img_2: numpy.ndarray): """ the higher signal will pick pixel on the image having the higher signal. A use case is that if there is some artefacts on images which creates stripes (from scintillator artefacts for example) it could be removed from this method """ # note: to be think about. But maybe it can be interesting to rescale img_1 and img_2 # to ge something more coherent return numpy.where(img_1 >= img_2, img_1, img_2)
[docs] def check_overlaps(frames: Union[tuple, numpy.ndarray], positions: tuple, axis: int, raise_error: bool): """ check over frames if there is a single overlap other juxtaposed frames (at most and at least) :param frames: liste of ordered / sorted frames along axis to test (from higher to lower) :param positions: positions of frames in 3D space as (position axis 0, position axis 1, position axis 2) :param axis: axis to check :param raise_error: if True then raise an error if two frames don't have at least and at most one overlap. Else log an error """ if not isinstance(frames, (tuple, numpy.ndarray)): raise TypeError(f"frames is expected to be a tuple or a numpy array. Get {type(frames)} instead") if not isinstance(positions, tuple) and len(positions) == 3: raise TypeError(f"positions is expected to be a tuple of 3 elements. Get {type(positions)} instead") assert isinstance(axis, int), "axis is expected to be an int" assert isinstance(raise_error, bool), "raise_error is expected to be a bool" def treat_error(error_msg: str): if raise_error: raise ValueError(error_msg) else: _logger.error(raise_error) if axis == 0: axis_frame_space = 0 elif axis == 2: raise NotImplementedError(f"overlap check along axis {axis_frame_space}") elif axis == 1: axis_frame_space = 1 # convert each frame to appropriate bounding box according to the axis def convert_to_bb(frame: numpy.ndarray, position: tuple, axis: int): assert isinstance(axis, int) assert isinstance(position, tuple), f"position expected a tuple. Get {type(position)} instead" assert len(position) == 3, f"Expect to have three items for the position. Get {len(position)}" start_frame = position[axis] - frame.shape[axis_frame_space] // 2 end_frame = start_frame + frame.shape[axis_frame_space] return BoundingBox1D(start_frame, end_frame) bounding_boxes = { convert_to_bb(frame=frame, position=position, axis=axis): position for frame, position in zip(frames, positions) } def get_frame_index(my_bb) -> str: bb_index = tuple(bounding_boxes.keys()).index(my_bb) + 1 if bb_index in (1, 21, 31): return f"{bb_index}st" elif bb_index in (2, 22, 32): return f"{bb_index}nd" elif bb_index == (3, 23, 33): return f"{bb_index}rd" else: return f"{bb_index}th" # check that theres an overlap between two juxtaposed bb (or frame at the end) all_bounding_boxes = tuple(bounding_boxes.keys()) bb_with_expected_overlap = [ (bb_frame, bb_next_frame) for bb_frame, bb_next_frame in zip(all_bounding_boxes[:-1], all_bounding_boxes[1:]) ] for bb_pair in bb_with_expected_overlap: bb_frame, bb_next_frame = bb_pair if bb_frame.max < bb_next_frame.min: treat_error(f"provided frames seems un sorted (from the higher to the lower)") if bb_frame.min < bb_next_frame.min: treat_error( f"Seems like {get_frame_index(bb_frame)} frame is fully overlaping with frame {get_frame_index(bb_next_frame)}" ) if bb_frame.get_overlap(bb_next_frame) is None: treat_error( f"no overlap found between two juxtaposed frames - {get_frame_index(bb_frame)} and {get_frame_index(bb_next_frame)}" ) # check there is no overlap between none juxtaposed bb def pick_all_none_juxtaposed_bb(index, my_bounding_boxes: tuple): """return all the bounding boxes to check for the index 'index': :return: (tested_bounding_box, bounding_boxes_to_test) """ my_bounding_boxes = {bb_index: bb for bb_index, bb in enumerate(my_bounding_boxes)} bounding_boxes = dict( filter( lambda pair: pair[0] not in (index - 1, index, index + 1), my_bounding_boxes.items(), ) ) return my_bounding_boxes[index], bounding_boxes.values() bb_without_expected_overlap = [ pick_all_none_juxtaposed_bb(index, all_bounding_boxes) for index in range(len(all_bounding_boxes)) ] for bb_pair in bb_without_expected_overlap: bb_frame, bb_not_juxtaposed_frames = bb_pair for bb_not_juxtaposed_frame in bb_not_juxtaposed_frames: if bb_frame.get_overlap(bb_not_juxtaposed_frame) is not None: treat_error( f"overlap found between two frames not juxtaposed - {bounding_boxes[bb_frame]} and {bounding_boxes[bb_not_juxtaposed_frame]}" )