import numpy as np
from ..utils import BaseClassError
class _BaseFFT:
"""
A base class for FFTs.
"""
implem = "none"
ProcessingCls = BaseClassError
def __init__(self, shape, dtype, r2c=True, axes=None, normalize="rescale", **backend_options):
"""
Base class for Fast Fourier Transform (FFT).
Parameters
----------
shape: list of int
Shape of the input data
dtype: str or numpy.dtype
Data type of the input data
r2c: bool, optional
Whether to use real-to-complex transform for real-valued input. Default is True.
axes: list of int, optional
Axes along which FFT is computed.
* For 2D transform: axes=(1,0)
* For batched 1D transform of 2D image: axes=(-1,)
normalize: str, optional
Whether to normalize FFT and IFFT. Possible values are:
* "rescale": in this case, Fourier data is divided by "N"
before IFFT, so that IFFT(FFT(data)) = data.
This corresponds to numpy norm=None i.e norm="backward".
* "ortho": in this case, FFT and IFFT are adjoint of eachother,
the transform is unitary. Both FFT and IFFT are scaled with 1/sqrt(N).
* "none": no normalizatio is done : IFFT(FFT(data)) = data*N
Other parameters
-----------------
backend_options: dict, optional
Parameters to pass to CudaProcessing or OpenCLProcessing class.
"""
self._init_backend(backend_options)
self._set_dtypes(dtype, r2c)
self._set_shape_and_axes(shape, axes)
self._configure_batched_transform()
self._configure_normalization(normalize)
self._compute_fft_plans()
def _init_backend(self, backend_options):
self.processing = self.ProcessingCls(**backend_options)
def _set_dtypes(self, dtype, r2c):
self.dtype = np.dtype(dtype)
dtypes_mapping = {
np.dtype("float32"): np.complex64,
np.dtype("float64"): np.complex128,
np.dtype("complex64"): np.complex64,
np.dtype("complex128"): np.complex128,
}
if self.dtype not in dtypes_mapping:
raise ValueError("Invalid input data type: got %s" % self.dtype)
self.dtype_out = dtypes_mapping[self.dtype]
self.r2c = r2c
def _set_shape_and_axes(self, shape, axes):
# Input shape
if np.isscalar(shape):
shape = (shape,)
self.shape = shape
# Axes
default_axes = tuple(range(len(self.shape)))
if axes is None:
self.axes = default_axes
else:
self.axes = tuple(np.array(default_axes)[np.array(axes)])
# Output shape
shape_out = self.shape
if self.r2c:
reduced_dim = self.axes[-1] if self.axes is not None else -1
shape_out = list(shape_out)
shape_out[reduced_dim] = shape_out[reduced_dim] // 2 + 1
shape_out = tuple(shape_out)
self.shape_out = shape_out
def _configure_batched_transform(self):
pass
def _configure_normalization(self, normalize):
pass
def _compute_fft_plans(self):
pass
class _BaseVKFFT(_BaseFFT):
"""
FFT using VKFFT backend
"""
implem = "vkfft"
backend = "none"
ProcessingCls = BaseClassError
vkffs_cls = BaseClassError
def _configure_batched_transform(self):
if self.axes is not None and len(self.shape) == len(self.axes):
self.axes = None
return
if self.r2c:
# batched Real-to-complex transforms are supported only along fast axes
if not (is_fast_axes(len(self.shape), self.axes)):
raise ValueError("For %dD R2C, only batched transforms along fast axes are allowed" % (len(self.shape)))
self._vkfft_ndim = len(self.axes)
self.axes = None # vkfft still can do a batched transform by providing dim=XX, axes=None
def _configure_normalization(self, normalize):
self.normalize = normalize
self._vkfft_norm = {
"rescale": 1,
"backward": 1,
"ortho": "ortho",
"none": 0,
}.get(self.normalize, 1)
def _set_shape_and_axes(self, shape, axes):
super()._set_shape_and_axes(shape, axes)
self._vkfft_ndim = None
def _compute_fft_plans(self):
self._vkfft_plan = self.vkffs_cls(
self.shape,
self.dtype,
ndim=self._vkfft_ndim,
inplace=False,
norm=self._vkfft_norm,
r2c=self.r2c,
dct=False,
axes=self.axes,
strides=None,
**self._vkfft_other_init_kwargs,
)
def fft(self, array, output=None):
if output is None:
output = self.output_fft = self.processing.allocate_array(
"output_fft", self.shape_out, dtype=self.dtype_out
)
return self._vkfft_plan.fft(array, dest=output)
def ifft(self, array, output=None):
if output is None:
output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype)
return self._vkfft_plan.ifft(array, dest=output)
[docs]
def is_fast_axes(ndim, axes):
"""
Return true if "axes" are the fast dimensions
"""
all_axes = list(range(ndim))
axes = sorted([ax + ndim if ax < 0 else ax for ax in axes]) # transform "-1" to an actual axis index (1 for 2D)
return all_axes[-len(axes) :] == axes