Source code for nabu.reconstruction.rings

import numpy as np
from scipy.fft import rfft, irfft
from silx.image.tomography import get_next_power
from ..thirdparty.pore3d_deringer_munch import munchetal_filter
from ..utils import get_2D_3D_shape, get_num_threads, check_supported
from ..misc.fourier_filters import get_bandpass_filter

try:
    from algotom.prep.removal import remove_all_stripe

    __has_algotom__ = True
except ImportError:
    __has_algotom__ = False


[docs] class MunchDeringer: def __init__(self, sigma, sinos_shape, levels=None, wname="db15", padding=None, padding_mode="edge"): """ Initialize a "Munch Et Al" sinogram deringer. See References for more information. Parameters ----------- sigma: float Standard deviation of the damping parameter. The higher value of sigma, the more important the filtering effect on the rings. levels: int, optional Number of wavelets decomposition levels. By default (None), the maximum number of decomposition levels is used. wname: str, optional Default is "db15" (Daubechies, 15 vanishing moments) sinos_shape: tuple, optional Shape of the sinogram (or sinograms stack). padding: tuple of two int, optional Horizontal padding to use for reducing the aliasing artefacts References ---------- B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009. """ self._get_shapes(sinos_shape, padding) self.sigma = sigma self.levels = levels self.wname = wname self.padding_mode = padding_mode self._check_can_use_wavelets() def _get_shapes(self, sinos_shape, padding): n_z, n_a, n_x = get_2D_3D_shape(sinos_shape) self.sinos_shape = n_z, n_a, n_x self.n_angles = n_a self.n_z = n_z self.n_x = n_x # Handle "padding=True" or "padding=False" if isinstance(padding, bool): if padding: padding = (n_x // 2, n_x // 2) else: padding = None # if padding is not None: pad_x1, pad_x2 = padding if np.iterable(pad_x1) or np.iterable(pad_x2): raise ValueError("Expected padding in the form (x1, x2)") self.sino_padded_shape = (n_a, n_x + pad_x1 + pad_x2) self.padding = padding def _check_can_use_wavelets(self): if munchetal_filter is None: raise ValueError("Need pywavelets to use this class") def _destripe_2D(self, sino, output): if self.padding is not None: sino = np.pad(sino, ((0, 0), self.padding), mode=self.padding_mode) res = munchetal_filter(sino, self.levels, self.sigma, wname=self.wname) if self.padding is not None: res = res[:, self.padding[0] : -self.padding[1]] output[:] = res return output
[docs] def remove_rings(self, sinos, output=None): """ Main function to performs rings artefacts removal on sinogram(s). CAUTION: this function defaults to in-place processing, meaning that the sinogram(s) you pass will be overwritten. Parameters ---------- sinos: numpy.ndarray Sinogram or stack of sinograms. output: numpy.ndarray, optional Output array. If set to None (default), the output overwrites the input. """ if output is None: output = sinos if sinos.ndim == 2: return self._destripe_2D(sinos, output) n_sinos = sinos.shape[0] for i in range(n_sinos): self._destripe_2D(sinos[i], output[i]) return output
[docs] class VoDeringer: """ An interface to Nghia Vo's "remove_all_stripe". Needs algotom to run. """ def __init__(self, sinos_shape, **remove_all_stripe_options): self._check_requirement() self._get_shapes(sinos_shape) self._remove_all_stripe_kwargs = remove_all_stripe_options def _check_requirement(self): if not __has_algotom__: raise ImportError("Need algotom") def _get_shapes(self, sinos_shape): n_z, n_a, n_x = get_2D_3D_shape(sinos_shape) self.sinos_shape = n_z, n_a, n_x self.n_angles = n_a self.n_z = n_z self.n_x = n_x
[docs] def remove_rings_sinogram(self, sino, output=None): new_sino = remove_all_stripe(sino, **self._remove_all_stripe_kwargs) # out-of-place if output is not None: output[:] = new_sino[:] return output return new_sino
[docs] def remove_rings_sinograms(self, sinos, output=None): if output is None: output = sinos for i in range(sinos.shape[0]): output[i] = self.remove_rings_sinogram(sinos[i]) return output
[docs] def remove_rings_radios(self, radios): sinos = np.moveaxis(radios, 1, 0) # (n_a, n_z, n_x) --> (n_z, n_a, n_x) return self.remove_rings_sinograms(sinos)
remove_rings = remove_rings_sinograms
[docs] class SinoMeanDeringer: supported_modes = ["subtract", "divide"] def __init__(self, sinos_shape, mode="subtract", filter_cutoff=None, padding_mode="edge", fft_num_threads=None): """ Rings correction with mean subtraction/division. The principle of this method is to subtract (or divide) the sinogram by its mean along a certain axis. In short: sinogram -= filt(sinogram.mean(axis=0)) where `filt` is some bandpass filter. Parameters ---------- sinos_shape: tuple of int Sinograms shape, in the form (n_angles, n_x) or (n_sinos, n_angles, n_x) mode: str, optional Operation to do on the sinogram, either "subtract" or "divide" filter_cutoff: tuple, optional Cut-off of the bandpass filter applied on the sinogram profiles. Empty (default) means no filtering. Possible values forms are: - (sigma_low, sigma_high): two float values defining the standard deviation of gaussian(sigma_low) * (1 - gaussian(sigma_high)). High values of sigma mean stronger effect of associated filters. - ((cutoff_low, transition_low), (cutoff_high, transition_high)) where "cutoff" is in normalized Nyquist frequency (0.5 is the maximum frequency), and "transition" is the width of filter decay in fraction of the cutoff frequency padding_mode: str, optional Padding mode when filtering the sinogram profile. Should be "constant" (i.e "zeros") for mathematical correctness, but in practice this yields a Gibbs effect when replicating the sinogram, so "edges" is recommended. fft_num_threads: int, optional How many threads to use for computing the fast Fourier transform when filtering the sinogram profile. Defaut is all the available threads. """ self._get_shapes(sinos_shape) check_supported(mode, self.supported_modes, "operation mode") self.mode = mode self._init_filter(filter_cutoff, fft_num_threads, padding_mode) def _get_shapes(self, sinos_shape): n_z, n_a, n_x = get_2D_3D_shape(sinos_shape) self.sinos_shape = n_z, n_a, n_x self.n_angles = n_a self.n_z = n_z self.n_x = n_x def _init_filter(self, filter_cutoff, fft_num_threads, padding_mode): self.filter_cutoff = filter_cutoff self._filter_f = None if filter_cutoff is None: return self._filter_size = get_next_power(self.n_x * 2) self._filter_f = get_bandpass_filter( (1, self._filter_size), cutoff_lowpass=filter_cutoff[0], cutoff_highpass=filter_cutoff[1], use_rfft=True, data_type=np.float32, ).ravel() self._fft_n_threads = get_num_threads(fft_num_threads) # compat if padding_mode == "edges": padding_mode = "edge" # self.padding_mode = padding_mode size_diff = self._filter_size - self.n_x self._pad_left, self._pad_right = size_diff // 2, size_diff - size_diff // 2 def _apply_filter(self, sino_profile): if self._filter_f is None: return sino_profile sino_profile = np.pad(sino_profile, (self._pad_left, self._pad_right), mode=self.padding_mode) sino_f = rfft(sino_profile, workers=self._fft_n_threads) sino_f *= self._filter_f return irfft(sino_f, workers=self._fft_n_threads)[self._pad_left : -self._pad_right] # ascontiguousarray ?
[docs] def remove_rings_sinogram(self, sino, output=None): # if output is not None: raise NotImplementedError # sino_profile = sino.mean(axis=0) sino_profile = self._apply_filter(sino_profile) if self.mode == "subtract": sino -= sino_profile elif self.mode == "divide": sino /= sino_profile return sino
[docs] def remove_rings_sinograms(self, sinos, output=None): # if output is not None: raise NotImplementedError # for i in range(sinos.shape[0]): self.remove_rings_sinogram(sinos[i])
remove_rings = remove_rings_sinograms