Source code for nabu.preproc.flatfield

from multiprocessing.pool import ThreadPool
from bisect import bisect_left
import numpy as np
from import load_images_from_dataurl_dict
from ..utils import check_supported, get_num_threads

[docs] class FlatFieldArrays: """ A class for flat-field normalization """ # the variable below will be True for the derived class # which is taylored for to helical case _full_shape = False _supported_interpolations = ["linear", "nearest"] def __init__( self, radios_shape: tuple, flats, darks, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, n_threads=None, ): """ Initialize a flat-field normalization process. Parameters ---------- radios_shape: tuple A tuple describing the shape of the radios stack, in the form `(n_radios, n_z, n_x)`. flats: dict Dictionary where each key is the flat index, and the value is a numpy.ndarray of the flat image. darks: dict Dictionary where each key is the dark index, and the value is a numpy.ndarray of the dark image. radios_indices: array of int, optional Array containing the radios indices in the scan. `radios_indices[0]` is the index of the first radio, and so on. interpolation: str, optional Interpolation method for flat-field. See below for more details. distortion_correction: DistortionCorrection, optional A DistortionCorrection object. If provided, it is used to correct flat distortions based on each radio. nan_value: float, optional Which float value is used to replace nan/inf after flat-field. radios_srcurrent: array, optional Array with the same shape as radios_indices. Each item contains the synchrotron electric current. If not None, normalization with current is applied. Please refer to "Notes" for more information on this normalization. flats_srcurrent: array, optional Array with the same length as "flats". Each item is a measurement of the synchrotron electric current for the corresponding flat. The items must be ordered in the same order as the flats indices (`flats.keys()`). This parameter must be used along with 'radios_srcurrent'. Please refer to "Notes" for more information on this normalization. n_threads: int or None, optional Number of threads to use for flat-field correction. Default is to use half the threads. Important ---------- `flats` and `darks` are expected to be a dictionary with integer keys (the flats/darks indices) and numpy array values. You can use the following helper functions: `` and `` Notes ------ Usually, when doing a scan, only one or a few darks/flats are acquired. However, the flat-field normalization has to be performed on each radio, although incoming beam can fluctuate between projections. The usual way to overcome this is to interpolate between flats. If interpolation="nearest", the first flat is used for the first radios subset, the second flat is used for the second radios subset, and so on. If interpolation="linear", the normalization is done as a linear function of the radio index. The normalization with synchrotron electric current is done as follows. Let s = sr/sr_max denote the ratio between current and maximum current, D be the dark-current frame, and X' be the normalized frame. Then: srcurrent_normalization(X) = X' = (X - D)/s + D flatfield_normalization(X') = (X' - D)/(F' - D) = (X - D) / (F - D) * sF/sX So current normalization boils down to a scalar multiplication after flat-field. """ if self._full_shape: # this is never going to happen in this base class. But in the derived class for helical # which needs to keep the full shape if radios_indices is not None: radios_shape = (len(radios_indices),) + radios_shape[1:] self._set_parameters(radios_shape, radios_indices, interpolation, nan_value) self._set_flats_and_darks(flats, darks) self._precompute_flats_indices_weights() self._configure_srcurrent_normalization(radios_srcurrent, flats_srcurrent) self.distortion_correction = distortion_correction self.n_threads = min(1, get_num_threads(n_threads) // 2) def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value): self._set_radios_shape(radios_shape) if radios_indices is None: radios_indices = np.arange(0, self.n_radios, dtype=np.int32) else: radios_indices = np.array(radios_indices, dtype=np.int32) self._check_radios_and_indices_congruence(radios_indices) self.radios_indices = radios_indices self.interpolation = interpolation check_supported(interpolation, self._supported_interpolations, "Interpolation mode") self.nan_value = nan_value self._radios_idx_to_pos = dict(zip(self.radios_indices, np.arange(self.radios_indices.size))) def _set_radios_shape(self, radios_shape): if len(radios_shape) == 2: self.radios_shape = (1,) + radios_shape elif len(radios_shape) == 3: self.radios_shape = radios_shape else: raise ValueError("Expected radios to have 2 or 3 dimensions") n_radios, n_z, n_x = self.radios_shape self.n_radios = n_radios self.n_angles = n_radios self.shape = (n_z, n_x) def _set_flats_and_darks(self, flats, darks): self._check_frames(flats, "flats", 1, 9999) self.n_flats = len(flats) self.flats = flats self._sorted_flat_indices = sorted(self.flats.keys()) if self._full_shape: # this is never going to happen in this base class. But in the derived class for helical # which needs to keep the full shape self.shape = flats[self._sorted_flat_indices[0]].shape self._flat2arrayidx = dict(zip(self._sorted_flat_indices, np.arange(self.n_flats))) self.flats_arr = np.zeros((self.n_flats,) + self.shape, "f") for i, idx in enumerate(self._sorted_flat_indices): self.flats_arr[i] = self.flats[idx] self._check_frames(darks, "darks", 1, 1) self.darks = darks self.n_darks = len(darks) self._sorted_dark_indices = sorted(self.darks.keys()) self._dark = None def _check_frames(self, frames, frames_type, min_frames_required, max_frames_supported): n_frames = len(frames) if n_frames < min_frames_required: raise ValueError("Need at least %d %s" % (min_frames_required, frames_type)) if n_frames > max_frames_supported: raise ValueError( "Flat-fielding with more than %d %s is not supported" % (max_frames_supported, frames_type) ) self._check_frame_shape(frames, frames_type) def _check_frame_shape(self, frames, frames_type): for frame_idx, frame in frames.items(): if frame.shape != self.shape: raise ValueError( "Invalid shape for %s %s: expected %s, but got %s" % (frames_type, frame_idx, str(self.shape), str(frame.shape)) ) def _check_radios_and_indices_congruence(self, radios_indices): if radios_indices.size != self.n_radios: raise ValueError( "Expected radios_indices to have length %s = n_radios, but got length %d" % (self.n_radios, radios_indices.size) ) def _precompute_flats_indices_weights(self): """ Build two arrays: "indices" and "weights". These arrays contain pre-computed information so that the interpolated flat is obtained with flat_interpolated = weight_prev * flat_prev + weight_next * flat_next where weight_prev, weight_next = weights[2*i], weights[2*i+1] idx_prev, idx_next = indices[2*i], indices[2*i+1] flat_prev, flat_next = flats[idx_prev], flats[idx_next] In words: - If a projection has an index between two flats, the equivalent flat is a linear interpolation between "previous flat" and "next flat". - If a projection has the same index as a flat, only this flat is used for normalization (this case normally never occurs, but it's handled in the code) """ def _interp_linear(idx, prev_next): if len(prev_next) == 1: # current index corresponds to an acquired flat weights = (1, 0) f_idx = (self._flat2arrayidx[prev_next[0]], -1) else: prev_idx, next_idx = prev_next delta = next_idx - prev_idx w1 = 1 - (idx - prev_idx) / delta w2 = 1 - (next_idx - idx) / delta weights = (w1, w2) f_idx = (self._flat2arrayidx[prev_idx], self._flat2arrayidx[next_idx]) return f_idx, weights def _interp_nearest(idx, prev_next): if len(prev_next) == 1: # current index corresponds to an acquired flat weights = (1, 0) f_idx = (self._flat2arrayidx[prev_next[0]], -1) else: prev_idx, next_idx = prev_next idx_to_take = prev_idx if abs(idx - prev_idx) < abs(idx - next_idx) else next_idx weights = (1, 0) f_idx = (self._flat2arrayidx[idx_to_take], -1) return f_idx, weights self.flats_idx = np.zeros((self.n_radios, 2), dtype=np.int32) self.flats_weights = np.zeros((self.n_radios, 2), dtype=np.float32) for i, idx in enumerate(self.radios_indices): prev_next = self.get_previous_next_indices(self._sorted_flat_indices, idx) if self.interpolation == "nearest": f_idx, weights = _interp_nearest(idx, prev_next) elif self.interpolation == "linear": f_idx, weights = _interp_linear(idx, prev_next) self.flats_idx[i] = f_idx self.flats_weights[i] = weights # pylint: disable=E1307 def _configure_srcurrent_normalization(self, radios_srcurrent, flats_srcurrent): self.normalize_srcurrent = False if radios_srcurrent is None or flats_srcurrent is None: return radios_srcurrent = np.array(radios_srcurrent) if radios_srcurrent.size != self.n_radios: raise ValueError( "Expected 'radios_srcurrent' to have %d elements but got %d" % (self.n_radios, radios_srcurrent.size) ) flats_srcurrent = np.array(flats_srcurrent) if flats_srcurrent.size != self.n_flats: raise ValueError( "Expected 'flats_srcurrent' to have %d elements but got %d" % (self.n_flats, flats_srcurrent.size) ) self.normalize_srcurrent = True self.radios_srcurrent = radios_srcurrent self.flats_srcurrent = flats_srcurrent self.srcurrent_ratios = np.zeros(self.n_radios, "f") # Flats SRCurrent is obtained with "nearest" interp, to emulate an already-done flats SR current normalization for i, radio_idx in enumerate(self.radios_indices): flat_idx = self.get_nearest_index(self._sorted_flat_indices, radio_idx) flat_srcurrent = self.flats_srcurrent[self._flat2arrayidx[flat_idx]] self.srcurrent_ratios[i] = flat_srcurrent / self.radios_srcurrent[i]
[docs] @staticmethod def get_previous_next_indices(arr, idx): pos = bisect_left(arr, idx) if pos == len(arr): # outside range return (arr[-1],) if arr[pos] == idx: return (idx,) if pos == 0: return (arr[0],) return arr[pos - 1], arr[pos]
[docs] @staticmethod def get_nearest_index(arr, idx): pos = bisect_left(arr, idx) if pos == len(arr) or arr[pos] == idx: return arr[-1] return arr[pos - 1] if idx - arr[pos - 1] < arr[pos] - idx else arr[pos]
[docs] @staticmethod def interp(pos, indices, weights, array, slice_y=slice(None, None), slice_x=slice(None, None)): """ Interpolate between two values. The interpolator consists in pre-computed arrays such that prev, next = indices[pos] w1, w2 = weights[pos] interpolated_value = w1 * array[prev] + w2 * array[next] """ prev_idx = indices[pos, 0] next_idx = indices[pos, 1] if slice_y != slice(None, None) or slice_x != slice(None, None): w1 = weights[pos, 0][slice_y, slice_x] w2 = weights[pos, 1][slice_y, slice_x] else: w1 = weights[pos, 0] w2 = weights[pos, 1] if next_idx == -1: val = array[prev_idx] else: val = w1 * array[prev_idx] + w2 * array[next_idx] return val
[docs] def get_flat(self, pos, dtype=np.float32, slice_y=slice(None, None), slice_x=slice(None, None)): flat = self.interp(pos, self.flats_idx, self.flats_weights, self.flats_arr, slice_y=slice_y, slice_x=slice_x) if flat.dtype != dtype: flat = np.ascontiguousarray(flat, dtype=dtype) return flat
[docs] def get_dark(self): if self._dark is None: first_dark_idx = self._sorted_dark_indices[0] dark = np.ascontiguousarray(self.darks[first_dark_idx], dtype=np.float32) self._dark = dark return self._dark
[docs] def remove_invalid_values(self, img): if self.nan_value is None: return invalid_mask = np.logical_not(np.isfinite(img)) img[invalid_mask] = self.nan_value
[docs] def normalize_radios(self, radios): """ Apply a flat-field normalization, with the current parameters, to a stack of radios. The processing is done in-place, meaning that the radios content is overwritten. Parameters ----------- radios: numpy.ndarray Radios chunk """ do_flats_distortion_correction = self.distortion_correction is not None dark = self.get_dark() def apply_flatfield(i): radio_data = radios[i] radio_data -= dark flat = self.get_flat(i) flat = flat - dark if do_flats_distortion_correction: flat = self.distortion_correction.estimate_and_correct(flat, radio_data) np.divide(radio_data, flat, out=radio_data) self.remove_invalid_values(radio_data) if self.n_threads > 2: with ThreadPool(self.n_threads) as tp:, range(self.n_radios)) else: for i in range(self.n_radios): apply_flatfield(i) if self.normalize_srcurrent: radios *= self.srcurrent_ratios[:, np.newaxis, np.newaxis] return radios
[docs] def normalize_single_radio( self, radio, radio_idx, dtype=np.float32, slice_y=slice(None, None), slice_x=slice(None, None) ): """ Apply a flat-field normalization to a single projection image. """ dark = self.get_dark()[slice_y, slice_x] radio -= dark radio_pos = self._radios_idx_to_pos[radio_idx] flat = self.get_flat(radio_pos, dtype=dtype, slice_y=slice_y, slice_x=slice_x) flat = flat - dark if self.distortion_correction is not None: flat = self.distortion_correction.estimate_and_correct(flat, radio) radio /= flat if self.normalize_srcurrent: radio *= self.srcurrent_ratios[radio_pos] self.remove_invalid_values(radio) return radio
FlatField = FlatFieldArrays
[docs] class FlatFieldDataUrls(FlatField): def __init__( self, radios_shape: tuple, flats: dict, darks: dict, radios_indices=None, interpolation: str = "linear", distortion_correction=None, nan_value=1.0, radios_srcurrent=None, flats_srcurrent=None, **chunk_reader_kwargs, ): """ Initialize a flat-field normalization process with DataUrls. Parameters ---------- radios_shape: tuple A tuple describing the shape of the radios stack, in the form `(n_radios, n_z, n_x)`. flats: dict Dictionary where the key is the flat index, and the value is a pointing to the flat. darks: dict Dictionary where the key is the dark index, and the value is a pointing to the dark. radios_indices: array, optional Array containing the radios indices. `radios_indices[0]` is the index of the first radio, and so on. interpolation: str, optional Interpolation method for flat-field. See below for more details. distortion_correction: DistortionCorrection, optional A DistortionCorrection object. If provided, it is used to correct flat distortions based on each radio. nan_value: float, optional Which float value is used to replace nan/inf after flat-field. Other Parameters ---------------- The other named parameters are passed to ChunkReader(). Please read its documentation for more information. Notes ------ Usually, when doing a scan, only one or a few darks/flats are acquired. However, the flat-field normalization has to be performed on each radio, although incoming beam can fluctuate between projections. The usual way to overcome this is to interpolate between flats. If interpolation="nearest", the first flat is used for the first radios subset, the second flat is used for the second radios subset, and so on. If interpolation="linear", the normalization is done as a linear function of the radio index. """ flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs) darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs) super().__init__( radios_shape, flats_arrays_dict, darks_arrays_dict, radios_indices=radios_indices, interpolation=interpolation, distortion_correction=distortion_correction, nan_value=nan_value, radios_srcurrent=radios_srcurrent, flats_srcurrent=flats_srcurrent, )