# pylint: skip-file
"""
This file is a "GPU" (through cupy) implementation of "remove_all_stripe".
The original method is implemented by Nghia Vo in the algotom project: https://github.com/algotom/algotom/blob/master/algotom/prep/removal.py
The implementation using cupy is done by Viktor Nikitin in the tomocupy project: https://github.com/tomography/tomocupy/blame/main/src/tomocupy/remove_stripe.py
License follows.
For now we can't rely on off-the-shelf tomocupy as it's not packaged in pypi, and compilation is quite tedious.
"""
# *************************************************************************** #
# Copyright © 2022, UChicago Argonne, LLC #
# All Rights Reserved #
# Software Name: Tomocupy #
# By: Argonne National Laboratory #
# #
# OPEN SOURCE LICENSE #
# #
# Redistribution and use in source and binary forms, with or without #
# modification, are permitted provided that the following conditions are met: #
# #
# 1. Redistributions of source code must retain the above copyright notice, #
# this list of conditions and the following disclaimer. #
# 2. Redistributions in binary form must reproduce the above copyright #
# notice, this list of conditions and the following disclaimer in the #
# documentation and/or other materials provided with the distribution. #
# 3. Neither the name of the copyright holder nor the names of its #
# contributors may be used to endorse or promote products derived #
# from this software without specific prior written permission. #
# #
# #
# *************************************************************************** #
# DISCLAIMER #
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS #
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT #
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS #
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT #
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, #
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED #
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR #
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF #
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING #
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS #
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
# *************************************************************************** #
try:
import cupy as cp
import pywt
from cupyx.scipy.ndimage import median_filter
from cupyx.scipy import signal
from cupyx.scipy.ndimage import binary_dilation
from cupyx.scipy.ndimage import uniform_filter1d
__have_tomocupy_deringer__ = True
except ImportError as err:
__have_tomocupy_deringer__ = False
__tomocupy_deringer_import_error__ = err
###### Ring removal with wavelet filtering (adapted for cupy from pytroch_wavelet package https://pytorch-wavelets.readthedocs.io/)################################################################################
def _reflect(x, minx, maxx):
"""Reflect the values in matrix *x* about the scalar values *minx* and
*maxx*. Hence a vector *x* containing a long linearly increasing series is
converted into a waveform which ramps linearly up and down between *minx*
and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
0.5), the ramps will have repeated max and min samples.
.. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
.. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
"""
x = cp.asanyarray(x)
rng = maxx - minx
rng_by_2 = 2 * rng
mod = cp.fmod(x - minx, rng_by_2)
normed_mod = cp.where(mod < 0, mod + rng_by_2, mod)
out = cp.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
return cp.array(out, dtype=x.dtype)
def _mypad(x, pad, value=0):
""" Function to do numpy like padding on Arrays. Only works for 2-D
padding.
Inputs:
x (array): Array to pad
pad (tuple): tuple of (left, right, top, bottom) pad sizes
"""
# Vertical only
if pad[0] == 0 and pad[1] == 0:
m1, m2 = pad[2], pad[3]
l = x.shape[-2]
xe = _reflect(cp.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
return x[:, :, xe]
# horizontal only
elif pad[2] == 0 and pad[3] == 0:
m1, m2 = pad[0], pad[1]
l = x.shape[-1]
xe = _reflect(cp.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
return x[:, :, :, xe]
def _conv2d(x, w, stride, pad, groups=1):
""" Convolution (equivalent pytorch.conv2d)
"""
if pad != 0:
x = cp.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), 'constant')
b, ci, hi, wi = x.shape
co, _, hk, wk = w.shape
ho = int(cp.floor(1 + (hi - hk) / stride[0]))
wo = int(cp.floor(1 + (wi - wk) / stride[1]))
out = cp.zeros([b, co, ho, wo], dtype='float32')
x = cp.expand_dims(x, axis=1)
w = cp.expand_dims(w, axis=0)
chunk = ci//groups
chunko = co//groups
for g in range(groups):
for ii in range(hk):
for jj in range(wk):
x_windows = x[:, :, g*chunk:(g+1)*chunk, ii:ho *
stride[0]+ii:stride[0], jj:wo*stride[1]+jj:stride[1]]
out[:, g*chunko:(g+1)*chunko] += cp.sum(x_windows *
w[:, g*chunko:(g+1)*chunko, :, ii:ii+1, jj:jj+1], axis=2)
return out
def _conv_transpose2d(x, w, stride, pad, bias=None, groups=1):
""" Transposed convolution (equivalent pytorch.conv_transpose2d)
"""
b, co, ho, wo = x.shape
co, ci, hk, wk = w.shape
hi = (ho-1)*stride[0]+hk
wi = (wo-1)*stride[1]+wk
out = cp.zeros([b, ci, hi, wi], dtype='float32')
chunk = ci//groups
chunko = co//groups
for g in range(groups):
for ii in range(hk):
for jj in range(wk):
x_windows = x[:, g*chunko:(g+1)*chunko]
out[:, g*chunk:(g+1)*chunk, ii:ho*stride[0]+ii:stride[0], jj:wo*stride[1] +
jj:stride[1]] += x_windows * w[g*chunko:(g+1)*chunko, :, ii:ii+1, jj:jj+1]
if pad != 0:
out = out[:, :, pad[0]:out.shape[2]-pad[0], pad[1]:out.shape[3]-pad[1]]
return out
[docs]
def afb1d(x, h0, h1='zero', dim=-1):
""" 1D analysis filter bank (along one dimension only) of an image
Parameters
----------
x (array): 4D input with the last two dimensions the spatial input
h0 (array): 4D input for the lowpass filter. Should have shape (1, 1,
h, 1) or (1, 1, 1, w)
h1 (array): 4D input for the highpass filter. Should have shape (1, 1,
h, 1) or (1, 1, 1, w)
dim (int) - dimension of filtering. d=2 is for a vertical filter (called
column filtering but filters across the rows). d=3 is for a
horizontal filter, (called row filtering but filters across the
columns).
Returns
-------
lohi: lowpass and highpass subbands concatenated along the channel
dimension
"""
C = x.shape[1]
# Convert the dim to positive
d = dim % 4
s = (2, 1) if d == 2 else (1, 2)
N = x.shape[d]
L = h0.size
L2 = L // 2
shape = [1, 1, 1, 1]
shape[d] = L
h = cp.concatenate([h0.reshape(*shape), h1.reshape(*shape)]*C, axis=0)
# Calculate the pad size
outsize = pywt.dwt_coeff_len(N, L, mode='symmetric')
p = 2 * (outsize - 1) - N + L
pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
x = _mypad(x, pad=pad)
lohi = _conv2d(x, h, stride=s, pad=0, groups=C)
return lohi
[docs]
def sfb1d(lo, hi, g0, g1='zero', dim=-1):
""" 1D synthesis filter bank of an image Array
"""
C = lo.shape[1]
d = dim % 4
L = g0.size
shape = [1, 1, 1, 1]
shape[d] = L
N = 2*lo.shape[d]
s = (2, 1) if d == 2 else (1, 2)
g0 = cp.concatenate([g0.reshape(*shape)]*C, axis=0)
g1 = cp.concatenate([g1.reshape(*shape)]*C, axis=0)
pad = (L-2, 0) if d == 2 else (0, L-2)
y = _conv_transpose2d(cp.asarray(lo), cp.asarray(g0), stride=s, pad=pad, groups=C) + \
_conv_transpose2d(cp.asarray(hi), cp.asarray(g1),
stride=s, pad=pad, groups=C)
return y
[docs]
class DWTForward():
""" Performs a 2d DWT Forward decomposition of an image
Args:
wave (str): Which wavelet to use.
"""
def __init__(self, wave='db1'):
super().__init__()
wave = pywt.Wavelet(wave)
h0_col, h1_col = wave.dec_lo, wave.dec_hi
h0_row, h1_row = h0_col, h1_col
self.h0_col = cp.array(h0_col).astype('float32')[
::-1].reshape((1, 1, -1, 1))
self.h1_col = cp.array(h1_col).astype('float32')[
::-1].reshape((1, 1, -1, 1))
self.h0_row = cp.array(h0_row).astype('float32')[
::-1].reshape((1, 1, 1, -1))
self.h1_row = cp.array(h1_row).astype('float32')[
::-1].reshape((1, 1, 1, -1))
[docs]
def apply(self, x):
""" Forward pass of the DWT.
Args:
x (array): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
Returns:
(yl, yh)
tuple of lowpass (yl) and bandpass (yh) coefficients.
yh is a list of scale coefficients. yl has shape
:math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
dimension in yh iterates over the LH, HL and HH coefficients.
Note:
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
downsampled shapes of the DWT pyramid.
"""
# Do a multilevel transform
# Do 1 level of the transform
lohi = afb1d(x, self.h0_row, self.h1_row, dim=3)
y = afb1d(lohi, self.h0_col, self.h1_col, dim=2)
s = y.shape
y = y.reshape(s[0], -1, 4, s[-2], s[-1]) # pylint: disable=E1121 # this might blow up in the future
x = cp.ascontiguousarray(y[:, :, 0])
yh = cp.ascontiguousarray(y[:, :, 1:])
return x, yh
[docs]
class DWTInverse():
""" Performs a 2d DWT Inverse reconstruction of an image
Args:
wave (str): Which wavelet to use.
"""
def __init__(self, wave='db1'):
super().__init__()
wave = pywt.Wavelet(wave)
g0_col, g1_col = wave.rec_lo, wave.rec_hi
g0_row, g1_row = g0_col, g1_col
# Prepare the filters
self.g0_col = cp.array(g0_col).astype('float32').reshape((1, 1, -1, 1))
self.g1_col = cp.array(g1_col).astype('float32').reshape((1, 1, -1, 1))
self.g0_row = cp.array(g0_row).astype('float32').reshape((1, 1, 1, -1))
self.g1_row = cp.array(g1_row).astype('float32').reshape((1, 1, 1, -1))
[docs]
def apply(self, coeffs):
"""
Args:
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
yl is a lowpass array of shape :math:`(N, C_{in}, H_{in}',
W_{in}')` and yh is a list of bandpass arrays of shape
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
the format returned by DWTForward
Returns:
Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
Note:
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
downsampled shapes of the DWT pyramid.
"""
yl, yh = coeffs
lh = yh[:, :, 0]
hl = yh[:, :, 1]
hh = yh[:, :, 2]
lo = sfb1d(yl, lh, self.g0_col, self.g1_col, dim=2)
hi = sfb1d(hl, hh, self.g0_col, self.g1_col, dim=2)
yl = sfb1d(lo, hi, self.g0_row, self.g1_row, dim=3)
return yl
[docs]
def remove_stripe_fw(data, sigma, wname, level):
"""Remove stripes with wavelet filtering"""
[nproj, nz, ni] = data.shape
nproj_pad = nproj + nproj // 8
xshift = int((nproj_pad - nproj) // 2)
# Accepts all wave types available to PyWavelets
xfm = DWTForward(wave=wname)
ifm = DWTInverse(wave=wname)
# Wavelet decomposition.
cc = []
sli = cp.zeros([nz, 1, nproj_pad, ni], dtype='float32')
sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) //
2] = data.astype('float32').swapaxes(0, 1)
for k in range(level):
sli, c = xfm.apply(sli)
cc.append(c)
# FFT
fcV = cp.fft.fft(cc[k][:, 0, 1], axis=1)
_, my, mx = fcV.shape
# Damping of ring artifact information.
y_hat = cp.fft.ifftshift((cp.arange(-my, my, 2) + 1) / 2)
damp = -cp.expm1(-y_hat**2 / (2 * sigma**2))
fcV *= cp.tile(damp, (mx, 1)).swapaxes(0, 1)
# Inverse FFT.
cc[k][:, 0, 1] = cp.fft.ifft(fcV, my, axis=1).real
# Wavelet reconstruction.
for k in range(level)[::-1]:
shape0 = cc[k][0, 0, 1].shape
sli = sli[:, :, :shape0[0], :shape0[1]]
sli = ifm.apply((sli, cc[k]))
data = sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) //
2, :ni].astype(data.dtype) # modified
data = data.swapaxes(0, 1)
return data
######## Titarenko ring removal ############################################################################################################################################################################
[docs]
def remove_stripe_ti(data, beta, mask_size):
"""Remove stripes with a new method by V. Titareno """
gamma = beta*((1-beta)/(1+beta)
)**cp.abs(cp.fft.fftfreq(data.shape[-1])*data.shape[-1])
gamma[0] -= 1
v = cp.mean(data, axis=0)
v = v-v[:, 0:1]
v = cp.fft.irfft(cp.fft.rfft(v)*cp.fft.rfft(gamma))
mask = cp.zeros(v.shape, dtype=v.dtype)
mask_size = mask_size*mask.shape[1]
mask[:, mask.shape[1]//2-mask_size//2:mask.shape[1]//2+mask_size//2] = 1
data[:] += v*mask
return data
######## Optimized version for Vo-all ring removal in tomopy################################################################################################################################################################
def _rs_sort(sinogram, size, matindex, dim):
"""
Remove stripes using the sorting technique.
"""
sinogram = cp.transpose(sinogram)
matcomb = cp.asarray(cp.dstack((matindex, sinogram)))
# matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcomb])
ids = cp.argsort(matcomb[:,:,1],axis=1)
matsort = matcomb.copy()
matsort[:,:,0] = cp.take_along_axis(matsort[:,:,0],ids,axis=1)
matsort[:,:,1] = cp.take_along_axis(matsort[:,:,1],ids,axis=1)
if dim == 1:
matsort[:, :, 1] = median_filter(matsort[:, :, 1], (size, 1))
else:
matsort[:, :, 1] = median_filter(matsort[:, :, 1], (size, size))
# matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort])
ids = cp.argsort(matsort[:,:,0],axis=1)
matsortback = matsort.copy()
matsortback[:,:,0] = cp.take_along_axis(matsortback[:,:,0],ids,axis=1)
matsortback[:,:,1] = cp.take_along_axis(matsortback[:,:,1],ids,axis=1)
sino_corrected = matsortback[:, :, 1]
return cp.transpose(sino_corrected)
def _mpolyfit(x,y):
n= len(x)
x_mean = cp.mean(x)
y_mean = cp.mean(y)
Sxy = cp.sum(x*y) - n*x_mean*y_mean
Sxx = cp.sum(x*x) - n*x_mean*x_mean
slope = Sxy / Sxx
intercept = y_mean - slope*x_mean
return slope,intercept
def _detect_stripe(listdata, snr):
"""
Algorithm 4 in :cite:`Vo:18`. Used to locate stripes.
"""
numdata = len(listdata)
listsorted = cp.sort(listdata)[::-1]
xlist = cp.arange(0, numdata, 1.0)
ndrop = cp.int16(0.25 * numdata)
# (_slope, _intercept) = cp.polyfit(xlist[ndrop:-ndrop - 1],
# listsorted[ndrop:-ndrop - 1], 1)
(_slope, _intercept) = _mpolyfit(xlist[ndrop:-ndrop - 1], listsorted[ndrop:-ndrop - 1])
numt1 = _intercept + _slope * xlist[-1]
noiselevel = cp.abs(numt1 - _intercept)
noiselevel = cp.clip(noiselevel, 1e-6, None)
val1 = cp.abs(listsorted[0] - _intercept) / noiselevel
val2 = cp.abs(listsorted[-1] - numt1) / noiselevel
listmask = cp.zeros_like(listdata)
if (val1 >= snr):
upper_thresh = _intercept + noiselevel * snr * 0.5
listmask[listdata > upper_thresh] = 1.0
if (val2 >= snr):
lower_thresh = numt1 - noiselevel * snr * 0.5
listmask[listdata <= lower_thresh] = 1.0
return listmask
def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
"""
Remove large stripes.
"""
drop_ratio = max(min(drop_ratio,0.8),0)# = cp.clip(drop_ratio, 0.0, 0.8)
(nrow, ncol) = sinogram.shape
ndrop = int(0.5 * drop_ratio * nrow)
sinosort = cp.sort(sinogram, axis=0)
sinosmooth = median_filter(sinosort, (1, size))
list1 = cp.mean(sinosort[ndrop:nrow - ndrop], axis=0)
list2 = cp.mean(sinosmooth[ndrop:nrow - ndrop], axis=0)
# listfact = cp.divide(list1,
# list2,
# out=cp.ones_like(list1),
# where=list2 != 0)
listfact = list1/list2
# Locate stripes
listmask = _detect_stripe(listfact, snr)
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
matfact = cp.tile(listfact, (nrow, 1))
# Normalize
if norm is True:
sinogram = sinogram / matfact
sinogram1 = cp.transpose(sinogram)
matcombine = cp.asarray(cp.dstack((matindex, sinogram1)))
# matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcombine])
ids = cp.argsort(matcombine[:,:,1],axis=1)
matsort = matcombine.copy()
matsort[:,:,0] = cp.take_along_axis(matsort[:,:,0],ids,axis=1)
matsort[:,:,1] = cp.take_along_axis(matsort[:,:,1],ids,axis=1)
matsort[:, :, 1] = cp.transpose(sinosmooth)
# matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort])
ids = cp.argsort(matsort[:,:,0],axis=1)
matsortback = matsort.copy()
matsortback[:,:,0] = cp.take_along_axis(matsortback[:,:,0],ids,axis=1)
matsortback[:,:,1] = cp.take_along_axis(matsortback[:,:,1],ids,axis=1)
sino_corrected = cp.transpose(matsortback[:, :, 1])
listxmiss = cp.where(listmask > 0.0)[0]
sinogram[:, listxmiss] = sino_corrected[:, listxmiss]
return sinogram
def _rs_dead(sinogram, snr, size, matindex, norm=True):
"""
Remove unresponsive and fluctuating stripes.
"""
sinogram = cp.copy(sinogram) # Make it mutable
(nrow, _) = sinogram.shape
# sinosmooth = cp.apply_along_axis(uniform_filter1d, 0, sinogram, 10)
sinosmooth = uniform_filter1d(sinogram, 10, axis=0)
listdiff = cp.sum(cp.abs(sinogram - sinosmooth), axis=0)
listdiffbck = median_filter(listdiff, size)
listfact = listdiff/listdiffbck
listmask = _detect_stripe(listfact, snr)
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
listmask[0:2] = 0.0
listmask[-2:] = 0.0
listx = cp.where(listmask < 1.0)[0]
listy = cp.arange(nrow)
matz = sinogram[:, listx]
listxmiss = cp.where(listmask > 0.0)[0]
# finter = interpolate.interp2d(listx.get(), listy.get(), matz.get(), kind='linear')
if len(listxmiss) > 0:
# sinogram_c[:, listxmiss.get()] = finter(listxmiss.get(), listy.get())
ids = cp.searchsorted(listx, listxmiss)
sinogram[:,listxmiss] = matz[:,ids-1]+(listxmiss-listx[ids-1])*(matz[:,ids]-matz[:,ids-1])/(listx[ids]-listx[ids-1])
# Remove residual stripes
if norm is True:
sinogram = _rs_large(sinogram, snr, size, matindex)
return sinogram
def _create_matindex(nrow, ncol):
"""
Create a 2D array of indexes used for the sorting technique.
"""
listindex = cp.arange(0.0, ncol, 1.0)
matindex = cp.tile(listindex, (nrow, 1))
return matindex
[docs]
def remove_all_stripe(tomo,
snr=3,
la_size=61,
sm_size=21,
dim=1):
"""
Remove all types of stripe artifacts from sinogram using Nghia Vo's
approach :cite:`Vo:18` (combination of algorithm 3,4,5, and 6).
Parameters
----------
tomo : ndarray
3D tomographic data.
snr : float
Ratio used to locate large stripes.
Greater is less sensitive.
la_size : int
Window size of the median filter to remove large stripes.
sm_size : int
Window size of the median filter to remove small-to-medium stripes.
dim : {1, 2}, optional
Dimension of the window.
Returns
-------
ndarray
Corrected 3D tomographic data.
"""
matindex = _create_matindex(tomo.shape[2], tomo.shape[0])
for m in range(tomo.shape[1]):
sino = tomo[:, m, :]
sino = _rs_dead(sino, snr, la_size, matindex)
sino = _rs_sort(sino, sm_size, matindex, dim)
tomo[:, m, :] = sino
return tomo
from ..cuda.utils import pycuda_to_cupy
[docs]
def remove_all_stripe_pycuda(radios, device_id=0, **kwargs):
"""
Nabu interface to "remove_all_stripe". In-place!
Parameters
----------
radios: pycuda.GPUArray
Stack of radios in the shape (n_angles, n_y, n_x)
so that sinogram number i is radios[:, i, :]
Other Parameters
----------------
See parameters of 'remove_all_stripe
"""
if getattr(remove_all_stripe, "_cupy_init", False) is False:
from cupy import cuda
cuda.Device(device_id).use()
setattr(remove_all_stripe, "_cupy_init", True)
cupy_radios = pycuda_to_cupy(radios) # no memory copy, the internal pointer is passed to pycuda
remove_all_stripe(cupy_radios, **kwargs)
return radios