from math import pi
import numpy as np
from scipy.fft import rfft, irfft
from silx.image.tomography import compute_fourier_filter, get_next_power
from ..processing.padding_base import PaddingBase
from ..utils import check_supported, get_num_threads
# # COMPAT.
# from .filtering_cuda import CudaSinoFilter
# SinoFilter = deprecated_class(
# "From version 2023.1, 'filtering_cuda.CudaSinoFilter' should be used instead of 'filtering.SinoFilter'. In the future, 'filtering.SinoFilter' will be a numpy-only class.",
# do_print=True,
# )(CudaSinoFilter)
# #
[docs]
class SinoFilter:
available_filters = [
"ramlak",
"shepp-logan",
"cosine",
"hamming",
"hann",
"tukey",
"lanczos",
"hilbert",
]
"""
A class for sinogram filtering.
It does the following:
- pad input array
- Fourier transform each row
- multiply with a 1D or 2D filter
- inverse Fourier transform
"""
available_padding_modes = PaddingBase.supported_modes
default_extra_options = {"cutoff": 1.0, "fft_threads": 0} # use all threads by default
def __init__(
self,
sino_shape,
filter_name=None,
padding_mode="zeros",
extra_options=None,
):
self._init_extra_options(extra_options)
self._set_padding_mode(padding_mode)
self._calculate_shapes(sino_shape)
self._init_fft()
self._allocate_memory()
self._compute_filter(filter_name)
def _init_extra_options(self, extra_options):
self.extra_options = self.default_extra_options.copy()
self.extra_options.update(extra_options or {})
def _set_padding_mode(self, padding_mode):
# Compat.
if padding_mode == "edges":
padding_mode = "edge"
if padding_mode == "zeros":
padding_mode = "constant"
#
check_supported(padding_mode, self.available_padding_modes, "padding mode")
self.padding_mode = padding_mode
def _calculate_shapes(self, sino_shape):
self.ndim = len(sino_shape)
if self.ndim == 2:
n_angles, dwidth = sino_shape
n_sinos = 1
elif self.ndim == 3:
n_sinos, n_angles, dwidth = sino_shape
else:
raise ValueError("Invalid sinogram number of dimensions")
self.sino_shape = sino_shape
self.n_angles = n_angles
self.dwidth = dwidth
# Make sure to use int() here, otherwise pycuda/pyopencl will crash in some cases
self.dwidth_padded = int(get_next_power(2 * self.dwidth))
self.sino_padded_shape = (n_angles, self.dwidth_padded)
if self.ndim == 3:
self.sino_padded_shape = (n_sinos,) + self.sino_padded_shape
sino_f_shape = list(self.sino_padded_shape)
sino_f_shape[-1] = sino_f_shape[-1] // 2 + 1
self.sino_f_shape = tuple(sino_f_shape)
self.pad_left = (self.dwidth_padded - self.dwidth) // 2
self.pad_right = self.dwidth_padded - self.dwidth - self.pad_left
def _init_fft(self):
pass
def _allocate_memory(self):
pass
[docs]
def set_filter(self, h_filt, normalize=True):
"""
Set a filter for sinogram filtering.
Parameters
----------
h_filt: numpy.ndarray
Array containing the filter. Each line of the sinogram will be filtered with
this filter. It has to be the Real-to-Complex Fourier Transform
of some real filter, padded to 2*sinogram_width.
normalize: bool or float, optional
Whether to normalize (multiply) the filter with pi/num_angles.
"""
if h_filt.size != self.sino_f_shape[-1]:
raise ValueError(
"""
Invalid filter size: expected %d, got %d.
Please check that the filter is the Fourier R2C transform of
some real 1D filter.
"""
% (self.sino_f_shape[-1], h_filt.size)
)
if not (np.iscomplexobj(h_filt)):
print("Warning: expected a complex Fourier filter")
self.filter_f = h_filt.copy()
if normalize:
self.filter_f *= pi / self.n_angles
self.filter_f = self.filter_f.astype(np.complex64)
def _compute_filter(self, filter_name):
self.filter_name = filter_name or "ram-lak"
# TODO add this one into silx
if self.filter_name == "hilbert":
freqs = np.fft.fftfreq(self.dwidth_padded)
filter_f = 1.0 / (2 * pi * 1j) * np.sign(freqs)
#
else:
filter_f = compute_fourier_filter(
self.dwidth_padded,
self.filter_name,
cutoff=self.extra_options["cutoff"],
)
filter_f = filter_f[: self.dwidth_padded // 2 + 1] # R2C
self.set_filter(filter_f, normalize=True)
def _check_array(self, arr):
if arr.dtype != np.float32:
raise ValueError("Expected data type = numpy.float32")
if arr.shape != self.sino_shape:
raise ValueError("Expected sinogram shape %s, got %s" % (self.sino_shape, arr.shape))
[docs]
def filter_sino(self, sino, output=None, no_output=False):
"""
Perform the sinogram siltering.
Parameters
----------
sino: numpy.ndarray or pycuda.gpuarray.GPUArray
Input sinogram (2D or 3D)
output: numpy.ndarray or pycuda.gpuarray.GPUArray, optional
Output array.
no_output: bool, optional
If set to True, no copy is be done. The resulting data lies
in self.d_sino_padded.
"""
self._check_array(sino)
sino_padded = np.pad(
sino, ((0, 0), (0, self.dwidth_padded - self.dwidth)), mode=self.padding_mode
) # pad with a FFT-friendly layout
sino_padded_f = rfft(sino_padded, axis=1, workers=get_num_threads(self.extra_options["fft_threads"]))
sino_padded_f *= self.filter_f
sino_filtered = irfft(sino_padded_f, axis=1, workers=get_num_threads(self.extra_options["fft_threads"]))
if output is None:
res = np.zeros(self.sino_shape, dtype=np.float32)
else:
res = output
if self.ndim == 2:
res[:] = sino_filtered[:, : self.dwidth] # pylint: disable=E1126 # ?!
else:
res[:] = sino_filtered[:, :, : self.dwidth] # pylint: disable=E1126 # ?!
return res
__call__ = filter_sino
[docs]
def filter_sinogram(
sinogram,
padded_width,
filter_name="ramlak",
padding_mode="constant",
normalize=True,
filter_cutoff=1.0,
**padding_kwargs,
):
"""
Simple function to filter sinogram.
Parameters
----------
sinogram: numpy.ndarray
Sinogram, two dimensional array with shape (n_angles, sino_width)
padded_width: int
Width to use for padding. Must be greater than sinogram width (i.e than sinogram.shape[-1])
filter_name: str, optional
Which filter to use. Default is ramlak (roughly equivalent to abs(nu) in frequency domain)
padding_mode: str, optional
Which padding mode to use. Default is zero-padding.
normalize: bool, optional
Whether to multiply the filtered sinogram with pi/n_angles
filter_cutoff: float, optional
frequency cutoff for filter
"""
n_angles, width = sinogram.shape
sinogram_padded = np.pad(sinogram, ((0, 0), (0, padded_width - width)), mode=padding_mode, **padding_kwargs)
fourier_filter = compute_fourier_filter(padded_width, filter_name, cutoff=filter_cutoff)
if normalize:
fourier_filter *= np.pi / n_angles
fourier_filter = fourier_filter[: padded_width // 2 + 1] # R2C
sino_f = rfft(sinogram_padded, axis=1)
sino_f *= fourier_filter
sino_filtered = irfft(sino_f, axis=1)[:, :width] # pylint: disable=E1126 # ?!
return sino_filtered