"""
nabu.pipeline.estimators: helper classes/functions to estimate parameters of a dataset
(center of rotation, detector tilt, etc).
"""
import inspect
import numpy as np
import scipy.fft # pylint: disable=E0611
from silx.io import get_data
from typing import Union, Optional
import math
from numbers import Real
from scipy import ndimage as nd
from ..preproc.flatfield import FlatFieldDataUrls
from ..estimation.cor import (
CenterOfRotation,
CenterOfRotationAdaptiveSearch,
CenterOfRotationSlidingWindow,
CenterOfRotationGrowingWindow,
CenterOfRotationFourierAngles,
CenterOfRotationOctaveAccurate,
)
from ..estimation.cor_sino import SinoCorInterface
from ..estimation.tilt import CameraTilt
from ..estimation.utils import is_fullturn_scan
from ..resources.logger import LoggerOrPrint
from ..resources.utils import extract_parameters
from ..utils import check_supported, is_int
from .params import tilt_methods
from ..resources.dataset_analyzer import get_radio_pair
from ..processing.rotation import Rotation
from ..io.reader import ChunkReader
from ..preproc.ccd import Log, CCDFilter
from ..misc import fourier_filters
from .params import cor_methods
from ..io.reader import load_images_from_dataurl_dict
[docs]
def estimate_cor(method, dataset_info, do_flatfield=True, cor_options: Optional[Union[str, dict]] = None, logger=None):
logger = LoggerOrPrint(logger)
cor_options = cor_options or {}
check_supported(method, list(cor_methods.keys()), "COR estimation method")
method = cor_methods[method]
# Extract CoR parameters from configuration file
if isinstance(cor_options, str):
try:
cor_options = extract_parameters(cor_options, sep=";")
except Exception as exc:
msg = "Could not extract parameters from cor_options: %s" % (str(exc))
logger.fatal(msg)
raise ValueError(msg)
elif isinstance(cor_options, dict):
pass
else:
raise TypeError(f"cor_options_str is expected to be a dict or a str. {type(cor_options)} provided")
# Dispatch. COR estimation is always expressed in absolute number of pixels (i.e. from the center of the first pixel column)
if method in CORFinder.search_methods:
cor_finder = CORFinder(
method,
dataset_info,
do_flatfield=do_flatfield,
cor_options=cor_options,
logger=logger,
)
estimated_cor = cor_finder.find_cor()
elif method in SinoCORFinder.search_methods:
cor_finder = SinoCORFinder(
method,
dataset_info,
slice_idx=cor_options.get("slice_idx", "middle"),
subsampling=cor_options.get("subsampling", 10),
do_flatfield=do_flatfield,
cor_options=cor_options,
logger=logger,
)
estimated_cor = cor_finder.find_cor()
else:
composite_options = update_func_kwargs(CompositeCORFinder, cor_options)
for what in ["cor_options", "logger"]:
composite_options.pop(what, None)
cor_finder = CompositeCORFinder(
dataset_info,
cor_options=cor_options,
logger=logger,
**composite_options,
)
estimated_cor = cor_finder.find_cor()
return estimated_cor
[docs]
class CORFinderBase:
"""
A base class for CoR estimators.
It does common tasks like data reading, flatfield, etc.
"""
search_methods = {}
def __init__(self, method, dataset_info, do_flatfield=True, cor_options=None, logger=None):
"""
Initialize a CORFinder object.
Parameters
----------
dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer`
Dataset information structure
"""
check_supported(method, self.search_methods, "CoR estimation method")
self.logger = LoggerOrPrint(logger)
self.dataset_info = dataset_info
self.do_flatfield = do_flatfield
self.shape = dataset_info.radio_dims[::-1]
self._init_cor_finder(method, cor_options)
def _init_cor_finder(self, method, cor_options):
self.method = method
if not isinstance(cor_options, (type(None), dict)):
raise TypeError(
f"cor_options is expected to be an optional instance of dict. Get {cor_options} ({type(cor_options)}) instead"
)
self.cor_options = {}
if isinstance(cor_options, dict):
self.cor_options.update(cor_options)
# tomotools internal meeting 07 feb 2024: Merge of options 'near_pos' and 'side'.
# See [minutes](https://gitlab.esrf.fr/tomotools/minutes/-/blob/master/minutes-20240207.md?ref_type=heads)
detector_width = self.dataset_info.radio_dims[0]
default_lookup_side = "right" if self.dataset_info.is_halftomo else "center"
near_init = self.cor_options.get("side", None)
if near_init is None:
near_init = default_lookup_side
if near_init == "from_file":
try:
near_pos = self.dataset_info.dataset_scanner.estimated_cor_frm_motor # relative pos in pixels
if isinstance(near_pos, Real):
# near_pos += detector_width // 2 # Field in NX is relative.
self.cor_options.update({"near_pos": int(near_pos)})
else:
near_init = default_lookup_side
except:
self.logger.warning(
"COR estimation from motor position absent from NX file. Global search is performed."
)
near_init = default_lookup_side
elif isinstance(near_init, Real):
self.cor_options.update({"near_pos": int(near_init)})
near_init = "near" # ???
elif near_init == "near": # Legacy
if not isinstance(self.cor_options["near_pos"], Real):
self.logger.warning("Side option set to 'near' but no 'near_pos' option set.")
self.logger.warning("Set side to right if HA, center otherwise.")
near_init = default_lookup_side
elif near_init in ("left", "right", "center", "all"):
pass
else:
self.logger.warning(
f"COR option 'side' received {near_init} and should be either 'from_file' (default), 'left', 'right', 'center', 'near' or a number."
)
if isinstance(self.cor_options.get("near_pos", None), Real):
# Check validity of near_pos
if np.abs(self.cor_options["near_pos"]) > detector_width / 2:
self.logger.warning(
f"Relative COR passed is greater than half the size of the detector. Did you enter a absolute COR position?"
)
self.logger.warning("Instead, the center of the detector is used.")
self.cor_options["near_pos"] = 0
# Set side from near_pos if passed.
if self.cor_options["near_pos"] < 0.0:
self.cor_options.update({"side": "left"})
near_init = "left"
else:
self.cor_options.update({"side": "right"})
near_init = "right"
self.cor_options.update({"side": near_init})
# At this stage : side is set to one of left, right, center near.
# and near_pos to a numeric value.
# if isinstance(self.cor_options["near_pos"], Real):
# # estimated_cor_frm_motor value is supposed to be relative. Since the config documentation expects the "near_pos" options
# # to be given as an absolute COR estimate, a conversion is needed.
# self.cor_options["near_pos"] += detector_width // 2 # converted in absolute nb of pixels.
# if not (isinstance(self.cor_options["near_pos"], Real) or self.cor_options["near_pos"] == "ignore"):
# self.cor_options.update({"near_pos": "ignore"})
# At this stage, cor_options["near_pos"] is either
# - 'ignore':
# - an (absolute) integer value (either the user-provided one if present or the NX one).
cor_class = self.search_methods[method]["class"]
self.cor_finder = cor_class(logger=self.logger, cor_options=self.cor_options)
lookup_side = self.cor_options.get("side", default_lookup_side)
# OctaveAccurate
# if cor_class == CenterOfRotationOctaveAccurate:
# lookup_side = "center"
angles = self.dataset_info.rotation_angles
self.cor_exec_args = []
self.cor_exec_args.extend(self.search_methods[method].get("default_args", []))
# CenterOfRotationSlidingWindow is the only class to have a mandatory argument ("side")
# TODO - it would be more elegant to have it as a kwarg...
if len(self.cor_exec_args) > 0:
if cor_class in (CenterOfRotationSlidingWindow, CenterOfRotationOctaveAccurate):
self.cor_exec_args[0] = lookup_side
elif cor_class in (CenterOfRotationFourierAngles,):
self.cor_exec_args[0] = angles
self.cor_exec_args[1] = lookup_side
#
self.cor_exec_kwargs = update_func_kwargs(self.cor_finder.find_shift, self.cor_options)
[docs]
class CORFinder(CORFinderBase):
"""
Find the Center of Rotation with methods based on two (180-degrees opposed) radios.
"""
search_methods = {
"centered": {
"class": CenterOfRotation,
},
"global": {
"class": CenterOfRotationAdaptiveSearch,
"default_kwargs": {"low_pass": 1, "high_pass": 20},
},
"sliding-window": {
"class": CenterOfRotationSlidingWindow,
"default_args": ["center"],
},
"growing-window": {
"class": CenterOfRotationGrowingWindow,
},
"octave-accurate": {
"class": CenterOfRotationOctaveAccurate,
"default_args": ["center"],
},
}
def __init__(
self, method, dataset_info, do_flatfield=True, cor_options=None, logger=None, radio_angles: tuple = (0.0, np.pi)
):
"""
Initialize a CORFinder object.
Parameters
----------
dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer`
Dataset information structure
radio_angles: angles to use to find the cor
"""
super().__init__(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger)
self._radio_angles = radio_angles
self._init_radios()
self._init_flatfield()
self._apply_flatfield()
self._apply_tilt()
def _init_radios(self):
self.radios, self._radios_indices = get_radio_pair(
self.dataset_info, radio_angles=self._radio_angles, return_indices=True
)
def _init_flatfield(self):
if not (self.do_flatfield):
return
self.flatfield = FlatFieldDataUrls(
self.radios.shape,
flats=self.dataset_info.flats,
darks=self.dataset_info.darks,
radios_indices=self._radios_indices,
interpolation="linear",
convert_float=True,
)
def _apply_flatfield(self):
if not (self.do_flatfield):
return
self.flatfield.normalize_radios(self.radios)
def _apply_tilt(self):
tilt = self.dataset_info.detector_tilt
if tilt is None:
return
self.logger.debug("COREstimator: applying detector tilt correction of %f degrees" % tilt)
rot = Rotation(self.shape, tilt)
for i in range(self.radios.shape[0]):
self.radios[i] = rot.rotate(self.radios[i])
[docs]
def find_cor(self):
"""
Find the center of rotation.
Returns
-------
cor: float
The estimated center of rotation for the current dataset.
"""
self.logger.info("Estimating center of rotation")
self.logger.debug("%s.find_shift(%s)" % (self.cor_finder.__class__.__name__, str(self.cor_exec_kwargs)))
shift = self.cor_finder.find_shift(
self.radios[0], np.fliplr(self.radios[1]), *self.cor_exec_args, **self.cor_exec_kwargs
)
return self.shape[1] / 2 + shift
# alias
COREstimator = CORFinder
[docs]
class SinoCORFinder(CORFinderBase):
"""
A class for finding Center of Rotation based on 360 degrees sinograms.
This class handles the steps of building the sinogram from raw radios.
"""
search_methods = {
"sino-coarse-to-fine": {
"class": SinoCorInterface,
},
"sino-sliding-window": {
"class": CenterOfRotationSlidingWindow,
"default_args": ["right"],
},
"sino-growing-window": {
"class": CenterOfRotationGrowingWindow,
},
"fourier-angles": {"class": CenterOfRotationFourierAngles, "default_args": [None, "center"]},
}
def __init__(
self, method, dataset_info, slice_idx="middle", subsampling=10, do_flatfield=True, cor_options=None, logger=None
):
"""
Initialize a SinoCORFinder object.
Other parameters
----------------
The following keys can be set in cor_options.
slice_idx: int or str
Which slice index to take for building the sinogram.
For example slice_idx=0 means that we extract the first line of each projection.
Value can also be "first", "top", "middle", "last", "bottom".
subsampling: int, float
subsampling strategy when building sinograms.
As building the complete sinogram from raw projections might be tedious, the reading is done with subsampling.
A positive integer value means the subsampling step (i.e `projections[::subsampling]`).
A negative integer value means we take -subsampling projections in total.
A float value indicates the angular step in DEGREES.
"""
super().__init__(method, dataset_info, do_flatfield=do_flatfield, cor_options=cor_options, logger=logger)
self._check_360()
self._set_slice_idx(slice_idx)
self._set_subsampling(subsampling)
self._load_raw_sinogram()
self._flatfield(do_flatfield)
self._get_sinogram()
def _check_360(self):
if self.dataset_info.dataset_scanner.scan_range == 360:
return
if not is_fullturn_scan(self.dataset_info.rotation_angles):
raise ValueError("Sinogram-based Center of Rotation estimation can only be used for 360 degrees scans")
def _set_slice_idx(self, slice_idx):
n_z = self.dataset_info.radio_dims[1]
if isinstance(slice_idx, str):
str_to_idx = {"top": 0, "first": 0, "middle": n_z // 2, "bottom": n_z - 1, "last": n_z - 1}
check_supported(slice_idx, str_to_idx.keys(), "slice location")
slice_idx = str_to_idx[slice_idx]
self.slice_idx = slice_idx
def _set_subsampling(self, subsampling):
projs_idx = sorted(self.dataset_info.projections.keys())
if is_int(subsampling):
if subsampling < 0: # Total number of angles
n_angles = -subsampling
indices_float = np.linspace(projs_idx[0], projs_idx[-1], n_angles, endpoint=True)
self.projs_indices = np.round(indices_float).astype(np.int32).tolist()
else: # Subsampling step
self.projs_indices = projs_idx[::subsampling]
self.angles = self.dataset_info.rotation_angles[::subsampling]
else: # Angular step
raise NotImplementedError()
def _load_raw_sinogram(self):
if self.slice_idx is None:
raise ValueError("Unknow slice index")
# Subsample projections
files = {}
for idx in self.projs_indices:
files[idx] = self.dataset_info.projections[idx]
self.files = files
self.data_reader = ChunkReader(
self.files,
sub_region=(None, None, self.slice_idx, self.slice_idx + 1),
convert_float=True,
)
self.data_reader.load_files()
self._radios = self.data_reader.files_data
def _flatfield(self, do_flatfield):
self.do_flatfield = bool(do_flatfield)
if not self.do_flatfield:
return
flatfield = FlatFieldDataUrls(
self._radios.shape,
self.dataset_info.flats,
self.dataset_info.darks,
radios_indices=self.projs_indices,
sub_region=(None, None, self.slice_idx, self.slice_idx + 1),
)
flatfield.normalize_radios(self._radios)
def _get_sinogram(self):
log = Log(self._radios.shape, clip_min=1e-6, clip_max=10.0)
sinogram = self._radios[:, 0, :].copy()
log.take_logarithm(sinogram)
self.sinogram = sinogram
@staticmethod
def _split_sinogram(sinogram):
n_a_2 = sinogram.shape[0] // 2
img_1, img_2 = sinogram[:n_a_2], sinogram[n_a_2:]
# "Handle" odd number of projections
if img_2.shape[0] > img_1.shape[0]:
img_2 = img_2[:-1, :]
#
return img_1, img_2
[docs]
def find_cor(self):
self.logger.info("Estimating center of rotation")
self.logger.debug("%s.find_shift(%s)" % (self.cor_finder.__class__.__name__, str(self.cor_exec_kwargs)))
img_1, img_2 = self._split_sinogram(self.sinogram)
shift = self.cor_finder.find_shift(img_1, np.fliplr(img_2), *self.cor_exec_args, **self.cor_exec_kwargs)
return self.shape[1] / 2 + shift
# alias
SinoCOREstimator = SinoCORFinder
[docs]
class CompositeCORFinder(CORFinderBase):
"""
Class and method to prepare sinogram and calculate COR
The pseudo sinogram is built with shrinked radios taken every theta_interval degres
Compared to first writing by Christian Nemoz:
- gives the same result of the original octave script on the dataset sofar tested
- The meaning of parameter n_subsampling_y (alias subsampling_y)is now the number of lines which are taken from
every radio. This is more meaningful in terms of amout of collected information because it
does not depend on the radio size. Moreover this is what was done in the octave script
- The spike_threshold has been added with default to 0.04
- The angular sampling is every 5 degree by default, as it is now the case also in the octave script
- The finding of the optimal overlap is doing by looping over the possible overlap, according to the overlap.
After a first testing phase, this part, which is the time consuming part, can be accelerated
by several order of magnitude without modifing the final result
"""
search_methods = {
"composite-coarse-to-fine": {
"class": CenterOfRotation, # Hack. Not used. Everything is done in the find_cor() func.
}
}
_default_cor_options = {"low_pass": 0.4, "high_pass": 10, "side": "center", "near_pos": 0, "near_width": 20}
def __init__(
self,
dataset_info,
oversampling=4,
theta_interval=5,
n_subsampling_y=10,
take_log=True,
cor_options=None,
spike_threshold=0.04,
logger=None,
norm_order=1,
):
super().__init__(
"composite-coarse-to-fine", dataset_info, do_flatfield=True, cor_options=cor_options, logger=logger
)
if norm_order not in [1, 2]:
raise ValueError(
f""" the norm order (nom_order parameter) must be either 1 or 2. You passed {norm_order}
"""
)
self.norm_order = norm_order
self.dataset_info = dataset_info
self.logger = LoggerOrPrint(logger)
self.sx, self.sy = self.dataset_info.radio_dims
default_cor_options = self._default_cor_options.copy()
default_cor_options.update(self.cor_options)
self.cor_options = default_cor_options
# the algorithm can work for angular ranges larger than 1.2*pi
# up to an arbitrarily number of turns as it is the case in helical scans
self.spike_threshold = spike_threshold
# the following line is necessary for multi-turns scan because the encoders is always
# in the interval 0-360
self.unwrapped_rotation_angles = np.unwrap(self.dataset_info.rotation_angles)
self.angle_min = self.unwrapped_rotation_angles.min()
self.angle_max = self.unwrapped_rotation_angles.max()
if (self.angle_max - self.angle_min) < 1.2 * np.pi:
useful_span = None
raise ValueError(
f"""Sinogram-based Center of Rotation estimation can only be used for scans over more than 180 degrees.
Your angular span was barely above 180 degrees, it was in fact {((self.angle_max - self.angle_min)/np.pi):.2f} x 180
and it is not considered to be enough by the discriminating condition which requires at least 1.2 half-turns
"""
)
else:
useful_span = min(np.pi, (self.angle_max - self.angle_min) - np.pi)
# readapt theta_interval accordingly if the span is smaller than pi
if useful_span < np.pi:
theta_interval = theta_interval * useful_span / np.pi
# self._get_cor_options(cor_options)
self.take_log = take_log
self.ovs = oversampling
self.theta_interval = theta_interval
target_sampling_y = np.round(np.linspace(0, self.sy - 1, n_subsampling_y + 2)).astype(int)[1:-1]
if self.spike_threshold is not None:
# take also one line below and on above for each line
# to provide appropriate margin
self.sampling_y = np.zeros([3 * len(target_sampling_y)], "i")
self.sampling_y[0::3] = np.maximum(0, target_sampling_y - 1)
self.sampling_y[2::3] = np.minimum(self.sy - 1, target_sampling_y + 1)
self.sampling_y[1::3] = target_sampling_y
self.ccd_correction = CCDFilter((len(self.sampling_y), self.sx), median_clip_thresh=self.spike_threshold)
else:
self.sampling_y = target_sampling_y
self.nproj = self.dataset_info.n_angles
my_condition = np.less(self.unwrapped_rotation_angles + np.pi, self.angle_max) * np.less(
self.unwrapped_rotation_angles, self.angle_min + useful_span
)
possibly_probed_angles = self.unwrapped_rotation_angles[my_condition]
possibly_probed_indices = np.arange(len(self.unwrapped_rotation_angles))[my_condition]
self.dproj = round(len(possibly_probed_angles) / np.rad2deg(useful_span) * self.theta_interval)
self.probed_angles = possibly_probed_angles[:: self.dproj]
self.probed_indices = possibly_probed_indices[:: self.dproj]
self.absolute_indices = sorted(self.dataset_info.projections.keys())
my_flats = load_images_from_dataurl_dict(self.dataset_info.flats)
if my_flats is not None and len(list(my_flats.keys())):
self.use_flat = True
self.flatfield = FlatFieldDataUrls(
(len(self.absolute_indices), self.sy, self.sx),
self.dataset_info.flats,
self.dataset_info.darks,
radios_indices=self.absolute_indices,
dtype=np.float64,
)
else:
self.use_flat = False
self.sx, self.sy = self.dataset_info.radio_dims
self.mlog = Log((1,) + (self.sy, self.sx), clip_min=1e-6, clip_max=10.0)
self.rcor_abs = round(self.sx / 2.0)
self.cor_acc = round(self.sx / 2.0)
self.nprobed = len(self.probed_angles)
# initialize sinograms and radios arrays
self.sino = np.zeros([2 * self.nprobed * n_subsampling_y, (self.sx - 1) * self.ovs + 1], "f")
self._loaded = False
self.high_pass = self.cor_options["high_pass"]
img_filter = fourier_filters.get_bandpass_filter(
(self.sino.shape[0] // 2, self.sino.shape[1]),
cutoff_lowpass=self.cor_options["low_pass"] * self.ovs,
cutoff_highpass=self.high_pass * self.ovs,
use_rfft=False, # rfft changes the image dimensions lenghts to even if odd
data_type=np.float64,
)
# we are interested in filtering only along the x dimension only
img_filter[:] = img_filter[0]
self.img_filter = img_filter
def _oversample(self, radio):
"""oversampling in the horizontal direction"""
if self.ovs == 1:
return radio
else:
ovs_2D = [1, self.ovs]
return oversample(radio, ovs_2D)
def _get_cor_options(self, cor_options):
default_dict = self._default_cor_options.copy()
if self.dataset_info.is_halftomo:
default_dict["side"] = "right"
if cor_options is None or cor_options == "":
cor_options = {}
if isinstance(cor_options, str):
try:
cor_options = extract_parameters(cor_options, sep=";")
except Exception as exc:
msg = "Could not extract parameters from cor_options: %s" % (str(exc))
self.logger.fatal(msg)
raise ValueError(msg)
default_dict.update(cor_options)
cor_options = default_dict
self.cor_options = cor_options
[docs]
def get_radio(self, image_num):
# radio_dataset_idx = self.absolute_indices[image_num]
radio_dataset_idx = image_num
data_url = self.dataset_info.projections[radio_dataset_idx]
radio = get_data(data_url).astype(np.float64)
if self.use_flat:
self.flatfield.normalize_single_radio(radio, radio_dataset_idx, dtype=radio.dtype)
if self.take_log:
self.mlog.take_logarithm(radio)
radio = radio[self.sampling_y]
if self.spike_threshold is not None:
self.ccd_correction.median_clip_correction(radio, output=radio)
radio = radio[1::3]
return radio
[docs]
def get_sino(self, reload=False):
"""
Build sinogram (composite image) from the radio files
"""
if self._loaded and not reload:
return self.sino
sorting_indexes = np.argsort(self.unwrapped_rotation_angles)
sorted_all_angles = self.unwrapped_rotation_angles[sorting_indexes]
sorted_angle_indexes = np.arange(len(self.unwrapped_rotation_angles))[sorting_indexes]
irad = 0
for prob_a, prob_i in zip(self.probed_angles, self.probed_indices):
radio1 = self.get_radio(self.absolute_indices[prob_i])
other_angle = prob_a + np.pi
insertion_point = np.searchsorted(sorted_all_angles, other_angle)
if insertion_point > 0 and insertion_point < len(sorted_all_angles):
other_i_l = sorted_angle_indexes[insertion_point - 1]
other_i_h = sorted_angle_indexes[insertion_point]
radio_l = self.get_radio(self.absolute_indices[other_i_l])
radio_h = self.get_radio(self.absolute_indices[other_i_h])
f = (other_angle - sorted_all_angles[insertion_point - 1]) / (
sorted_all_angles[insertion_point] - sorted_all_angles[insertion_point - 1]
)
radio2 = (1 - f) * radio_l + f * radio_h
else:
if insertion_point == 0:
other_i = sorted_angle_indexes[0]
elif insertion_point == len(sorted_all_angles):
other_i = sorted_angle_indexes[insertion_point - 1]
radio2 = self.get_radio(self.absolute_indices[other_i])
self.sino[irad : irad + radio1.shape[0], :] = self._oversample(radio1)
self.sino[
irad + self.nprobed * radio1.shape[0] : irad + self.nprobed * radio1.shape[0] + radio1.shape[0], :
] = self._oversample(radio2)
irad = irad + radio1.shape[0]
self.sino[np.isnan(self.sino)] = 0.0001 # ?
return self.sino
[docs]
def find_cor(self, reload=False):
self.logger.info("Estimating center of rotation")
self.logger.debug("%s.find_shift(%s)" % (self.__class__.__name__, self.cor_options))
self.sinogram = self.get_sino(reload=reload)
dim_v, dim_h = self.sinogram.shape
assert dim_v % 2 == 0, " this should not happen "
dim_v = dim_v // 2
radio1 = self.sinogram[:dim_v]
radio2 = self.sinogram[dim_v:]
orig_sy, orig_ovsd_sx = radio1.shape
radio1 = scipy.fft.ifftn(
scipy.fft.fftn(radio1, axes=(-2, -1)) * self.img_filter, axes=(-2, -1)
).real # TODO: convolute only along x
radio2 = scipy.fft.ifftn(
scipy.fft.fftn(radio2, axes=(-2, -1)) * self.img_filter, axes=(-2, -1)
).real # TODO: convolute only along x
tmp_sy, ovsd_sx = radio1.shape
assert orig_sy == tmp_sy and orig_ovsd_sx == ovsd_sx, "this should not happen"
if self.cor_options["side"] == "center":
overlap_min = max(round(ovsd_sx - ovsd_sx / 3), 4)
overlap_max = min(round(ovsd_sx + ovsd_sx / 3), 2 * ovsd_sx - 4)
elif self.cor_options["side"] == "right":
overlap_min = max(4, self.ovs * self.high_pass * 3)
overlap_max = ovsd_sx
elif self.cor_options["side"] == "left":
overlap_min = ovsd_sx
overlap_max = min(2 * ovsd_sx - 4, 2 * ovsd_sx - self.ovs * self.ovs * self.high_pass * 3)
elif self.cor_options["side"] == "all":
overlap_min = max(4, self.ovs * self.high_pass * 3)
overlap_max = min(2 * ovsd_sx - 4, 2 * ovsd_sx - self.ovs * self.ovs * self.high_pass * 3)
elif self.cor_options["side"] == "near":
near_pos = self.cor_options["near_pos"]
near_width = self.cor_options["near_width"]
overlap_min = max(4, ovsd_sx - 2 * self.ovs * (near_pos + near_width))
overlap_max = min(2 * ovsd_sx - 4, ovsd_sx - 2 * self.ovs * (near_pos - near_width))
else:
message = f""" The cor options "side" can only have one of the three possible values ["","",""].
But it has the value "{self.cor_options["side"]}" instead
"""
raise ValueError(message)
if overlap_min > overlap_max:
message = f""" There is no safe search range in find_cor once the margins corresponding to the high_pass filter are discarded.
Try reducing the low_pass parameter in cor_options
"""
raise ValueError(message)
self.logger.info(
"looking for overlap from min %.2f and max %.2f\n" % (overlap_min / self.ovs, overlap_max / self.ovs)
)
best_overlap = overlap_min
best_error = np.inf
blurred_radio1 = nd.gaussian_filter(abs(radio1), [0, self.high_pass])
blurred_radio2 = nd.gaussian_filter(abs(radio2), [0, self.high_pass])
for z in range(int(overlap_min), int(overlap_max) + 1):
if z <= ovsd_sx:
my_z = z
my_radio1 = radio1
my_radio2 = radio2
my_blurred_radio1 = blurred_radio1
my_blurred_radio2 = blurred_radio2
else:
my_z = ovsd_sx - (z - ovsd_sx)
my_radio1 = np.fliplr(radio1)
my_radio2 = np.fliplr(radio2)
my_blurred_radio1 = np.fliplr(blurred_radio1)
my_blurred_radio2 = np.fliplr(blurred_radio2)
common_left = np.fliplr(my_radio1[:, ovsd_sx - my_z :])[:, : -int(math.ceil(self.ovs * self.high_pass * 2))]
# adopt a 'safe' margin considering high_pass value (possibly float)
common_right = my_radio2[:, ovsd_sx - my_z : -int(math.ceil(self.ovs * self.high_pass * 2))]
common_blurred_left = np.fliplr(my_blurred_radio1[:, ovsd_sx - my_z :])[
:, : -int(math.ceil(self.ovs * self.high_pass * 2))
]
# adopt a 'safe' margin considering high_pass value (possibly float)
common_blurred_right = my_blurred_radio2[:, ovsd_sx - my_z : -int(math.ceil(self.ovs * self.high_pass * 2))]
if common_right.size == 0:
continue
error = self.error_metric(common_right, common_left, common_blurred_right, common_blurred_left)
min_error = min(best_error, error)
if min_error == error:
best_overlap = z
best_error = min_error
# self.logger.debug(
# "testing an overlap of %.2f pixels, actual best overlap is %.2f pixels over %d\r"
# % (z / self.ovs, best_overlap / self.ovs, ovsd_sx / self.ovs),
# )
offset = (ovsd_sx - best_overlap) / self.ovs / 2
cor_abs = (self.sx - 1) / 2 + offset
return cor_abs
[docs]
def error_metric(self, common_right, common_left, common_blurred_right, common_blurred_left):
if self.norm_order == 2:
return self.error_metric_l2(common_right, common_left)
elif self.norm_order == 1:
return self.error_metric_l1(common_right, common_left, common_blurred_right, common_blurred_left)
else:
assert False, "this cannot happen"
[docs]
def error_metric_l2(self, common_right, common_left):
common = common_right - common_left
tmp = np.linalg.norm(common)
norm_diff2 = tmp * tmp
norm_right = np.linalg.norm(common_right)
norm_left = np.linalg.norm(common_left)
res = norm_diff2 / (norm_right * norm_left)
return res
[docs]
def error_metric_l1(self, common_right, common_left, common_blurred_right, common_blurred_left):
common = (common_right - common_left) / (common_blurred_right + common_blurred_left)
res = abs(common).mean()
return res
[docs]
def oversample(radio, ovs_s):
"""oversampling an image in arbitrary directions.
The first and last point of each axis will still remain as extremal points of the new axis.
"""
result = np.zeros([(radio.shape[0] - 1) * ovs_s[0] + 1, (radio.shape[1] - 1) * ovs_s[1] + 1], "f")
# Pre-initialisation: The original data falls exactly on the following strided positions in the new data array.
result[:: ovs_s[0], :: ovs_s[1]] = radio
for k in range(0, ovs_s[0]):
# interpolation coefficient for axis 0
g = k / ovs_s[0]
for i in range(0, ovs_s[1]):
if i == 0 and k == 0:
# this case subset was already exactly matched from before the present double loop,
# in the pre-initialisation line.
continue
# interpolation coefficent for axis 1
f = i / ovs_s[1]
# stop just a bit before cause we are not extending beyond the limits.
# If we are exacly on a vertical or horizontal original line, then no shift will be applied,
# and we will exploit the equality f+(1-f)=g+(1-g)=1 adding twice the same contribution with
# interpolation factors which become dummies pour le coup.
stop0 = -ovs_s[0] if k else None
stop1 = -ovs_s[1] if i else None
# Once again, we exploit the g+(1-g)=1 equality
start0 = ovs_s[0] if k else 0
start1 = ovs_s[1] if i else 0
# and what is done below makes clear the corundum above.
result[k :: ovs_s[0], i :: ovs_s[1]] = (1 - g) * (
(1 - f) * result[0 : stop0 : ovs_s[0], 0 : stop1 : ovs_s[1]]
+ f * result[0 : stop0 : ovs_s[0], start1 :: ovs_s[1]]
) + g * (
(1 - f) * result[start0 :: ovs_s[0], 0 : stop1 : ovs_s[1]]
+ f * result[start0 :: ovs_s[0], start1 :: ovs_s[1]]
)
return result
# alias
CompositeCOREstimator = CompositeCORFinder
# Some heavily inelegant things going on here
[docs]
def get_default_kwargs(func):
params = inspect.signature(func).parameters
res = {}
for param_name, param in params.items():
if param.default != inspect._empty:
res[param_name] = param.default
return res
[docs]
def update_func_kwargs(func, options):
res_options = get_default_kwargs(func)
for option_name, option_val in options.items():
if option_name in res_options:
res_options[option_name] = option_val
return res_options
[docs]
def get_class_name(class_object):
return str(class_object).split(".")[-1].strip(">").strip("'").strip('"')
[docs]
class DetectorTiltEstimator:
"""
Helper class for detector tilt estimation.
It automatically chooses the right radios and performs flat-field.
"""
default_tilt_method = "1d-correlation"
# Given a tilt angle "a", the maximum deviation caused by the tilt (in pixels) is
# N/2 * |sin(a)| where N is the number of pixels
# We ignore tilts causing less than 0.25 pixel deviation: N/2*|sin(a)| < tilt_threshold
tilt_threshold = 0.25
def __init__(self, dataset_info, do_flatfield=True, logger=None, autotilt_options=None):
"""
Initialize a detector tilt estimator helper.
Parameters
----------
dataset_info: `dataset_info` object
Data structure with the dataset information.
do_flatfield: bool, optional
Whether to perform flat field on radios.
logger: `Logger` object, optional
Logger object
autotilt_options: dict, optional
named arguments to pass to the detector tilt estimator class.
"""
self._set_params(dataset_info, do_flatfield, logger, autotilt_options)
self.radios, self.radios_indices = get_radio_pair(dataset_info, radio_angles=(0.0, np.pi), return_indices=True)
self._init_flatfield()
self._apply_flatfield()
def _set_params(self, dataset_info, do_flatfield, logger, autotilt_options):
self.dataset_info = dataset_info
self.do_flatfield = bool(do_flatfield)
self.logger = LoggerOrPrint(logger)
self._get_autotilt_options(autotilt_options)
def _init_flatfield(self):
if not (self.do_flatfield):
return
self.flatfield = FlatFieldDataUrls(
self.radios.shape,
flats=self.dataset_info.flats,
darks=self.dataset_info.darks,
radios_indices=self.radios_indices,
interpolation="linear",
convert_float=True,
)
def _apply_flatfield(self):
if not (self.do_flatfield):
return
self.flatfield.normalize_radios(self.radios)
def _get_autotilt_options(self, autotilt_options):
if autotilt_options is None:
self.autotilt_options = None
return
try:
autotilt_options = extract_parameters(autotilt_options)
except Exception as exc:
msg = "Could not extract parameters from autotilt_options: %s" % (str(exc))
self.logger.fatal(msg)
raise ValueError(msg)
self.autotilt_options = autotilt_options
if "threshold" in autotilt_options:
self.tilt_threshold = autotilt_options.pop("threshold")
[docs]
def find_tilt(self, tilt_method=None):
"""
Find the detector tilt.
Parameters
----------
tilt_method: str, optional
Which tilt estimation method to use.
"""
if tilt_method is None:
tilt_method = self.default_tilt_method
check_supported(tilt_method, set(tilt_methods.values()), "tilt estimation method")
self.logger.info("Estimating detector tilt angle")
autotilt_params = {
"roi_yxhw": None,
"median_filt_shape": None,
"padding_mode": None,
"peak_fit_radius": 1,
"high_pass": None,
"low_pass": None,
}
autotilt_params.update(self.autotilt_options or {})
self.logger.debug("%s(%s)" % ("CameraTilt", str(autotilt_params)))
tilt_calc = CameraTilt()
tilt_cor_position, camera_tilt = tilt_calc.compute_angle(
self.radios[0], np.fliplr(self.radios[1]), method=tilt_method, **autotilt_params
)
self.logger.info("Estimated detector tilt angle: %f degrees" % camera_tilt)
# Ignore too small tilts
max_deviation = np.max(self.dataset_info.radio_dims) * np.abs(np.sin(np.deg2rad(camera_tilt)))
if self.dataset_info.is_halftomo:
max_deviation *= 2
if max_deviation < self.tilt_threshold:
self.logger.info(
"Estimated tilt angle (%.3f degrees) results in %.2f maximum pixels shift, which is below threshold (%.2f pixel). Ignoring the tilt, no correction will be done."
% (camera_tilt, max_deviation, self.tilt_threshold)
)
camera_tilt = None
return camera_tilt
# alias
TiltFinder = DetectorTiltEstimator