import numpy as np
from ..cuda.utils import __has_pycuda__
from ..cuda.processing import CudaProcessing
from ..processing.muladd_cuda import CudaMulAdd
from .shift import VerticalShift
[docs]
class CudaVerticalShift(VerticalShift):
def __init__(self, radios_shape, shifts, **cuda_options):
"""
Vertical Shifter, Cuda backend.
"""
super().__init__(radios_shape, shifts)
self.cuda_processing = CudaProcessing(**(cuda_options or {}))
self._init_cuda_arrays()
def _init_cuda_arrays(self):
interp_infos_arr = np.zeros((len(self.interp_infos), 2), "f")
self._d_interp_infos = self.cuda_processing.to_device("_d_interp_infos", interp_infos_arr)
self._d_radio_new = self.cuda_processing.allocate_array("_d_radio_new", self.radios_shape[1:], "f")
self._d_radio = self.cuda_processing.allocate_array("_d_radio", self.radios_shape[1:], "f")
self.muladd_kernel = CudaMulAdd(ctx=self.cuda_processing.ctx)
[docs]
def apply_vertical_shifts(self, radios, iangles, output=None):
"""
Parameters
----------
radios: 3D pycuda.gpuarray.GPUArray
The input radios. If the optional parameter is not given, they are modified in-place
iangles: a sequence of integers
Must have the same lenght as radios.
It contains the index at which the shift is found in `self.shifts`
given by `shifts` argument in the initialisation of the object.
output: 3D pycuda.gpuarray.GPUArray, optional
If given, it will be modified to contain the shifted radios.
Must be of the same shape of `radios`.
"""
self._check(radios, iangles)
n_a, n_z, n_x = radios.shape
assert n_z == self.radios_shape[1]
x_slice = slice(0, n_x) # slice(None, None)
def nonempty_subregion(region):
if region is None:
return True
z_slice = region[0]
return z_slice.stop - z_slice.start > 0
d_radio_new = self._d_radio_new
d_radio = self._d_radio
for ia in iangles:
d_radio_new.fill(0)
d_radio[:] = radios[ia, :, :] # mul-add kernel won't work with pycuda view
S0, f = self.interp_infos[ia]
f = np.float32(f)
s0 = S0
if s0 > 0:
# newradio[:-s0] = radio[s0:] * (1 - f)
dst_region = (slice(0, n_z - s0), x_slice)
other_region = (slice(s0, n_z), x_slice)
elif s0 == 0:
# newradio[:] = radio[s0:] * (1 - f)
dst_region = None
other_region = (slice(s0, n_z), x_slice)
else:
# newradio[-s0:] = radio[:s0] * (1 - f)
dst_region = (slice(-s0, n_z), x_slice)
other_region = (slice(0, n_z + s0), x_slice)
if all([nonempty_subregion(reg) for reg in [dst_region, other_region]]):
self.muladd_kernel(
d_radio_new,
d_radio,
1,
1 - f,
dst_region=dst_region,
other_region=other_region,
)
s0 = S0 + 1
if s0 > 0:
# newradio[:-s0] += radio[s0:] * f
dst_region = (slice(0, n_z - s0), x_slice)
other_region = (slice(s0, n_z), x_slice)
elif s0 == 0:
# newradio[:] += radio[s0:] * f
dst_region = None
other_region = (slice(s0, n_z), x_slice)
else:
# newradio[-s0:] += radio[:s0] * f
dst_region = (slice(-s0, n_z), x_slice)
other_region = (slice(0, n_z + s0), x_slice)
if all([nonempty_subregion(reg) for reg in [dst_region, other_region]]):
self.muladd_kernel(d_radio_new, d_radio, 1, f, dst_region=dst_region, other_region=other_region)
if output is None:
radios[ia, :, :] = d_radio_new[:, :]
else:
output[ia, :, :] = d_radio_new[:, :]