Source code for nabu.processing.fftshift

import numpy as np
from ..utils import BaseClassError, get_opencl_srcfile, updiv
from ..opencl.kernel import OpenCLKernel
from ..opencl.processing import OpenCLProcessing
from pyopencl.tools import dtype_to_ctype as cl_dtype_to_ctype


[docs] class FFTshiftBase: KernelCls = BaseClassError ProcessingCls = BaseClassError dtype_to_ctype = BaseClassError backend = "none" def __init__(self, shape, dtype, dst_dtype=None, axes=None, **backend_options): """ Parameters ---------- shape: tuple Array shape - can be 1D or 2D. 3D is not supported. dtype: str or numpy.dtype Data type, eg. "f", numpy.complex64, ... dst_dtype: str or numpy.dtype Output data type. If not provided (default), the shift is done in-place. axes: tuple, optional Axes over which to shift. Default is None, which shifts all axes. Other parameters ---------------- backend_options: named arguments to pass to CudaProcessing or OpenCLProcessing """ # if axes not in [1, (1,), (-1,)]: raise NotImplementedError # self.processing = self.ProcessingCls(**backend_options) self.shape = shape if len(self.shape) not in [1, 2]: raise ValueError("Expected 1D or 2D array") self.dtype = np.dtype(dtype) self.dst_dtype = dst_dtype if dst_dtype is None: self._configure_inplace_shift() else: self._configure_out_of_place_shift() self._configure_kenel_initialization() self._fftshift_kernel = self.KernelCls(*self._kernel_init_args, **self._kernel_init_kwargs) self._configure_kernel_call() def _configure_inplace_shift(self): self.inplace = True # in-place on odd-sized array is more difficult - see fftshift.cl if self.shape[-1] & 1: raise NotImplementedError # self._kernel_init_args = [ "fftshift_x_inplace", ] self._kernel_init_kwargs = { "options": [ "-DDTYPE=%s" % self.dtype_to_ctype(self.dtype), ], } def _configure_out_of_place_shift(self): self.inplace = False self._kernel_init_args = [ "fftshift_x", ] self._kernel_init_kwargs = { "options": [ "-DDTYPE=%s" % self.dtype_to_ctype(self.dtype), "-DDTYPE_OUT=%s" % self.dtype_to_ctype(np.dtype(self.dst_dtype)), ], } additional_flag = None input_is_complex = np.iscomplexobj(np.ones(1, dtype=self.dtype)) output_is_complex = np.iscomplexobj(np.ones(1, dtype=self.dst_dtype)) if not (input_is_complex) and output_is_complex: additional_flag = "-DCAST_TO_COMPLEX" if input_is_complex and not (output_is_complex): additional_flag = "-DCAST_TO_REAL" if additional_flag is not None: self._kernel_init_kwargs["options"].append(additional_flag) def _call_fftshift_inplace(self, arr, direction): self._fftshift_kernel( # pylint: disable=E1102 arr, np.int32(self.shape[1]), np.int32(self.shape[0]), np.int32(direction), **self._kernel_kwargs ) return arr def _call_fftshift_out_of_place(self, arr, dst, direction): if dst is None: dst = self.processing.allocate_array("dst", arr.shape, dtype=self.dst_dtype) self._fftshift_kernel( # pylint: disable=E1102 arr, dst, np.int32(self.shape[1]), np.int32(self.shape[0]), np.int32(direction), **self._kernel_kwargs ) return dst
[docs] def fftshift(self, arr, dst=None): if self.inplace: return self._call_fftshift_inplace(arr, 1) else: return self._call_fftshift_out_of_place(arr, dst, 1)
[docs] def ifftshift(self, arr, dst=None): if self.inplace: return self._call_fftshift_inplace(arr, -1) else: return self._call_fftshift_out_of_place(arr, dst, -1)
[docs] class OpenCLFFTshift(FFTshiftBase): KernelCls = OpenCLKernel ProcessingCls = OpenCLProcessing dtype_to_ctype = cl_dtype_to_ctype backend = "opencl" def _configure_kenel_initialization(self): self._kernel_init_args.append(self.processing.ctx) self._kernel_init_kwargs.update( { "filename": get_opencl_srcfile("fftshift.cl"), "queue": self.processing.queue, } ) def _configure_kernel_call(self): # TODO in-place fftshift needs to launch only arr.size//2 threads block = (16, 16, 1) grid = [updiv(a, b) * b for a, b in zip(self.shape[::-1], block)] self._kernel_kwargs = {"global_size": grid, "local_size": block}