"""
nabu.pipeline.estimators: helper classes/functions to estimate parameters of a dataset
(center of rotation, detector tilt, etc).
"""
import inspect
import numpy as np
import scipy.fft # pylint: disable=E0611
from silx.io import get_data
import math
from scipy import ndimage as nd
from ..preproc.flatfield import FlatField
from ..estimation.cor import (
CenterOfRotation,
CenterOfRotationAdaptiveSearch,
CenterOfRotationSlidingWindow,
CenterOfRotationGrowingWindow,
CenterOfRotationOctaveAccurate,
)
from ..estimation.cor_sino import SinoCorInterface, CenterOfRotationFourierAngles, CenterOfRotationVo
from ..estimation.tilt import CameraTilt
from ..estimation.utils import is_fullturn_scan
from ..resources.logger import LoggerOrPrint
from ..resources.utils import extract_parameters
from ..utils import check_supported, deprecation_warning, get_num_threads, is_int, is_scalar
from ..resources.dataset_analyzer import get_radio_pair
from ..processing.rotation import Rotation
from ..preproc.ccd import Log, CCDFilter
from ..misc import fourier_filters
from .params import cor_methods, tilt_methods
[docs]
def estimate_cor(method, dataset_info, do_flatfield=True, cor_options=None, logger=None):
"""
High level function to compute the center of rotation (COR)
Parameters
----------
method: name of the method to be used for computing the center of rotation
dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer`
Dataset information structure
do_flatfield: If True apply flat field to compute the center of rotation
cor_options: optional dictionary that can contain the following keys:
* slice_idx: index of the slice to use for computing the sinogram (for sinogram based algorithms)
* subsampling subsampling
* radio_angles: angles of the radios to use (for radio based algorithms)
logger: logging object
"""
logger = LoggerOrPrint(logger)
cor_options = cor_options or {}
check_supported(method, list(cor_methods.keys()), "COR estimation method")
method = cor_methods[method]
# Extract CoR parameters from configuration file
if isinstance(cor_options, str):
try:
cor_options = extract_parameters(cor_options, sep=";")
except Exception as exc:
msg = "Could not extract parameters from cor_options: %s" % (str(exc))
logger.fatal(msg)
raise ValueError(msg)
elif isinstance(cor_options, dict):
pass
else:
raise TypeError(f"cor_options_str is expected to be a dict or a str. {type(cor_options)} provided")
# Dispatch. COR estimation is always expressed in absolute number of pixels (i.e. from the center of the first pixel column)
if method in CORFinder.search_methods:
cor_finder = CORFinder(
method,
dataset_info,
do_flatfield=do_flatfield,
cor_options=cor_options,
radio_angles=cor_options.get("radio_angles", (0.0, np.pi)),
logger=logger,
)
estimated_cor = cor_finder.find_cor()
elif method in SinoCORFinder.search_methods:
cor_finder = SinoCORFinder(
method,
dataset_info,
slice_idx=cor_options.get("slice_idx", "middle"),
subsampling=cor_options.get("subsampling", 10),
do_flatfield=do_flatfield,
take_log=cor_options.get("take_log", True),
cor_options=cor_options,
logger=logger,
)
estimated_cor = cor_finder.find_cor()
else:
composite_options = update_func_kwargs(CompositeCORFinder, cor_options)
for what in ["cor_options", "logger"]:
composite_options.pop(what, None)
cor_finder = CompositeCORFinder(
dataset_info,
cor_options=cor_options,
logger=logger,
**composite_options,
)
estimated_cor = cor_finder.find_cor()
return estimated_cor
[docs]
class CORFinderBase:
"""
A base class for CoR estimators.
It does common tasks like data reading, flatfield, etc.
"""
search_methods = {}
def __init__(self, method, dataset_info, do_flatfield=True, cor_options=None, logger=None):
"""
Initialize a CORFinder object.
Parameters
----------
dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer`
Dataset information structure
"""
check_supported(method, self.search_methods, "CoR estimation method")
self.method = method
self.cor_options = cor_options or {}
self.logger = LoggerOrPrint(logger)
self.dataset_info = dataset_info
self.do_flatfield = do_flatfield
self.shape = dataset_info.radio_dims[::-1]
self._get_lookup_side()
self._init_cor_finder()
def _get_lookup_side(self):
"""
Get the "initial guess" where the center-of-rotation (CoR) should be estimated.
For example 'center' means that CoR search will be done near the middle of the detector, i.e center column.
"""
lookup_side = self.cor_options.get("side", None)
self._lookup_side = lookup_side
# User-provided scalar
if not (isinstance(lookup_side, str)) and np.isscalar(lookup_side):
return
default_lookup_side = "right" if self.dataset_info.is_halftomo else "center"
# By default in nabu config, side='from_file' meaning that we inspect the dataset information for CoR metadata
if lookup_side == "from_file":
initial_cor_pos = self.dataset_info.dataset_scanner.x_rotation_axis_pixel_position # relative pos in pixels
if initial_cor_pos is None or initial_cor_pos == 0:
self.logger.warning("Could not get an initial estimate for center of rotation in data file")
lookup_side = default_lookup_side
else:
lookup_side = initial_cor_pos
self._lookup_side = initial_cor_pos
def _init_cor_finder(self):
cor_finder_cls = self.search_methods[self.method]["class"]
self.cor_finder = cor_finder_cls(verbose=False, logger=self.logger, extra_options=None)
[docs]
class CORFinder(CORFinderBase):
"""
Find the Center of Rotation with methods based on two (180-degrees opposed) radios.
"""
search_methods = {
"centered": {
"class": CenterOfRotation,
},
"global": {
"class": CenterOfRotationAdaptiveSearch,
"default_kwargs": {"low_pass": 1, "high_pass": 20},
},
"sliding-window": {
"class": CenterOfRotationSlidingWindow,
},
"growing-window": {
"class": CenterOfRotationGrowingWindow,
},
"octave-accurate": {
"class": CenterOfRotationOctaveAccurate,
},
}
def __init__(
self, method, dataset_info, do_flatfield=True, cor_options=None, logger=None, radio_angles=(0.0, np.pi)
):
"""
Initialize a CORFinder object.
Parameters
----------
dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer`
Dataset information structure
radio_angles: angles to use to find the cor
"""
super().__init__(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger)
self._radio_angles = radio_angles
self._init_radios()
self._apply_flatfield()
self._apply_tilt()
# octave-accurate does not support half-acquisition scans,
# but information on field of view is only known here with the "dataset_info" object.
# Do the check here.
if self.dataset_info.is_halftomo and method == "octave-accurate":
raise ValueError("The CoR estimator 'octave-accurate' does not support half-acquisition scans")
#
def _init_radios(self):
self.radios, self._radios_indices = get_radio_pair(
self.dataset_info, radio_angles=self._radio_angles, return_indices=True
)
def _apply_flatfield(self):
if not (self.do_flatfield):
return
self.flatfield = FlatField(
self.radios.shape,
flats=self.dataset_info.flats,
darks=self.dataset_info.darks,
radios_indices=self._radios_indices,
interpolation="linear",
)
self.flatfield.normalize_radios(self.radios)
def _apply_tilt(self):
tilt = self.dataset_info.detector_tilt
if tilt is None:
return
self.logger.debug("COREstimator: applying detector tilt correction of %f degrees" % tilt)
rot = Rotation(self.shape, tilt)
for i in range(self.radios.shape[0]):
self.radios[i] = rot.rotate(self.radios[i])
[docs]
def find_cor(self):
"""
Find the center of rotation.
Returns
-------
cor: float
The estimated center of rotation for the current dataset.
"""
self.logger.info("Estimating center of rotation")
# All find_shift() methods in self.search_methods have the same API with "img_1" and "img_2"
cor_exec_kwargs = update_func_kwargs(self.cor_finder.find_shift, self.cor_options)
cor_exec_kwargs["return_relative_to_middle"] = False
# ----- FIXME -----
# 'self.cor_options' can contain 'side="from_file"', and we should not modify it directly
# because it's entered by the user.
# Either make a copy of self.cor_options, or change the inspect() mechanism
if cor_exec_kwargs.get("side", None) == "from_file":
cor_exec_kwargs["side"] = self._lookup_side or "center"
# ------
if self._lookup_side is not None:
cor_exec_kwargs["side"] = self._lookup_side
self.logger.debug("%s.find_shift(%s)" % (self.cor_finder.__class__.__name__, str(cor_exec_kwargs)))
shift = self.cor_finder.find_shift(self.radios[0], np.fliplr(self.radios[1]), **cor_exec_kwargs)
return shift
# alias
COREstimator = CORFinder
[docs]
class SinoCORFinder(CORFinderBase):
"""
A class for finding Center of Rotation based on 360 degrees sinograms.
This class handles the steps of building the sinogram from raw radios.
"""
search_methods = {
"sino-coarse-to-fine": {
"class": SinoCorInterface,
},
"sino-sliding-window": {
"class": CenterOfRotationSlidingWindow,
},
"sino-growing-window": {
"class": CenterOfRotationGrowingWindow,
},
"fourier-angles": {"class": CenterOfRotationFourierAngles},
"vo": {
"class": CenterOfRotationVo,
},
}
def __init__(
self,
method,
dataset_info,
do_flatfield=True,
take_log=True,
cor_options=None,
logger=None,
slice_idx="middle",
subsampling=10,
):
"""
Initialize a SinoCORFinder object.
Other parameters
----------------
The following keys can be set in cor_options.
slice_idx: int or str
Which slice index to take for building the sinogram.
For example slice_idx=0 means that we extract the first line of each projection.
Value can also be "first", "top", "middle", "last", "bottom".
subsampling: int, float
subsampling strategy when building sinograms.
As building the complete sinogram from raw projections might be tedious, the reading is done with subsampling.
A positive integer value means the subsampling step (i.e `projections[::subsampling]`).
"""
super().__init__(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger)
self._set_slice_idx(slice_idx)
self._set_subsampling(subsampling)
self._load_raw_sinogram()
self._flatfield(do_flatfield)
self._get_sinogram(take_log)
def _check_360(self):
if not is_fullturn_scan(self.dataset_info.rotation_angles):
raise ValueError("Sinogram-based Center of Rotation estimation can only be used for 360 degrees scans")
def _set_slice_idx(self, slice_idx):
n_z = self.dataset_info.radio_dims[1]
if isinstance(slice_idx, str):
str_to_idx = {"top": 0, "first": 0, "middle": n_z // 2, "bottom": n_z - 1, "last": n_z - 1}
check_supported(slice_idx, str_to_idx.keys(), "slice location")
slice_idx = str_to_idx[slice_idx]
self.slice_idx = slice_idx
def _set_subsampling(self, subsampling):
projs_idx = sorted(self.dataset_info.projections.keys())
self.subsampling = None
if is_int(subsampling):
if subsampling < 0: # Total number of angles
raise NotImplementedError
else:
self.projs_indices = projs_idx[::subsampling]
self.angles = self.dataset_info.rotation_angles[::subsampling]
self.subsampling = subsampling
else: # Angular step
raise NotImplementedError()
def _load_raw_sinogram(self):
if self.slice_idx is None:
raise ValueError("Unknow slice index")
reader_kwargs = {
"sub_region": (slice(None, None, self.subsampling), slice(self.slice_idx, self.slice_idx + 1), slice(None))
}
if self.dataset_info.kind == "edf":
reader_kwargs = {"n_reading_threads": get_num_threads()}
self.data_reader = self.dataset_info.get_reader(**reader_kwargs)
self._radios = self.data_reader.load_data()
def _flatfield(self, do_flatfield):
self.do_flatfield = bool(do_flatfield)
if not self.do_flatfield:
return
flats = {k: arr[self.slice_idx : self.slice_idx + 1, :] for k, arr in self.dataset_info.flats.items()}
darks = {k: arr[self.slice_idx : self.slice_idx + 1, :] for k, arr in self.dataset_info.darks.items()}
flatfield = FlatField(
self._radios.shape,
flats,
darks,
radios_indices=self.projs_indices,
)
flatfield.normalize_radios(self._radios)
def _get_sinogram(self, take_log):
sinogram = self._radios[:, 0, :].copy()
if take_log:
log = Log(self._radios.shape, clip_min=1e-6, clip_max=10.0)
log.take_logarithm(sinogram)
self.sinogram = sinogram
@staticmethod
def _split_sinogram(sinogram):
n_a_2 = sinogram.shape[0] // 2
img_1, img_2 = sinogram[:n_a_2], sinogram[n_a_2:]
# "Handle" odd number of projections
if img_2.shape[0] > img_1.shape[0]:
img_2 = img_2[:-1, :]
#
return img_1, img_2
[docs]
def find_cor(self):
self.logger.info("Estimating center of rotation")
cor_exec_kwargs = update_func_kwargs(self.cor_finder.find_shift, self.cor_options)
cor_exec_kwargs["return_relative_to_middle"] = False
# FIXME
# 'self.cor_options' can contain 'side="from_file"', and we should not modify it directly
# because it's entered by the user.
# Either make a copy of self.cor_options, or change the inspect() mechanism
if cor_exec_kwargs["side"] == "from_file":
cor_exec_kwargs["side"] = self._lookup_side or "center"
#
if self._lookup_side is not None:
cor_exec_kwargs["side"] = self._lookup_side
if self.method == "fourier-angles":
cor_exec_args = [self.sinogram]
cor_exec_kwargs["angles"] = self.dataset_info.rotation_angles
elif self.method == "vo":
cor_exec_args = [self.sinogram]
cor_exec_kwargs["halftomo"] = self.dataset_info.is_halftomo
cor_exec_kwargs["is_360"] = is_fullturn_scan(self.dataset_info.rotation_angles)
else:
# For these methods relying on find_shift() with two images, the sinogram needs to be split in two
img_1, img_2 = self._split_sinogram(self.sinogram)
cor_exec_args = [img_1, np.fliplr(img_2)]
self.logger.debug("%s.find_shift(%s)" % (self.cor_finder.__class__.__name__, str(cor_exec_kwargs)))
shift = self.cor_finder.find_shift(*cor_exec_args, **cor_exec_kwargs)
return shift
# alias
SinoCOREstimator = SinoCORFinder
[docs]
class CompositeCORFinder(CORFinderBase):
"""
Class and method to prepare sinogram and calculate COR
The pseudo sinogram is built with shrinked radios taken every theta_interval degres
Compared to first writing by Christian Nemoz:
- gives the same result of the original octave script on the dataset sofar tested
- The meaning of parameter n_subsampling_y (alias subsampling_y)is now the number of lines which are taken from
every radio. This is more meaningful in terms of amout of collected information because it
does not depend on the radio size. Moreover this is what was done in the octave script
- The spike_threshold has been added with default to 0.04
- The angular sampling is every 5 degree by default, as it is now the case also in the octave script
- The finding of the optimal overlap is doing by looping over the possible overlap, according to the overlap.
After a first testing phase, this part, which is the time consuming part, can be accelerated
by several order of magnitude without modifing the final result
"""
search_methods = {
"composite-coarse-to-fine": {
"class": CenterOfRotation, # Hack. Not used. Everything is done in the find_cor() func.
}
}
_default_cor_options = {"low_pass": 0.4, "high_pass": 10, "side": "near", "near_pos": 0, "near_width": 40}
def __init__(
self,
dataset_info,
oversampling=4,
theta_interval=5,
n_subsampling_y=40,
take_log=True,
cor_options=None,
spike_threshold=0.04,
logger=None,
norm_order=1,
):
super().__init__(
"composite-coarse-to-fine", dataset_info, do_flatfield=True, cor_options=cor_options, logger=logger
)
if norm_order not in [1, 2]:
raise ValueError(
f""" the norm order (nom_order parameter) must be either 1 or 2. You passed {norm_order}
"""
)
self.norm_order = norm_order
self.dataset_info = dataset_info
self.logger = LoggerOrPrint(logger)
self.sx, self.sy = self.dataset_info.radio_dims
default_cor_options = self._default_cor_options.copy()
default_cor_options.update(self.cor_options)
self.cor_options = default_cor_options
# the algorithm can work for angular ranges larger than 1.2*pi
# up to an arbitrarily number of turns as it is the case in helical scans
self.spike_threshold = spike_threshold
# the following line is necessary for multi-turns scan because the encoders is always
# in the interval 0-360
self.unwrapped_rotation_angles = np.unwrap(self.dataset_info.rotation_angles)
self.angle_min = self.unwrapped_rotation_angles.min()
self.angle_max = self.unwrapped_rotation_angles.max()
if (self.angle_max - self.angle_min) < 1.2 * np.pi:
useful_span = None
raise ValueError(
f"""Sinogram-based Center of Rotation estimation can only be used for scans over more than 180 degrees.
Your angular span was barely above 180 degrees, it was in fact {((self.angle_max - self.angle_min)/np.pi):.2f} x 180
and it is not considered to be enough by the discriminating condition which requires at least 1.2 half-turns
"""
)
else:
useful_span = min(np.pi, (self.angle_max - self.angle_min) - np.pi)
# readapt theta_interval accordingly if the span is smaller than pi
if useful_span < np.pi:
theta_interval = theta_interval * useful_span / np.pi
self.take_log = take_log
self.ovs = oversampling
self.theta_interval = theta_interval
target_sampling_y = np.round(np.linspace(0, self.sy - 1, n_subsampling_y + 2)).astype(int)[1:-1]
if self.spike_threshold is not None:
# take also one line below and on above for each line
# to provide appropriate margin
self.sampling_y = np.zeros([3 * len(target_sampling_y)], "i")
self.sampling_y[0::3] = np.maximum(0, target_sampling_y - 1)
self.sampling_y[2::3] = np.minimum(self.sy - 1, target_sampling_y + 1)
self.sampling_y[1::3] = target_sampling_y
self.ccd_correction = CCDFilter((len(self.sampling_y), self.sx), median_clip_thresh=self.spike_threshold)
else:
self.sampling_y = target_sampling_y
self.nproj = self.dataset_info.n_angles
my_condition = np.less(self.unwrapped_rotation_angles + np.pi, self.angle_max) * np.less(
self.unwrapped_rotation_angles, self.angle_min + useful_span
)
possibly_probed_angles = self.unwrapped_rotation_angles[my_condition]
possibly_probed_indices = np.arange(len(self.unwrapped_rotation_angles))[my_condition]
self.dproj = round(len(possibly_probed_angles) / np.rad2deg(useful_span) * self.theta_interval)
self.probed_angles = possibly_probed_angles[:: self.dproj]
self.probed_indices = possibly_probed_indices[:: self.dproj]
self.absolute_indices = sorted(self.dataset_info.projections.keys())
my_flats = self.dataset_info.flats
if my_flats is not None and len(list(my_flats.keys())):
self.use_flat = True
self.flatfield = FlatField(
(len(self.absolute_indices), self.sy, self.sx),
self.dataset_info.flats,
self.dataset_info.darks,
radios_indices=self.absolute_indices,
)
else:
self.use_flat = False
self.sx, self.sy = self.dataset_info.radio_dims
self.mlog = Log((1,) + (self.sy, self.sx), clip_min=1e-6, clip_max=10.0)
self.rcor_abs = round(self.sx / 2.0)
self.cor_acc = round(self.sx / 2.0)
self.nprobed = len(self.probed_angles)
# initialize sinograms and radios arrays
self.sino = np.zeros([2 * self.nprobed * n_subsampling_y, (self.sx - 1) * self.ovs + 1], "f")
self._loaded = False
self.high_pass = self.cor_options["high_pass"]
img_filter = fourier_filters.get_bandpass_filter(
(self.sino.shape[0] // 2, self.sino.shape[1]),
cutoff_lowpass=self.cor_options["low_pass"] * self.ovs,
cutoff_highpass=self.high_pass * self.ovs,
use_rfft=False, # rfft changes the image dimensions lenghts to even if odd
data_type=np.float64,
)
# we are interested in filtering only along the x dimension only
img_filter[:] = img_filter[0]
self.img_filter = img_filter
def _oversample(self, radio):
"""oversampling in the horizontal direction"""
if self.ovs == 1:
return radio
else:
ovs_2D = [1, self.ovs]
return oversample(radio, ovs_2D)
def _get_cor_options(self, cor_options):
default_dict = self._default_cor_options.copy()
if self.dataset_info.is_halftomo:
default_dict["side"] = "right"
if cor_options is None or cor_options == "":
cor_options = {}
if isinstance(cor_options, str):
try:
cor_options = extract_parameters(cor_options, sep=";")
except Exception as exc:
msg = "Could not extract parameters from cor_options: %s" % (str(exc))
self.logger.fatal(msg)
raise ValueError(msg)
default_dict.update(cor_options)
cor_options = default_dict
self.cor_options = cor_options
[docs]
def get_radio(self, image_num):
# radio_dataset_idx = self.absolute_indices[image_num]
radio_dataset_idx = image_num
data_url = self.dataset_info.projections[radio_dataset_idx]
radio = get_data(data_url).astype(np.float64)
if self.use_flat:
self.flatfield.normalize_single_radio(radio, radio_dataset_idx, dtype=radio.dtype)
if self.take_log:
self.mlog.take_logarithm(radio)
radio = radio[self.sampling_y]
if self.spike_threshold is not None:
self.ccd_correction.median_clip_correction(radio, output=radio)
radio = radio[1::3]
return radio
[docs]
def get_sino(self, reload=False):
"""
Build sinogram (composite image) from the radio files
"""
if self._loaded and not reload:
return self.sino
sorting_indexes = np.argsort(self.unwrapped_rotation_angles)
sorted_all_angles = self.unwrapped_rotation_angles[sorting_indexes]
sorted_angle_indexes = np.arange(len(self.unwrapped_rotation_angles))[sorting_indexes]
irad = 0
for prob_a, prob_i in zip(self.probed_angles, self.probed_indices):
radio1 = self.get_radio(self.absolute_indices[prob_i])
other_angle = prob_a + np.pi
insertion_point = np.searchsorted(sorted_all_angles, other_angle)
if insertion_point > 0 and insertion_point < len(sorted_all_angles):
other_i_l = sorted_angle_indexes[insertion_point - 1]
other_i_h = sorted_angle_indexes[insertion_point]
radio_l = self.get_radio(self.absolute_indices[other_i_l])
radio_h = self.get_radio(self.absolute_indices[other_i_h])
f = (other_angle - sorted_all_angles[insertion_point - 1]) / (
sorted_all_angles[insertion_point] - sorted_all_angles[insertion_point - 1]
)
radio2 = (1 - f) * radio_l + f * radio_h
else:
if insertion_point == 0:
other_i = sorted_angle_indexes[0]
elif insertion_point == len(sorted_all_angles):
other_i = sorted_angle_indexes[insertion_point - 1]
radio2 = self.get_radio(self.absolute_indices[other_i]) # pylint: disable=E0606
self.sino[irad : irad + radio1.shape[0], :] = self._oversample(radio1)
self.sino[
irad + self.nprobed * radio1.shape[0] : irad + self.nprobed * radio1.shape[0] + radio1.shape[0], :
] = self._oversample(radio2)
irad = irad + radio1.shape[0]
self.sino[np.isnan(self.sino)] = 0.0001 # ?
return self.sino
[docs]
def find_cor(self, reload=False):
self.logger.info("Estimating center of rotation")
self.logger.debug("%s.find_shift(%s)" % (self.__class__.__name__, self.cor_options))
self.sinogram = self.get_sino(reload=reload)
dim_v, dim_h = self.sinogram.shape
assert dim_v % 2 == 0, " this should not happen "
dim_v = dim_v // 2
radio1 = self.sinogram[:dim_v]
radio2 = self.sinogram[dim_v:]
orig_sy, orig_ovsd_sx = radio1.shape
radio1 = scipy.fft.ifftn(
scipy.fft.fftn(radio1, axes=(-2, -1)) * self.img_filter, axes=(-2, -1)
).real # TODO: convolute only along x
radio2 = scipy.fft.ifftn(
scipy.fft.fftn(radio2, axes=(-2, -1)) * self.img_filter, axes=(-2, -1)
).real # TODO: convolute only along x
tmp_sy, ovsd_sx = radio1.shape
assert orig_sy == tmp_sy and orig_ovsd_sx == ovsd_sx, "this should not happen"
cor_side = self.cor_options["side"]
if cor_side == "center":
overlap_min = max(round(ovsd_sx - ovsd_sx / 3), 4)
overlap_max = min(round(ovsd_sx + ovsd_sx / 3), 2 * ovsd_sx - 4)
elif cor_side == "right":
overlap_min = max(4, self.ovs * self.high_pass * 3)
overlap_max = ovsd_sx
elif cor_side == "left":
overlap_min = ovsd_sx
overlap_max = min(2 * ovsd_sx - 4, 2 * ovsd_sx - self.ovs * self.ovs * self.high_pass * 3)
elif cor_side == "all":
overlap_min = max(4, self.ovs * self.high_pass * 3)
overlap_max = min(2 * ovsd_sx - 4, 2 * ovsd_sx - self.ovs * self.ovs * self.high_pass * 3)
elif is_scalar(cor_side):
near_pos = cor_side
near_width = self.cor_options["near_width"]
overlap_min = max(4, ovsd_sx - 2 * self.ovs * (near_pos + near_width))
overlap_max = min(2 * ovsd_sx - 4, ovsd_sx - 2 * self.ovs * (near_pos - near_width))
# COMPAT.
elif cor_side == "near":
deprecation_warning(
"using side='near' is deprecated, use side=<a scalar> instead",
do_print=True,
func_name="composite_near_pos",
)
near_pos = self.cor_options["near_pos"]
near_width = self.cor_options["near_width"]
overlap_min = max(4, ovsd_sx - 2 * self.ovs * (near_pos + near_width))
overlap_max = min(2 * ovsd_sx - 4, ovsd_sx - 2 * self.ovs * (near_pos - near_width))
# ---
else:
raise ValueError("Invalid option 'side=%s'" % self.cor_options["side"])
if overlap_min > overlap_max:
message = f""" There is no safe search range in find_cor once the margins corresponding to the high_pass filter are discarded.
Try reducing the low_pass parameter in cor_options
"""
raise ValueError(message)
self.logger.info(
"looking for overlap from min %.2f and max %.2f\n" % (overlap_min / self.ovs, overlap_max / self.ovs)
)
best_overlap = overlap_min
best_error = np.inf
blurred_radio1 = nd.gaussian_filter(abs(radio1), [0, self.high_pass])
blurred_radio2 = nd.gaussian_filter(abs(radio2), [0, self.high_pass])
for z in range(int(overlap_min), int(overlap_max) + 1):
if z <= ovsd_sx:
my_z = z
my_radio1 = radio1
my_radio2 = radio2
my_blurred_radio1 = blurred_radio1
my_blurred_radio2 = blurred_radio2
else:
my_z = ovsd_sx - (z - ovsd_sx)
my_radio1 = np.fliplr(radio1)
my_radio2 = np.fliplr(radio2)
my_blurred_radio1 = np.fliplr(blurred_radio1)
my_blurred_radio2 = np.fliplr(blurred_radio2)
common_left = np.fliplr(my_radio1[:, ovsd_sx - my_z :])[:, : -int(math.ceil(self.ovs * self.high_pass * 2))]
# adopt a 'safe' margin considering high_pass value (possibly float)
common_right = my_radio2[:, ovsd_sx - my_z : -int(math.ceil(self.ovs * self.high_pass * 2))]
common_blurred_left = np.fliplr(my_blurred_radio1[:, ovsd_sx - my_z :])[
:, : -int(math.ceil(self.ovs * self.high_pass * 2))
]
# adopt a 'safe' margin considering high_pass value (possibly float)
common_blurred_right = my_blurred_radio2[:, ovsd_sx - my_z : -int(math.ceil(self.ovs * self.high_pass * 2))]
if common_right.size == 0:
continue
error = self.error_metric(common_right, common_left, common_blurred_right, common_blurred_left)
min_error = min(best_error, error)
if min_error == error:
best_overlap = z
best_error = min_error
# self.logger.debug(
# "testing an overlap of %.2f pixels, actual best overlap is %.2f pixels over %d\r"
# % (z / self.ovs, best_overlap / self.ovs, ovsd_sx / self.ovs),
# )
offset = (ovsd_sx - best_overlap) / self.ovs / 2
cor_abs = (self.sx - 1) / 2 + offset
return cor_abs
[docs]
def error_metric(self, common_right, common_left, common_blurred_right, common_blurred_left):
if self.norm_order == 2:
return self.error_metric_l2(common_right, common_left)
elif self.norm_order == 1:
return self.error_metric_l1(common_right, common_left, common_blurred_right, common_blurred_left)
else:
assert False, "this cannot happen"
[docs]
def error_metric_l2(self, common_right, common_left):
common = common_right - common_left
tmp = np.linalg.norm(common)
norm_diff2 = tmp * tmp
norm_right = np.linalg.norm(common_right)
norm_left = np.linalg.norm(common_left)
res = norm_diff2 / (norm_right * norm_left)
return res
[docs]
def error_metric_l1(self, common_right, common_left, common_blurred_right, common_blurred_left):
common = (common_right - common_left) / (common_blurred_right + common_blurred_left)
res = abs(common).mean()
return res
[docs]
def oversample(radio, ovs_s):
"""oversampling an image in arbitrary directions.
The first and last point of each axis will still remain as extremal points of the new axis.
"""
result = np.zeros([(radio.shape[0] - 1) * ovs_s[0] + 1, (radio.shape[1] - 1) * ovs_s[1] + 1], "f")
# Pre-initialisation: The original data falls exactly on the following strided positions in the new data array.
result[:: ovs_s[0], :: ovs_s[1]] = radio
for k in range(0, ovs_s[0]):
# interpolation coefficient for axis 0
g = k / ovs_s[0]
for i in range(0, ovs_s[1]):
if i == 0 and k == 0:
# this case subset was already exactly matched from before the present double loop,
# in the pre-initialisation line.
continue
# interpolation coefficent for axis 1
f = i / ovs_s[1]
# stop just a bit before cause we are not extending beyond the limits.
# If we are exacly on a vertical or horizontal original line, then no shift will be applied,
# and we will exploit the equality f+(1-f)=g+(1-g)=1 adding twice the same contribution with
# interpolation factors which become dummies pour le coup.
stop0 = -ovs_s[0] if k else None
stop1 = -ovs_s[1] if i else None
# Once again, we exploit the g+(1-g)=1 equality
start0 = ovs_s[0] if k else 0
start1 = ovs_s[1] if i else 0
# and what is done below makes clear the corundum above.
result[k :: ovs_s[0], i :: ovs_s[1]] = (1 - g) * (
(1 - f) * result[0 : stop0 : ovs_s[0], 0 : stop1 : ovs_s[1]]
+ f * result[0 : stop0 : ovs_s[0], start1 :: ovs_s[1]]
) + g * (
(1 - f) * result[start0 :: ovs_s[0], 0 : stop1 : ovs_s[1]]
+ f * result[start0 :: ovs_s[0], start1 :: ovs_s[1]]
)
return result
# alias
CompositeCOREstimator = CompositeCORFinder
# Some heavily inelegant things going on here
[docs]
def get_default_kwargs(func):
params = inspect.signature(func).parameters
res = {}
for param_name, param in params.items():
if param.default != inspect._empty:
res[param_name] = param.default
return res
[docs]
def update_func_kwargs(func, options):
res_options = get_default_kwargs(func)
for option_name, option_val in options.items():
if option_name in res_options:
res_options[option_name] = option_val
return res_options
[docs]
def get_class_name(class_object):
return str(class_object).split(".")[-1].strip(">").strip("'").strip('"')
[docs]
class DetectorTiltEstimator:
"""
Helper class for detector tilt estimation.
It automatically chooses the right radios and performs flat-field.
"""
default_tilt_method = "1d-correlation"
# Given a tilt angle "a", the maximum deviation caused by the tilt (in pixels) is
# N/2 * |sin(a)| where N is the number of pixels
# We ignore tilts causing less than 0.25 pixel deviation: N/2*|sin(a)| < tilt_threshold
tilt_threshold = 0.25
def __init__(self, dataset_info, do_flatfield=True, logger=None, autotilt_options=None):
"""
Initialize a detector tilt estimator helper.
Parameters
----------
dataset_info: `dataset_info` object
Data structure with the dataset information.
do_flatfield: bool, optional
Whether to perform flat field on radios.
logger: `Logger` object, optional
Logger object
autotilt_options: dict, optional
named arguments to pass to the detector tilt estimator class.
"""
self._set_params(dataset_info, do_flatfield, logger, autotilt_options)
self.radios, self.radios_indices = get_radio_pair(dataset_info, radio_angles=(0.0, np.pi), return_indices=True)
self._init_flatfield()
self._apply_flatfield()
def _set_params(self, dataset_info, do_flatfield, logger, autotilt_options):
self.dataset_info = dataset_info
self.do_flatfield = bool(do_flatfield)
self.logger = LoggerOrPrint(logger)
self._get_autotilt_options(autotilt_options)
def _init_flatfield(self):
if not (self.do_flatfield):
return
self.flatfield = FlatField(
self.radios.shape,
flats=self.dataset_info.flats,
darks=self.dataset_info.darks,
radios_indices=self.radios_indices,
interpolation="linear",
)
def _apply_flatfield(self):
if not (self.do_flatfield):
return
self.flatfield.normalize_radios(self.radios)
def _get_autotilt_options(self, autotilt_options):
if autotilt_options is None:
self.autotilt_options = None
return
try:
autotilt_options = extract_parameters(autotilt_options)
except Exception as exc:
msg = "Could not extract parameters from autotilt_options: %s" % (str(exc))
self.logger.fatal(msg)
raise ValueError(msg)
self.autotilt_options = autotilt_options
if "threshold" in autotilt_options:
self.tilt_threshold = autotilt_options.pop("threshold")
[docs]
def find_tilt(self, tilt_method=None):
"""
Find the detector tilt.
Parameters
----------
tilt_method: str, optional
Which tilt estimation method to use.
"""
if tilt_method is None:
tilt_method = self.default_tilt_method
check_supported(tilt_method, set(tilt_methods.values()), "tilt estimation method")
self.logger.info("Estimating detector tilt angle")
autotilt_params = {
"roi_yxhw": None,
"median_filt_shape": None,
"padding_mode": None,
"peak_fit_radius": 1,
"high_pass": None,
"low_pass": None,
}
autotilt_params.update(self.autotilt_options or {})
self.logger.debug("%s(%s)" % ("CameraTilt", str(autotilt_params)))
tilt_calc = CameraTilt()
tilt_cor_position, camera_tilt = tilt_calc.compute_angle(
self.radios[0], np.fliplr(self.radios[1]), method=tilt_method, **autotilt_params
)
self.logger.info("Estimated detector tilt angle: %f degrees" % camera_tilt)
# Ignore too small tilts
max_deviation = np.max(self.dataset_info.radio_dims) * np.abs(np.sin(np.deg2rad(camera_tilt)))
if self.dataset_info.is_halftomo:
max_deviation *= 2
if max_deviation < self.tilt_threshold:
self.logger.info(
"Estimated tilt angle (%.3f degrees) results in %.2f maximum pixels shift, which is below threshold (%.2f pixel). Ignoring the tilt, no correction will be done."
% (camera_tilt, max_deviation, self.tilt_threshold)
)
camera_tilt = None
return camera_tilt
# alias
TiltFinder = DetectorTiltEstimator