1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244 | import numpy as np
from scipy.fft import rfft, irfft
from silx.image.tomography import get_next_power
from ..thirdparty.pore3d_deringer_munch import munchetal_filter
from ..utils import get_2D_3D_shape, get_num_threads, check_supported
from ..misc.fourier_filters import get_bandpass_filter
class MunchDeringer:
def __init__(self, sigma, sinos_shape, levels=None, wname="db15", padding=None, padding_mode="edge"):
"""
Initialize a "Munch Et Al" sinogram deringer. See References for more information.
Parameters
-----------
sigma: float
Standard deviation of the damping parameter. The higher value of sigma,
the more important the filtering effect on the rings.
levels: int, optional
Number of wavelets decomposition levels.
By default (None), the maximum number of decomposition levels is used.
wname: str, optional
Default is "db15" (Daubechies, 15 vanishing moments)
sinos_shape: tuple, optional
Shape of the sinogram (or sinograms stack).
padding: tuple of two int, optional
Horizontal padding to use for reducing the aliasing artefacts
References
----------
B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with
combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009.
"""
self._get_shapes(sinos_shape, padding)
self.sigma = sigma
self.levels = levels
self.wname = wname
self.padding_mode = padding_mode
self._check_can_use_wavelets()
def _get_shapes(self, sinos_shape, padding):
n_z, n_a, n_x = get_2D_3D_shape(sinos_shape)
self.sinos_shape = n_z, n_a, n_x
self.n_angles = n_a
self.n_z = n_z
self.n_x = n_x
# Handle "padding=True" or "padding=False"
if isinstance(padding, bool):
if padding:
padding = (n_x // 2, n_x // 2)
else:
padding = None
#
if padding is not None:
pad_x1, pad_x2 = padding
if np.iterable(pad_x1) or np.iterable(pad_x2):
raise ValueError("Expected padding in the form (x1, x2)")
self.sino_padded_shape = (n_a, n_x + pad_x1 + pad_x2)
self.padding = padding
def _check_can_use_wavelets(self):
if munchetal_filter is None:
raise ValueError("Need pywavelets to use this class")
def _destripe_2D(self, sino, output):
if self.padding is not None:
sino = np.pad(sino, ((0, 0), self.padding), mode=self.padding_mode)
res = munchetal_filter(sino, self.levels, self.sigma, wname=self.wname)
if self.padding is not None:
res = res[:, self.padding[0] : -self.padding[1]]
output[:] = res
return output
def remove_rings(self, sinos, output=None):
"""
Main function to performs rings artefacts removal on sinogram(s).
CAUTION: this function defaults to in-place processing, meaning that
the sinogram(s) you pass will be overwritten.
Parameters
----------
sinos: numpy.ndarray
Sinogram or stack of sinograms.
output: numpy.ndarray, optional
Output array. If set to None (default), the output overwrites the input.
"""
if output is None:
output = sinos
if sinos.ndim == 2:
return self._destripe_2D(sinos, output)
n_sinos = sinos.shape[0]
for i in range(n_sinos):
self._destripe_2D(sinos[i], output[i])
return output
class VoDeringer:
"""
An interface to Nghia Vo's "remove_all_stripe".
Needs algotom to run.
"""
def __init__(self, sinos_shape, **remove_all_stripe_options):
self._init_lib()
self._get_shapes(sinos_shape)
self._remove_all_stripe_kwargs = remove_all_stripe_options
def _init_lib(self):
# Importing this is time-consumming, because of the numba initialization
from algotom.prep.removal import remove_all_stripe
#
self._remove_all_stripe = remove_all_stripe
def _get_shapes(self, sinos_shape):
n_z, n_a, n_x = get_2D_3D_shape(sinos_shape)
self.sinos_shape = n_z, n_a, n_x
self.n_angles = n_a
self.n_z = n_z
self.n_x = n_x
def remove_rings_sinogram(self, sino, output=None):
new_sino = self._remove_all_stripe(sino, **self._remove_all_stripe_kwargs) # out-of-place
if output is not None:
output[:] = new_sino[:]
return output
return new_sino
def remove_rings_sinograms(self, sinos, output=None):
if output is None:
output = sinos
for i in range(sinos.shape[0]):
output[i] = self.remove_rings_sinogram(sinos[i])
return output
def remove_rings_radios(self, radios):
sinos = np.moveaxis(radios, 1, 0) # (n_a, n_z, n_x) --> (n_z, n_a, n_x)
return self.remove_rings_sinograms(sinos)
remove_rings = remove_rings_sinograms
class SinoMeanDeringer:
supported_modes = ["subtract", "divide"]
def __init__(self, sinos_shape, mode="subtract", filter_cutoff=None, padding_mode="edge", fft_num_threads=None):
"""
Rings correction with mean subtraction/division.
The principle of this method is to subtract (or divide) the sinogram by its mean along a certain axis.
In short:
sinogram -= filt(sinogram.mean(axis=0))
where `filt` is some bandpass filter.
Parameters
----------
sinos_shape: tuple of int
Sinograms shape, in the form (n_angles, n_x) or (n_sinos, n_angles, n_x)
mode: str, optional
Operation to do on the sinogram, either "subtract" or "divide"
filter_cutoff: tuple, optional
Cut-off of the bandpass filter applied on the sinogram profiles.
Empty (default) means no filtering.
Possible values forms are:
- (sigma_low, sigma_high): two float values defining the standard deviation of
gaussian(sigma_low) * (1 - gaussian(sigma_high)).
High values of sigma mean stronger effect of associated filters.
- ((cutoff_low, transition_low), (cutoff_high, transition_high))
where "cutoff" is in normalized Nyquist frequency (0.5 is the maximum frequency),
and "transition" is the width of filter decay in fraction of the cutoff frequency
padding_mode: str, optional
Padding mode when filtering the sinogram profile.
Should be "constant" (i.e "zeros") for mathematical correctness,
but in practice this yields a Gibbs effect when replicating the sinogram, so "edges" is recommended.
fft_num_threads: int, optional
How many threads to use for computing the fast Fourier transform when filtering the sinogram profile.
Defaut is all the available threads.
"""
self._get_shapes(sinos_shape)
check_supported(mode, self.supported_modes, "operation mode")
self.mode = mode
self._init_filter(filter_cutoff, fft_num_threads, padding_mode)
def _get_shapes(self, sinos_shape):
n_z, n_a, n_x = get_2D_3D_shape(sinos_shape)
self.sinos_shape = n_z, n_a, n_x
self.n_angles = n_a
self.n_z = n_z
self.n_x = n_x
def _init_filter(self, filter_cutoff, fft_num_threads, padding_mode):
self.filter_cutoff = filter_cutoff
self._filter_f = None
if filter_cutoff is None:
return
self._filter_size = get_next_power(self.n_x * 2)
self._filter_f = get_bandpass_filter(
(1, self._filter_size),
cutoff_lowpass=filter_cutoff[0],
cutoff_highpass=filter_cutoff[1],
use_rfft=True,
data_type=np.float32,
).ravel()
self._fft_n_threads = get_num_threads(fft_num_threads)
# compat
if padding_mode == "edges":
padding_mode = "edge"
#
self.padding_mode = padding_mode
size_diff = self._filter_size - self.n_x
self._pad_left, self._pad_right = size_diff // 2, size_diff - size_diff // 2
def _apply_filter(self, sino_profile):
if self._filter_f is None:
return sino_profile
sino_profile = np.pad(sino_profile, (self._pad_left, self._pad_right), mode=self.padding_mode)
sino_f = rfft(sino_profile, workers=self._fft_n_threads)
sino_f *= self._filter_f
return irfft(sino_f, workers=self._fft_n_threads)[self._pad_left : -self._pad_right] # ascontiguousarray ?
def remove_rings_sinogram(self, sino, output=None):
#
if output is not None:
raise NotImplementedError
#
sino_profile = sino.mean(axis=0)
sino_profile = self._apply_filter(sino_profile)
if self.mode == "subtract":
sino -= sino_profile
elif self.mode == "divide":
sino /= sino_profile
return sino
def remove_rings_sinograms(self, sinos, output=None):
#
if output is not None:
raise NotImplementedError
#
for i in range(sinos.shape[0]):
self.remove_rings_sinogram(sinos[i])
remove_rings = remove_rings_sinograms
|