Source code for nabu.preproc.flatfield_cuda

import numpy as np

from nabu.cuda.processing import CudaProcessing
from ..preproc.flatfield import FlatFieldArrays
from ..utils import get_cuda_srcfile
from ..io.reader import load_images_from_dataurl_dict
from ..cuda.utils import __has_pycuda__


[docs] class CudaFlatFieldArrays(FlatFieldArrays): def __init__( self, radios_shape, flats, darks, radios_indices=None, interpolation="linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, cuda_options=None, ): """ Initialize a flat-field normalization CUDA process. Please read the documentation of nabu.preproc.flatfield.FlatField for help on the parameters. """ # if distortion_correction is not None: raise NotImplementedError("Flats distortion correction is not implemented with the Cuda backend") # super().__init__( radios_shape, flats, darks, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, nan_value=nan_value, ) self.cuda_processing = CudaProcessing(**(cuda_options or {})) self._init_cuda_kernels() self._load_flats_and_darks_on_gpu() def _init_cuda_kernels(self): # TODO if self.interpolation != "linear": raise ValueError("Interpolation other than linar is not yet implemented in the cuda back-end") # self._cuda_fname = get_cuda_srcfile("flatfield.cu") options = [ "-DN_FLATS=%d" % self.n_flats, "-DN_DARKS=%d" % self.n_darks, ] if self.nan_value is not None: options.append("-DNAN_VALUE=%f" % self.nan_value) self.cuda_kernel = self.cuda_processing.kernel( "flatfield_normalization", self._cuda_fname, signature="PPPiiiPP", options=options ) self._nx = np.int32(self.shape[1]) self._ny = np.int32(self.shape[0]) def _load_flats_and_darks_on_gpu(self): # Flats self.d_flats = self.cuda_processing.allocate_array("d_flats", (self.n_flats,) + self.shape, np.float32) for i, flat_idx in enumerate(self._sorted_flat_indices): self.d_flats[i].set(np.ascontiguousarray(self.flats[flat_idx], dtype=np.float32)) # Darks self.d_darks = self.cuda_processing.allocate_array("d_darks", (self.n_darks,) + self.shape, np.float32) for i, dark_idx in enumerate(self._sorted_dark_indices): self.d_darks[i].set(np.ascontiguousarray(self.darks[dark_idx], dtype=np.float32)) self.d_darks_indices = self.cuda_processing.to_device( "d_darks_indices", np.array(self._sorted_dark_indices, dtype=np.int32) ) # Indices self.d_flats_indices = self.cuda_processing.to_device("d_flats_indices", self.flats_idx) self.d_flats_weights = self.cuda_processing.to_device("d_flats_weights", self.flats_weights)
[docs] def normalize_radios(self, radios): """ Apply a flat-field correction, with the current parameters, to a stack of radios. Parameters ----------- radios_shape: `pycuda.gpuarray.GPUArray` Radios chunk. """ if not (isinstance(radios, self.cuda_processing.array_class)): raise ValueError("Expected a pycuda.gpuarray (got %s)" % str(type(radios))) if radios.dtype != np.float32: raise ValueError("radios must be in float32 dtype (got %s)" % str(radios.dtype)) if radios.shape != self.radios_shape: raise ValueError("Expected radios shape = %s but got %s" % (str(self.radios_shape), str(radios.shape))) self.cuda_kernel( radios, self.d_flats, self.d_darks, self._nx, self._ny, np.int32(self.n_radios), self.d_flats_indices, self.d_flats_weights, ) if self.normalize_srcurrent: for i in range(self.n_radios): radios[i] *= self.srcurrent_ratios[i] return radios
CudaFlatField = CudaFlatFieldArrays
[docs] class CudaFlatFieldDataUrls(CudaFlatField): def __init__( self, radios_shape, flats, darks, radios_indices=None, interpolation="linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, cuda_options=None, **chunk_reader_kwargs, ): flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs) darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs) super().__init__( radios_shape, flats_arrays_dict, darks_arrays_dict, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, cuda_options=cuda_options, )