from distutils.version import StrictVersion
from typing import Optional, Union
import logging
import functools
import numpy
from scipy.ndimage import affine_transform
from tomoscan.scanbase import TomoScanBase
from tomoscan.volumebase import VolumeBase
from nxtomo.utils.transformation import build_matrix, UDDetTransformation
from silx.utils.enum import Enum as _Enum
from scipy.fft import rfftn as local_fftn
from scipy.fft import irfftn as local_ifftn
from silx.utils.enum import Enum as _Enum
from nxtomo.utils.transformation import build_matrix, UDDetTransformation
from tomoscan.scanbase import TomoScanBase
from .overlap import OverlapStitchingStrategy, ZStichOverlapKernel
from .alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
from ..misc import fourier_filters
from ..estimation.alignment import AlignmentBase
from ..resources.dataset_analyzer import HDF5DatasetAnalyzer
from ..resources.nxflatfield import update_dataset_info_flats_darks
try:
import itk
except ImportError:
has_itk = False
else:
has_itk = True
_logger = logging.getLogger(__name__)
try:
from skimage.registration import phase_cross_correlation
except ImportError:
_logger.warning(
"Unable to load skimage. Please install it if you want to use it for finding shifts from `find_relative_shifts`"
)
__has_sk_phase_correlation__ = False
else:
__has_sk_phase_correlation__ = True
[docs]
class ShiftAlgorithm(_Enum):
"""All generic shift search algorithm"""
NABU_FFT = "nabu-fft"
SKIMAGE = "skimage"
ITK_IMG_REG_V4 = "itk-img-reg-v4"
NONE = "None"
# In the case of shift search on radio along axis 2 (or axis x in image space) we can benefit from the existing
# nabu algorithm such as growing-window or sliding-window
CENTERED = "centered"
GLOBAL = "global"
SLIDING_WINDOW = "sliding-window"
GROWING_WINDOW = "growing-window"
SINO_COARSE_TO_FINE = "sino-coarse-to-fine"
COMPOSITE_COARSE_TO_FINE = "composite-coarse-to-fine"
@classmethod
def from_value(cls, value):
if value in ("", None):
return ShiftAlgorithm.NONE
else:
return super().from_value(value=value)
[docs]
def test_overlap_stitching_strategy(overlap_1, overlap_2, stitching_strategies):
"""
stitch the two ovrelap with all the requested strategies.
Return a dictionary with stitching strategy as key and a result dict as value.
result dict keys are: 'weights_overlap_1', 'weights_overlap_2', 'stiching'
"""
res = {}
for strategy in stitching_strategies:
s = OverlapStitchingStrategy.from_value(strategy)
stitcher = ZStichOverlapKernel(
stitching_strategy=s,
frame_width=overlap_1.shape[1],
)
stiched_overlap, w1, w2 = stitcher.stitch(overlap_1, overlap_2, check_input=True)
res[s.value] = {
"stitching": stiched_overlap,
"weights_overlap_1": w1,
"weights_overlap_2": w2,
}
return res
[docs]
def find_frame_relative_shifts(
overlap_upper_frame: numpy.ndarray,
overlap_lower_frame: numpy.ndarray,
estimated_shifts,
x_cross_correlation_function=None,
y_cross_correlation_function=None,
x_shifts_params: Optional[dict] = None,
y_shifts_params: Optional[dict] = None,
):
from nabu.stitching.config import (
KEY_WINDOW_SIZE,
KEY_LOW_PASS_FILTER,
KEY_HIGH_PASS_FILTER,
) # avoid cyclic import
x_cross_correlation_function = ShiftAlgorithm.from_value(x_cross_correlation_function)
y_cross_correlation_function = ShiftAlgorithm.from_value(y_cross_correlation_function)
if x_shifts_params is None:
x_shifts_params = {}
if y_shifts_params is None:
y_shifts_params = {}
# apply filtering if any
def _str_to_int(value):
if isinstance(value, str):
value = value.lstrip("'").lstrip('"')
value = value.rstrip("'").rstrip('"')
value = int(value)
return value
low_pass = _str_to_int(x_shifts_params.get(KEY_LOW_PASS_FILTER, y_shifts_params.get(KEY_LOW_PASS_FILTER, None)))
high_pass = _str_to_int(x_shifts_params.get(KEY_HIGH_PASS_FILTER, y_shifts_params.get(KEY_HIGH_PASS_FILTER, None)))
if high_pass is None and low_pass is None:
pass
else:
if low_pass is None:
low_pass = 1
if high_pass is None:
high_pass = 20
_logger.info(f"filter image for shift search (low_pass={low_pass}, high_pass={high_pass})")
img_filter = fourier_filters.get_bandpass_filter(
overlap_upper_frame.shape[-2:],
cutoff_lowpass=low_pass,
cutoff_highpass=high_pass,
use_rfft=True,
data_type=overlap_upper_frame.dtype,
)
overlap_upper_frame = local_ifftn(
local_fftn(overlap_upper_frame, axes=(-2, -1)) * img_filter, axes=(-2, -1)
).real
overlap_lower_frame = local_ifftn(
local_fftn(overlap_lower_frame, axes=(-2, -1)) * img_filter, axes=(-2, -1)
).real
# compute shifts
initial_shifts = numpy.array(estimated_shifts).copy()
extra_shifts = numpy.array([0.0, 0.0])
def skimage_proxy(img1, img2):
if not __has_sk_phase_correlation__:
raise ValueError("scikit-image not installed. Cannot do phase correlation from it")
else:
found_shift, _, _ = phase_cross_correlation(reference_image=img1, moving_image=img2, space="real")
return -found_shift
shift_methods = {
ShiftAlgorithm.NABU_FFT: functools.partial(
find_shift_correlate, img1=overlap_upper_frame, img2=overlap_lower_frame
),
ShiftAlgorithm.SKIMAGE: functools.partial(skimage_proxy, img1=overlap_upper_frame, img2=overlap_lower_frame),
ShiftAlgorithm.ITK_IMG_REG_V4: functools.partial(
find_shift_with_itk, img1=overlap_upper_frame, img2=overlap_lower_frame
),
ShiftAlgorithm.NONE: functools.partial(lambda: (0.0, 0.0)),
}
res_algo = {}
for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)):
if shift_alg not in shift_methods:
raise ValueError(f"requested image alignment function not handled ({shift_alg})")
try:
res_algo[shift_alg] = shift_methods[shift_alg]()
except Exception as e:
_logger.error(f"Failed to find shift from {shift_alg.value}. Error is {e}")
res_algo[shift_alg] = (0, 0)
extra_shifts = (
res_algo[y_cross_correlation_function][0],
res_algo[x_cross_correlation_function][1],
)
final_rel_shifts = numpy.array(extra_shifts) + initial_shifts
return tuple([int(shift) for shift in final_rel_shifts])
[docs]
def find_volumes_relative_shifts(
upper_volume: VolumeBase,
lower_volume: VolumeBase,
estimated_shifts,
dim_axis_1: int,
dtype,
flip_ud_upper_frame: bool = False,
flip_ud_lower_frame: bool = False,
slice_for_shift: Union[int, str] = "middle",
x_cross_correlation_function=None,
y_cross_correlation_function=None,
x_shifts_params: Optional[dict] = None,
y_shifts_params: Optional[dict] = None,
alignment_axis_2="center",
alignment_axis_1="center",
):
"""
:param int dim_axis_1: axis 1 dimension (to handle axis 1 alignment)
"""
if y_shifts_params is None:
y_shifts_params = {}
if x_shifts_params is None:
x_shifts_params = {}
alignment_axis_2 = AlignmentAxis2.from_value(alignment_axis_2)
alignment_axis_1 = AlignmentAxis1.from_value(alignment_axis_1)
assert dim_axis_1 > 0, "dim_axis_1 <= 0"
if isinstance(slice_for_shift, str):
if slice_for_shift == "first":
slice_for_shift = 0
elif slice_for_shift == "last":
slice_for_shift = dim_axis_1
elif slice_for_shift == "middle":
slice_for_shift = dim_axis_1 // 2
else:
raise ValueError("invalid slice provided to search shift", slice_for_shift)
def get_slice_along_axis_1(volume: VolumeBase, index: int):
assert isinstance(index, int), f"index should be an int, {type(index)} provided"
volume_shape = volume.get_volume_shape()
if alignment_axis_1 is AlignmentAxis1.BACK:
front_empty_width = dim_axis_1 - volume_shape[1]
if index < front_empty_width:
return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
else:
return volume.get_slice(index=index - front_empty_width, axis=1)
elif alignment_axis_1 is AlignmentAxis1.FRONT:
if index >= volume_shape[1]:
return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
else:
return volume.get_slice(index=index, axis=1)
elif alignment_axis_1 is AlignmentAxis1.CENTER:
front_empty_width = (dim_axis_1 - volume_shape[1]) // 2
back_empty_width = dim_axis_1 - front_empty_width
if index < front_empty_width or index > back_empty_width:
return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
else:
return volume.get_slice(index=index - front_empty_width, axis=1)
else:
raise TypeError(f"unmanaged alignment mode {alignment_axis_1.value}")
upper_frame = get_slice_along_axis_1(upper_volume, index=slice_for_shift)
lower_frame = get_slice_along_axis_1(lower_volume, index=slice_for_shift)
if flip_ud_upper_frame:
upper_frame = numpy.flipud(upper_frame.copy())
if flip_ud_lower_frame:
lower_frame = numpy.flipud(lower_frame.copy())
from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import
w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_frame.shape[0], lower_frame.shape[0]))
if start_overlap == 0:
overlap_upper_frame = upper_frame[-end_overlap:]
else:
overlap_upper_frame = upper_frame[-end_overlap:-start_overlap]
overlap_lower_frame = lower_frame[start_overlap:end_overlap]
# align if necessary
if overlap_upper_frame.shape[1] != overlap_lower_frame.shape[1]:
overlap_frame_width = min(overlap_upper_frame.shape[1], overlap_lower_frame.shape[1])
if alignment_axis_2 is AlignmentAxis2.CENTER:
upper_frame_left_pos = overlap_upper_frame.shape[1] // 2 - overlap_frame_width // 2
upper_frame_right_pos = upper_frame_left_pos + overlap_frame_width
overlap_upper_frame = overlap_upper_frame[:, upper_frame_left_pos:upper_frame_right_pos]
lower_frame_left_pos = overlap_lower_frame.shape[1] // 2 - overlap_frame_width // 2
lower_frame_right_pos = lower_frame_left_pos + overlap_frame_width
overlap_lower_frame = overlap_lower_frame[:, lower_frame_left_pos:lower_frame_right_pos]
elif alignment_axis_2 is AlignmentAxis2.LEFT:
overlap_upper_frame = overlap_upper_frame[:, :overlap_frame_width]
overlap_lower_frame = overlap_lower_frame[:, :overlap_frame_width]
elif alignment_axis_2 is AlignmentAxis2.RIGTH:
overlap_upper_frame = overlap_upper_frame[:, -overlap_frame_width:]
overlap_lower_frame = overlap_lower_frame[:, -overlap_frame_width:]
else:
raise ValueError(f"Alignement {alignment_axis_2.value} is not handled")
if not overlap_upper_frame.shape == overlap_lower_frame.shape:
raise ValueError(f"Fail to get consistant overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})")
return find_frame_relative_shifts(
overlap_upper_frame=overlap_upper_frame,
overlap_lower_frame=overlap_lower_frame,
estimated_shifts=estimated_shifts,
x_cross_correlation_function=x_cross_correlation_function,
y_cross_correlation_function=y_cross_correlation_function,
x_shifts_params=x_shifts_params,
y_shifts_params=y_shifts_params,
)
from nabu.pipeline.estimators import estimate_cor
[docs]
def find_projections_relative_shifts(
upper_scan: TomoScanBase,
lower_scan: TomoScanBase,
estimated_shifts,
flip_ud_upper_frame: bool = False,
flip_ud_lower_frame: bool = False,
projection_for_shift: Union[int, str] = "middle",
invert_order: bool = False,
x_cross_correlation_function=None,
y_cross_correlation_function=None,
x_shifts_params: Optional[dict] = None,
y_shifts_params: Optional[dict] = None,
) -> tuple:
"""
deduce the relative shift between the two scans.
Expected behavior:
* compute expected overlap area from z_translations and (sample) pixel size
* call an (optional) cross correlation function from the overlap area to compute the x shift and polish the y shift from `projection_for_shift`
:param TomoScanBase scan_0:
:param TomoScanBase scan_1:
:param int axis_0_overlap_px: overlap between the two scans in pixel
:param Union[int,str] projection_for_shift: index fo the projection to use (in projection space or in scan space ?. For now in projection) or str. If str must be in (`middle`, `first`, `last`)
:param str x_cross_correlation_function: optional method to refine x shift from computing cross correlation. For now valid values are: ("skimage", "nabu-fft")
:param str y_cross_correlation_function: optional method to refine y shift from computing cross correlation. For now valid values are: ("skimage", "nabu-fft")
:param int minimal_overlap_area_for_cross_correlation: if first approximated overlap shift found from z_translation is lower than this value will fall back on taking the full image for the cross correlation and log a warning
:param bool invert_order: are projections inverted between the two scans (case if rotation angle are inverted)
:param tuple estimated_shifts: 'a priori' shift estimation
:return: relative shift of scan_1 with scan_0 as reference: (y_shift, x_shift)
:rtype: tuple
:warning: this function will flip left-right and up-down frames by default. So it will return shift according to this information
"""
if x_shifts_params is None:
x_shifts_params = {}
if y_shifts_params is None:
y_shifts_params = {}
if estimated_shifts[0] < 0:
raise ValueError("y_overlap_px is expected to be stricktly positive")
x_cross_correlation_function = ShiftAlgorithm.from_value(x_cross_correlation_function)
y_cross_correlation_function = ShiftAlgorithm.from_value(y_cross_correlation_function)
# { handle specific use case (finding shift on scan) - when using nabu COR algorithms (for axis 2)
if x_cross_correlation_function in (
ShiftAlgorithm.SINO_COARSE_TO_FINE,
ShiftAlgorithm.COMPOSITE_COARSE_TO_FINE,
ShiftAlgorithm.CENTERED,
ShiftAlgorithm.GLOBAL,
ShiftAlgorithm.GROWING_WINDOW,
ShiftAlgorithm.SLIDING_WINDOW,
):
cor_options = x_shifts_params.copy()
cor_options.pop("img_reg_method", None)
# remove all none numeric options because estimate_cor will call 'literal_eval' on them
upper_scan_dataset_info = HDF5DatasetAnalyzer(
location=upper_scan.master_file, extra_options={"hdf5_entry": upper_scan.entry}
)
update_dataset_info_flats_darks(upper_scan_dataset_info, flatfield_mode=1)
upper_scan_pos = estimate_cor(
method=x_cross_correlation_function.value,
dataset_info=upper_scan_dataset_info,
cor_options=cor_options,
)
lower_scan_dataset_info = HDF5DatasetAnalyzer(
location=lower_scan.master_file, extra_options={"hdf5_entry": lower_scan.entry}
)
update_dataset_info_flats_darks(lower_scan_dataset_info, flatfield_mode=1)
lower_scan_pos = estimate_cor(
method=x_cross_correlation_function.value,
dataset_info=lower_scan_dataset_info,
cor_options=cor_options,
)
estimated_shifts = tuple(
[
estimated_shifts[0],
(lower_scan_pos - upper_scan_pos),
]
)
x_cross_correlation_function = ShiftAlgorithm.NONE
# } else we will compute shift from the flat projections
def get_flat_fielded_proj(
scan: TomoScanBase, proj_index: int, reverse: bool, transformation_matrix: Optional[numpy.ndarray]
):
first_proj_idx = sorted(lower_scan.projections.keys(), reverse=reverse)[proj_index]
ff = scan.flat_field_correction(
(scan.projections[first_proj_idx],),
(first_proj_idx,),
)[0]
assert ff.ndim == 2, f"expects a single 2D frame. Get something with {ff.ndim} dimensions"
if transformation_matrix is not None:
assert (
transformation_matrix.ndim == 2
), f"expects a 2D transformation matrix. Get a {transformation_matrix.ndim} D"
if numpy.isclose(transformation_matrix[2, 2], -1):
transformation_matrix[2, :] = 0
transformation_matrix[0, 2] = 0
transformation_matrix[2, 2] = 1
ff = numpy.flipud(ff)
return ff
if isinstance(projection_for_shift, str):
if projection_for_shift.lower() == "first":
projection_for_shift = 0
elif projection_for_shift.lower() == "last":
projection_for_shift = -1
elif projection_for_shift.lower() == "middle":
projection_for_shift = len(upper_scan.projections) // 2
else:
try:
projection_for_shift = int(projection_for_shift)
except ValueError:
raise ValueError(
f"{projection_for_shift} cannot be cast to an int and is not one of the possible ('first', 'last', 'middle')"
)
elif not isinstance(projection_for_shift, (int, numpy.number)):
raise TypeError(
f"projection_for_shift is expected to be an int. Not {type(projection_for_shift)} - {projection_for_shift}"
)
upper_scan_transformations = list(upper_scan.get_detector_transformations(tuple()))
if flip_ud_upper_frame:
upper_scan_transformations.append(UDDetTransformation())
upper_scan_trans_matrix = build_matrix(upper_scan_transformations)
lower_scan_transformations = list(lower_scan.get_detector_transformations(tuple()))
if flip_ud_lower_frame:
lower_scan_transformations.append(UDDetTransformation())
lower_scan_trans_matrix = build_matrix(lower_scan_transformations)
upper_proj = get_flat_fielded_proj(
upper_scan,
projection_for_shift,
reverse=False,
transformation_matrix=upper_scan_trans_matrix,
)
lower_proj = get_flat_fielded_proj(
lower_scan,
projection_for_shift,
reverse=invert_order,
transformation_matrix=lower_scan_trans_matrix,
)
from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import
w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_proj.shape[0], lower_proj.shape[0]))
if start_overlap == 0:
overlap_upper_frame = upper_proj[-end_overlap:]
else:
overlap_upper_frame = upper_proj[-end_overlap:-start_overlap]
overlap_lower_frame = lower_proj[start_overlap:end_overlap]
if not overlap_upper_frame.shape == overlap_lower_frame.shape:
raise ValueError(f"Fail to get consistant overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})")
return find_frame_relative_shifts(
overlap_upper_frame=overlap_upper_frame,
overlap_lower_frame=overlap_lower_frame,
estimated_shifts=estimated_shifts,
x_cross_correlation_function=x_cross_correlation_function,
y_cross_correlation_function=y_cross_correlation_function,
x_shifts_params=x_shifts_params,
y_shifts_params=y_shifts_params,
)
[docs]
def find_shift_correlate(img1, img2, padding_mode="reflect"):
alignment = AlignmentBase()
cc = alignment._compute_correlation_fft(
img1,
img2,
padding_mode,
)
img_shape = img1.shape[-2:]
cc_vs = numpy.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
cc_hs = numpy.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
(f_vals, fv, fh) = alignment.extract_peak_region_2d(cc, cc_vs=cc_vs, cc_hs=cc_hs)
shifts_vh = alignment.refine_max_position_2d(f_vals, fv, fh)
return -shifts_vh
[docs]
def find_shift_with_itk(img1: numpy.ndarray, img2: numpy.ndarray) -> tuple:
# created from https://examples.itk.org/src/registration/common/perform2dtranslationregistrationwithmeansquares/documentation
# return (y_shift, x_shift). For now shift are integers as only integer shift are handled.
if not img1.dtype == img2.dtype:
raise ValueError("the two images are expected to have the same type")
if not img1.ndim == img2.ndim == 2:
raise ValueError("the two images are expected to 2D numpy arrays")
if not has_itk:
_logger.warning("itk is not installed. Please install it to find shift with it")
return (0, 0)
if StrictVersion(itk.Version.GetITKVersion()) < StrictVersion("4.9.0"):
_logger.error("ITK 4.9.0 is required to find shift with it.")
return (0, 0)
pixel_type = itk.ctype("float")
img1 = numpy.ascontiguousarray(img1, dtype=numpy.float32)
img2 = numpy.ascontiguousarray(img2, dtype=numpy.float32)
dimension = 2
image_type = itk.Image[pixel_type, dimension]
fixed_image = itk.PyBuffer[image_type].GetImageFromArray(img1)
moving_image = itk.PyBuffer[image_type].GetImageFromArray(img2)
transform_type = itk.TranslationTransform[itk.D, dimension]
initial_transform = transform_type.New()
optimizer = itk.RegularStepGradientDescentOptimizerv4.New(
LearningRate=4,
MinimumStepLength=0.001,
RelaxationFactor=0.5,
NumberOfIterations=200,
)
metric = itk.MeanSquaresImageToImageMetricv4[image_type, image_type].New()
registration = itk.ImageRegistrationMethodv4.New(
FixedImage=fixed_image,
MovingImage=moving_image,
Metric=metric,
Optimizer=optimizer,
InitialTransform=initial_transform,
)
moving_initial_transform = transform_type.New()
initial_parameters = moving_initial_transform.GetParameters()
initial_parameters[0] = 0
initial_parameters[1] = 0
moving_initial_transform.SetParameters(initial_parameters)
registration.SetMovingInitialTransform(moving_initial_transform)
identity_transform = transform_type.New()
identity_transform.SetIdentity()
registration.SetFixedInitialTransform(identity_transform)
registration.SetNumberOfLevels(1)
registration.SetSmoothingSigmasPerLevel([0])
registration.SetShrinkFactorsPerLevel([1])
registration.Update()
transform = registration.GetTransform()
final_parameters = transform.GetParameters()
translation_along_x = final_parameters.GetElement(0)
translation_along_y = final_parameters.GetElement(1)
return numpy.round(translation_along_y), numpy.round(translation_along_x)