import numpy as np
from ..utils import docstring, get_cuda_srcfile, updiv
from ..cuda.processing import CudaProcessing, __has_pycuda__
from ..processing.padding_cuda import CudaPadding
from ..processing.fft_cuda import get_fft_class, get_available_fft_implems
from ..processing.transpose import CudaTranspose
from ..thirdparty.tomocupy_remove_stripe import remove_all_stripe_pycuda, __have_tomocupy_deringer__
from .rings import MunchDeringer, SinoMeanDeringer, VoDeringer
if __has_pycuda__:
import pycuda.gpuarray as garray
from ..cuda.kernel import CudaKernel
try:
from pycudwt import Wavelets
__have_pycudwt__ = True
except ImportError:
__have_pycudwt__ = False
[docs]
class CudaMunchDeringer(MunchDeringer):
def __init__(
self,
sigma,
sinos_shape,
levels=None,
wname="db15",
padding=None,
padding_mode="edge",
fft_backend="skcuda",
cuda_options=None,
):
"""
Initialize a "Munch Et Al" sinogram deringer with the Cuda backend.
See References for more information.
Parameters
-----------
sigma: float
Standard deviation of the damping parameter. The higher value of sigma,
the more important the filtering effect on the rings.
levels: int, optional
Number of wavelets decomposition levels.
By default (None), the maximum number of decomposition levels is used.
wname: str, optional
Default is "db15" (Daubechies, 15 vanishing moments)
sinos_shape: tuple, optional
Shape of the sinogram (or sinograms stack).
References
----------
B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with
combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009.
"""
super().__init__(sigma, sinos_shape, levels=levels, wname=wname, padding=padding, padding_mode=padding_mode)
self._check_can_use_wavelets()
self.cuda_processing = CudaProcessing(**(cuda_options or {}))
self.ctx = self.cuda_processing.ctx
self._init_pycudwt()
self._init_padding()
self._init_fft(fft_backend)
self._setup_fw_kernel()
def _check_can_use_wavelets(self):
if not (__have_pycudwt__):
raise ValueError("Needs pycudwt to use this class")
def _init_padding(self):
if self.padding is None:
return
self.padder = CudaPadding(
self.sinos_shape[1:],
((0, 0), self.padding),
mode=self.padding_mode,
cuda_options={"ctx": self.cuda_processing.ctx},
)
def _init_fft(self, fft_backend):
self.fft_cls = get_fft_class(backend=fft_backend)
# For all k >= 1, we perform a batched (I)FFT along axis 0 on an array
# of shape (n_a/2^k, n_x/2^k) (up to DWT size rounding)
if self.fft_cls.implem == "vkfft":
self._create_plans_vkfft()
else:
self._create_plans_skfft()
def _create_plans_skfft(self):
self._fft_plans = {}
for level, d_vcoeff in self._d_vertical_coeffs.items():
self._fft_plans[level] = self.fft_cls(d_vcoeff.shape, np.float32, r2c=True, axes=(0,), ctx=self.ctx)
def _create_plans_vkfft(self):
"""
VKFFT does not support batched R2C transforms along axis 0 ("slow axis").
We can either use C2C (faster, but needs more memory) or transpose the arrays to do R2C along axis=1.
Here we transpose the arrays.
"""
self._fft_plans = {}
self._transpose_forward_1 = {}
self._transpose_forward_2 = {}
self._transpose_inverse_1 = {}
self._transpose_inverse_2 = {}
for level, d_vcoeff in self._d_vertical_coeffs.items():
shape = d_vcoeff.shape
# Normally, a batched 1D fft on 2D data of shape (Ny, Nx) along axis 0 returns an array of shape (Ny/2+1, Nx):
#
# (Ny, Nx) --[fft_0]--> (Ny/2, Nx)
# f32 c64
#
# In this case, we can only do batched 1D transform along axis 1, so we have to trick with transposes:
#
# (Ny, Nx) --[T]--> (Nx, Ny) --[fft_1]--> (Nx, Ny/2) --[T]--> (Ny/2, Nx)
# f32 f32 c64 c64
#
# (In both cases IFFT is done the same way from right to left)
self._transpose_forward_1[level] = CudaTranspose(shape, np.float32, ctx=self.ctx)
self._fft_plans[level] = self.fft_cls(shape[::-1], np.float32, r2c=True, ctx=self.ctx)
self._transpose_forward_2[level] = CudaTranspose((shape[1], shape[0] // 2 + 1), np.complex64, ctx=self.ctx)
self._transpose_inverse_1[level] = CudaTranspose((shape[0] // 2 + 1, shape[1]), np.complex64, ctx=self.ctx)
self._transpose_inverse_2[level] = CudaTranspose(shape[::-1], np.float32, ctx=self.ctx)
def _init_pycudwt(self):
if self.levels is None:
self.levels = 100 # will be clipped by pycudwt
sino_shape = self.sinos_shape[1:] if self.padding is None else self.sino_padded_shape
self.cudwt = Wavelets(np.zeros(sino_shape, "f"), self.wname, self.levels)
self.levels = self.cudwt.levels
# Access memory allocated by "pypwt" from pycuda
self._d_sino = garray.empty(sino_shape, np.float32, gpudata=self.cudwt.image_int_ptr())
self._get_vertical_coeffs()
def _get_vertical_coeffs(self):
self._d_vertical_coeffs = {}
# Transfer the (0-memset) coefficients in order to get all the shapes
coeffs = self.cudwt.coeffs
for i in range(self.cudwt.levels):
shape = coeffs[i + 1][1].shape
self._d_vertical_coeffs[i + 1] = garray.empty(
shape, np.float32, gpudata=self.cudwt.coeff_int_ptr(3 * i + 2)
)
def _setup_fw_kernel(self):
self._fw_kernel = CudaKernel(
"kern_fourierwavelets",
filename=get_cuda_srcfile("fourier_wavelets.cu"),
signature="Piif",
)
def _apply_fft(self, level):
d_coeffs = self._d_vertical_coeffs[level]
# All the memory is allocated (or re-used) under the hood
if self.fft_cls.implem == "vkfft":
d_coeffs_t = self._transpose_forward_1[level](
d_coeffs
) # allocates self._transpose_forward_1[level].processing.dst
d_coeffs_t_f = self._fft_plans[level].fft(d_coeffs_t) # allocates self._fft_plans[level].output_fft
d_coeffs_f = self._transpose_forward_2[level](
d_coeffs_t_f
) # allocates self._transpose_forward_2[level].processing.dst
else:
d_coeffs_f = self._fft_plans[level].fft(d_coeffs)
return d_coeffs_f
def _apply_ifft(self, d_coeffs_f, level):
d_coeffs = self._d_vertical_coeffs[level]
if self.fft_cls.implem == "vkfft":
d_coeffs_t_f = self._transpose_inverse_1[level](d_coeffs_f, dst=self._fft_plans[level].output_fft)
d_coeffs_t = self._fft_plans[level].ifft(
d_coeffs_t_f, output=self._transpose_forward_1[level].processing.dst
)
self._transpose_inverse_2[level](d_coeffs_t, dst=d_coeffs)
else:
self._fft_plans[level].ifft(d_coeffs_f, output=d_coeffs)
def _destripe_2D(self, d_sino, output):
if not (d_sino.flags.c_contiguous):
sino = self.cuda_processing.allocate_array("_d_sino", d_sino.shape, np.float32)
sino[:] = d_sino[:]
else:
sino = d_sino
if self.padding is not None:
sino = self.padder.pad(sino)
# set the "image" for DWT (memcpy D2D)
self._d_sino.set(sino)
# perform forward DWT
self.cudwt.forward()
for i in range(self.cudwt.levels):
level = i + 1
Ny, Nx = self._d_vertical_coeffs[level].shape
# Batched FFT along axis 0
d_vertical_coeffs_f = self._apply_fft(level)
# Dampen the wavelets coefficients
self._fw_kernel(d_vertical_coeffs_f, Nx, Ny, self.sigma)
# IFFT
self._apply_ifft(d_vertical_coeffs_f, level)
# Finally, inverse DWT
self.cudwt.inverse()
d_out = self._d_sino
if self.padding is not None:
d_out = self._d_sino[:, self.padding[0] : -self.padding[1]] # memcpy2D
output.set(d_out)
return output
[docs]
def can_use_cuda_deringer():
"""
Check wether cuda implementation of deringer can be used.
Checking for installed modules is not enough, as for example pyvkfft can be installed without cuda devices
"""
can_do_fft = get_available_fft_implems() != []
return can_do_fft and __have_pycudwt__
[docs]
class CudaVoDeringer(VoDeringer):
"""
An interface to topocupy's "remove_all_stripe".
"""
def _check_requirement(self):
if not (__have_tomocupy_deringer__):
raise ImportError("need cupy")
[docs]
def remove_rings_radios(self, radios):
return remove_all_stripe_pycuda(radios, **self._remove_all_stripe_kwargs)
[docs]
def remove_rings_sinograms(self, sinos):
radios = sinos.transpose(axes=(1, 0, 2)) # view, no copy
self.remove_rings_radios(radios)
return sinos
[docs]
def remove_rings_sinogram(self, sino):
radios = sino.reshape(sino.shape[0], 1, sino.shape[1]) # no copy
self.remove_rings_radios(radios)
return sino
remove_rings = remove_rings_sinograms
[docs]
class CudaSinoMeanDeringer(SinoMeanDeringer):
@docstring(SinoMeanDeringer)
def __init__(
self,
sinos_shape,
mode="subtract",
filter_cutoff=None,
padding_mode="edge",
fft_num_threads=None,
**cuda_options,
):
self.processing = CudaProcessing(**(cuda_options or {}))
super().__init__(sinos_shape, mode, filter_cutoff, padding_mode, fft_num_threads)
self._init_kernels()
def _init_kernels(self):
self.d_sino_profile = self.processing.allocate_array("sino_profile", self.n_x)
self._mean_kernel = self.processing.kernel(
"vertical_mean",
filename=get_cuda_srcfile("normalization.cu"),
signature="PPiii",
)
self._mean_kernel_block = (32, 1, 1)
self._mean_kernel_grid = [updiv(self.sinos_shape[-1], self._mean_kernel_block[0]), 1, 1]
self._mean_kernel_args = [self.d_sino_profile, np.int32(self.n_x), np.int32(self.n_angles), np.int32(1)]
self._mean_kernel_kwargs = {
"grid": self._mean_kernel_grid,
"block": self._mean_kernel_block,
}
self._op_kernel = self.processing.kernel(
"inplace_generic_op_3Dby1D",
filename=get_cuda_srcfile("ElementOp.cu"),
signature="PPiii",
options=["-DGENERIC_OP=%d" % (3 if self.mode == "divide" else 1)],
)
self._op_kernel_block = (16, 16, 1)
self._op_kernel_grid = [updiv(a, b) for a, b in zip(self.sinos_shape[1:][::-1], self._op_kernel_block[:-1])] + [
1
]
self._op_kernel_args = [self.d_sino_profile, np.int32(self.n_x), np.int32(self.n_angles), np.int32(1)]
self._op_kernel_kwargs = {
"grid": self._op_kernel_grid,
"block": self._op_kernel_block,
}
def _init_filter(self, filter_cutoff, fft_num_threads, padding_mode):
super()._init_filter(filter_cutoff, fft_num_threads, padding_mode)
if filter_cutoff is None:
return
self._d_filter_f = self.processing.to_device("_filter_f", self._filter_f)
self.padder = CudaPadding(
(self.n_x, 1),
((self._pad_left, self._pad_right), (0, 0)),
mode=self.padding_mode,
cuda_options={"ctx": self.processing.ctx},
)
fft_cls = get_fft_class()
self._fft = fft_cls(self._filter_size, np.float32, r2c=True)
def _apply_filter(self, sino_profile):
if self._filter_f is None:
return sino_profile
sino_profile = sino_profile.reshape((-1, 1)) # view
sino_profile_p = self.padder.pad(sino_profile).ravel()
sino_profile_f = self._fft.fft(sino_profile_p)
sino_profile_f *= self._d_filter_f
self._fft.ifft(sino_profile_f, output=sino_profile_p)
self.d_sino_profile[:] = sino_profile_p[self._pad_left : -self._pad_right]
return self.d_sino_profile
[docs]
def remove_rings_sinogram(self, sino, output=None):
#
if output is not None:
raise NotImplementedError
#
if not (sino.flags.c_contiguous):
d_sino = self.processing.allocate_array("d_sino", sino.shape, np.float32)
d_sino[:] = sino[:]
else:
d_sino = sino
self._mean_kernel(d_sino, *self._mean_kernel_args, **self._mean_kernel_kwargs)
self._apply_filter(self.d_sino_profile)
self._op_kernel(d_sino, *self._op_kernel_args, **self._op_kernel_kwargs)
if not (sino.flags.c_contiguous):
sino[:] = self.processing.d_sino[:]
return sino
[docs]
def remove_rings_sinograms(self, sinograms):
for i in range(sinograms.shape[0]):
self.remove_rings_sinogram(sinograms[i])