Source code for nabu.stitching.overlap

# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/

__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "10/05/2022"


import numpy
import logging
from typing import Optional, Union
from silx.utils.enum import Enum as _Enum
from nabu.misc import fourier_filters
from scipy.fft import rfftn as local_fftn
from scipy.fft import irfftn as local_ifftn
from tomoscan.utils.geometry import BoundingBox1D

_logger = logging.getLogger(__name__)


[docs] class OverlapStitchingStrategy(_Enum): MEAN = "mean" COSINUS_WEIGHTS = "cosinus weights" LINEAR_WEIGHTS = "linear weights" CLOSEST = "closest" IMAGE_MINIMUM_DIVERGENCE = "image minimum divergence" HIGHER_SIGNAL = "higher signal"
DEFAULT_OVERLAP_STRATEGY = OverlapStitchingStrategy.COSINUS_WEIGHTS DEFAULT_OVERLAP_SIZE = None # could also be an int # default overlap to be take for stitching. Ig None: take the largest possible area
[docs] class OverlapKernelBase: pass
[docs] class ZStichOverlapKernel(OverlapKernelBase): """ class used to define overlap between two scans and create stitch between frames (`stitch` function) """ DEFAULT_HIGH_FREQUENCY_THRESHOLD = 2 def __init__( self, frame_width: int, stitching_strategy: OverlapStitchingStrategy = DEFAULT_OVERLAP_STRATEGY, overlap_size: int = DEFAULT_OVERLAP_SIZE, extra_params: Optional[dict] = None, ) -> None: """ """ from nabu.stitching.config import KEY_THRESHOLD_FREQUENCY # avoid acylic import if not isinstance(overlap_size, int) and overlap_size > 0: raise TypeError( f"overlap_size is expected to be a positive int, {overlap_size} - not {overlap_size} ({type(overlap_size)})" ) if not isinstance(frame_width, int) or not frame_width > 0: raise TypeError( f"frame_width is expected to be a positive int, {frame_width} - not {frame_width} ({type(frame_width)})" ) self._overlap_size = abs(overlap_size) self._frame_width = frame_width self._stitching_strategy = OverlapStitchingStrategy.from_value(stitching_strategy) self._weights_img_1 = None self._weights_img_2 = None if extra_params is None: extra_params = {} self._high_frequency_threshold = extra_params.get( KEY_THRESHOLD_FREQUENCY, self.DEFAULT_HIGH_FREQUENCY_THRESHOLD ) def __str__(self) -> str: return f"z-stitching kernel (policy={self.stitching_strategy.value}, overlap_size={self.overlap_size}, frame={self._frame_width})" @staticmethod def __check_img(img, name): if not isinstance(img, numpy.ndarray) and img.ndim == 2: raise ValueError(f"{name} is expected to be 2D numpy array") @property def overlap_size(self) -> int: return self._overlap_size @overlap_size.setter def overlap_size(self, size: int): if not isinstance(size, int): raise TypeError(f"height expects a int ({type(size)} provided instead)") if not size >= 0: raise ValueError(f"height is expected to be positive") self._overlap_size = abs(size) # update weights if needed if self._weights_img_1 is not None or self._weights_img_2 is not None: self.compute_weights() @property def img_2(self) -> numpy.ndarray: return self._img_2 @property def weights_img_1(self) -> Optional[numpy.ndarray]: return self._weights_img_1 @property def weights_img_2(self) -> Optional[numpy.ndarray]: return self._weights_img_2 @property def stitching_strategy(self) -> OverlapStitchingStrategy: return self._stitching_strategy
[docs] def compute_weights(self): if self.stitching_strategy is OverlapStitchingStrategy.MEAN: weights_img_1 = numpy.ones(self._overlap_size) * 0.5 weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStitchingStrategy.CLOSEST: n_item = self._overlap_size // 2 + self._overlap_size % 2 weights_img_1 = numpy.concatenate( [ numpy.ones(n_item), numpy.zeros(self._overlap_size - n_item), ] ) weights_img_2 = numpy.concatenate( [ numpy.zeros(n_item), numpy.ones(self._overlap_size - n_item), ] ) elif self.stitching_strategy is OverlapStitchingStrategy.LINEAR_WEIGHTS: weights_img_1 = numpy.linspace(1.0, 0.0, self._overlap_size) weights_img_2 = weights_img_1[::-1] elif self.stitching_strategy is OverlapStitchingStrategy.COSINUS_WEIGHTS: angles = numpy.linspace(0.0, numpy.pi / 2.0, self._overlap_size) weights_img_1 = numpy.cos(angles) ** 2 weights_img_2 = numpy.sin(angles) ** 2 elif self.stitching_strategy in ( OverlapStitchingStrategy.IMAGE_MINIMUM_DIVERGENCE, OverlapStitchingStrategy.HIGHER_SIGNAL, ): # those strategies are not using constant weights but have treatments depending on the provided img_1 and mg_2 during stitching return else: raise NotImplementedError(f"{self.stitching_strategy} not implemented") self._weights_img_1 = weights_img_1.reshape(-1, 1) * numpy.ones(self._frame_width).reshape(1, -1) self._weights_img_2 = weights_img_2.reshape(-1, 1) * numpy.ones(self._frame_width).reshape(1, -1)
[docs] def stitch(self, img_1, img_2, check_input=True) -> tuple: """Compute overlap region from the defined strategy""" if check_input: self.__check_img(img_1, "img_1") self.__check_img(img_2, "img_2") if img_1.shape != img_2.shape: raise ValueError( f"both images are expected to be of the same shape to apply stitch ({img_1.shape} vs {img_2.shape})" ) if self._stitching_strategy is OverlapStitchingStrategy.IMAGE_MINIMUM_DIVERGENCE: return ( compute_image_minimum_divergence( img_1=img_1, img_2=img_2, high_frequency_threshold=self._high_frequency_threshold ), None, None, ) elif self._stitching_strategy is OverlapStitchingStrategy.HIGHER_SIGNAL: return ( compute_image_higher_signal(img_1=img_1, img_2=img_2), None, None, ) else: if self.weights_img_1 is None or self.weights_img_2 is None: self.compute_weights() return ( img_1 * self.weights_img_1 + img_2 * self.weights_img_2, self.weights_img_1, self.weights_img_2, )
[docs] def compute_image_minimum_divergence(img_1: numpy.ndarray, img_2: numpy.ndarray, high_frequency_threshold): """ Algorithm to improve treatment of high frequency. It split the two images into two parts: high frequency and low frequency. The two low frequency part will be stitched using a 'sinusoidal' / cosinus weights approach. When the two high frequency parts will be stitched by taking the lower divergent pixels """ # split low and high frequencies def split_image(image: numpy.ndarray, threshold: int) -> tuple: """split an image to return (low_frequency, high_frequency)""" lowpass_filter = fourier_filters.get_lowpass_filter( image.shape[-2:], cutoff_par=threshold, use_rfft=True, data_type=image.dtype, ) highpass_filter = fourier_filters.get_highpass_filter( image.shape[-2:], cutoff_par=threshold, use_rfft=True, data_type=image.dtype, ) low_fre_part = local_ifftn(local_fftn(image, axes=(-2, -1)) * lowpass_filter, axes=(-2, -1)).real high_fre_part = local_ifftn(local_fftn(image, axes=(-2, -1)) * highpass_filter, axes=(-2, -1)).real return (low_fre_part, high_fre_part) low_freq_img_1, high_freq_img_1 = split_image(img_1, threshold=high_frequency_threshold) low_freq_img_2, high_freq_img_2 = split_image(img_2, threshold=high_frequency_threshold) # handle low frequency low_freq_stitching_kernel = ZStichOverlapKernel( frame_width=img_1.shape[1], stitching_strategy=OverlapStitchingStrategy.COSINUS_WEIGHTS, overlap_size=img_1.shape[0], ) low_freq_stitched = low_freq_stitching_kernel.stitch( img_1=low_freq_img_1, img_2=low_freq_img_2, check_input=False, )[0] # handle high frequency mean_high_frequency = numpy.mean([high_freq_img_1, high_freq_img_2]) assert numpy.isscalar(mean_high_frequency) high_freq_distance_img_1 = numpy.abs(high_freq_img_1 - mean_high_frequency) high_freq_distance_img_2 = numpy.abs(high_freq_img_2 - mean_high_frequency) high_freq_stitched = numpy.where( high_freq_distance_img_1 >= high_freq_distance_img_2, high_freq_distance_img_2, high_freq_distance_img_1 ) # merge back low and high frequencies together def merge_images(low_freq: numpy.ndarray, high_freq: numpy.ndarray) -> numpy.ndarray: """merge two part of an image. The low frequency part with the high frequency part""" return low_freq + high_freq return merge_images(low_freq_stitched, high_freq_stitched)
[docs] def compute_image_higher_signal(img_1: numpy.ndarray, img_2: numpy.ndarray): """ the higher signal will pick pixel on the image having the higher signal. A use case is that if there is some artefacts on images which creates stripes (from scintillator artefacts for example) it could be removed from this method """ # note: to be think about. But maybe it can be interesting to rescale img_1 and img_2 # to ge something more coherent return numpy.where(img_1 >= img_2, img_1, img_2)
[docs] def check_overlaps(frames: Union[tuple, numpy.ndarray], positions: tuple, axis: int, raise_error: bool): """ check over frames if there is a single overlap other juxtaposed frames (at most and at least) :param frames: liste of ordered / sorted frames along axis to test (from higher to lower) :param positions: positions of frames in 3D space as (position axis 0, position axis 1, position axis 2) :param axis: axis to check :param raise_error: if True then raise an error if two frames don't have at least and at most one overlap. Else log an error """ if not isinstance(frames, (tuple, numpy.ndarray)): raise TypeError(f"frames is expected to be a tuple or a numpy array. Get {type(frames)} instead") if not isinstance(positions, tuple): raise TypeError(f"positions is expected to be a tuple. Get {type(positions)} instead") assert isinstance(axis, int), "axis is expected to be an int" assert isinstance(raise_error, bool), "raise_error is expected to be a bool" def treat_error(error_msg: str): if raise_error: raise ValueError(error_msg) else: _logger.error(raise_error) # convert each frame to appropriate bounding box according to the axis def convert_to_bb(frame: numpy.ndarray, position: tuple, axis: int): assert isinstance(axis, int) assert isinstance(position, tuple), f"position expected a tuple. Get {type(position)} instead" start_frame = position[axis] - frame.shape[axis] // 2 end_frame = start_frame + frame.shape[axis] return BoundingBox1D(start_frame, end_frame) bounding_boxes = { convert_to_bb(frame=frame, position=position, axis=axis): position for frame, position in zip(frames, positions) } def get_frame_index(my_bb) -> str: bb_index = tuple(bounding_boxes.keys()).index(my_bb) + 1 if bb_index in (1, 21, 31): return f"{bb_index}st" elif bb_index in (2, 22, 32): return f"{bb_index}nd" elif bb_index == (3, 23, 33): return f"{bb_index}rd" else: return f"{bb_index}th" # check that theres an overlap between two juxtaposed bb (or frame at the end) all_bounding_boxes = tuple(bounding_boxes.keys()) bb_with_expected_overlap = [ (bb_frame, bb_next_frame) for bb_frame, bb_next_frame in zip(all_bounding_boxes[:-1], all_bounding_boxes[1:]) ] for bb_pair in bb_with_expected_overlap: bb_frame, bb_next_frame = bb_pair if bb_frame.max < bb_next_frame.min: treat_error(f"provided frames seems un sorted (from the higher to the lower)") if bb_frame.min < bb_next_frame.min: treat_error( f"Seems like {get_frame_index(bb_frame)} frame is fully overlaping with frame {get_frame_index(bb_next_frame)}" ) if bb_frame.get_overlap(bb_next_frame) is None: treat_error( f"no overlap found between two juxtaposed frames - {get_frame_index(bb_frame)} and {get_frame_index(bb_next_frame)}" ) # check there is no overlap between none juxtaposed bb def pick_all_none_juxtaposed_bb(index, my_bounding_boxes: tuple): """return all the bounding boxes to check for the index 'index': :return: (tested_bounding_box, bounding_boxes_to_test) """ my_bounding_boxes = {bb_index: bb for bb_index, bb in enumerate(my_bounding_boxes)} bounding_boxes = dict( filter( lambda pair: pair[0] not in (index - 1, index, index + 1), my_bounding_boxes.items(), ) ) return my_bounding_boxes[index], bounding_boxes.values() bb_without_expected_overlap = [ pick_all_none_juxtaposed_bb(index, all_bounding_boxes) for index in range(len(all_bounding_boxes)) ] for bb_pair in bb_without_expected_overlap: bb_frame, bb_not_juxtaposed_frames = bb_pair for bb_not_juxtaposed_frame in bb_not_juxtaposed_frames: if bb_frame.get_overlap(bb_not_juxtaposed_frame) is not None: treat_error( f"overlap found between two frames not juxtaposed - {bounding_boxes[bb_frame]} and {bounding_boxes[bb_not_juxtaposed_frame]}" )