Skip to content

nabu.estimation.distortion

[docs] module nabu.estimation.distortion

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import scipy.interpolate
from .translation import DetectorTranslationAlongBeam
from ..misc.filters import correct_spikes
from ..resources.logger import LoggerOrPrint


def estimate_flat_distortion(
    flat,
    image,
    tile_size=100,
    interpolation_kind="linear",
    padding_mode="edge",
    correction_spike_threshold=None,
    logger=None,
):
    """
    Estimate the wavefront distortion on a flat image, from another image.

    Parameters
    ----------
    flat: np.array
        The flat-field image to be corrected
    image: np.ndarray
        The image to correlate the flat against.
    tile_size: int
        The wavefront corrections are calculated by correlating the
        image to the flat, region by region. The regions are tiles of size tile_size
    interpolation_kind: "linear" or "cubic"
        The interpolation method used for interpolation
    padding_mode: string
        Padding mode. Must be valid for np.pad
        when wavefront correction is applied, the corrections are first found for the tiles,
        which gives the shift at the center of each tiled. Then, to interpolate the corrections,
        at the positions f every pixel, on must add also the border of the extremal tiles.
        This is done by padding with a width of 1, and using the mode given 'padding_mode'.
    correction_spike_threshold: float, optional
        By default it is None and no spike correction is performed on the shifts grid which is found by correlation.
        If set to a float, a spike removal will be applied using such threshold

    Returns
    --------
    coordinates: np.ndarray
        An array having dimensions (flat.shape[0], flat.shape[1], 2)
        where each coordinates[i,j] contains the coordinates of the position
        in the  image "flat" which correlates to the pixel (i,j) in the  image "im2".
    """
    logger = LoggerOrPrint(logger)
    starts_r = np.array(range(0, image.shape[0] - tile_size, tile_size))
    starts_c = np.array(range(0, image.shape[1] - tile_size, tile_size))
    cor1 = np.zeros([len(starts_r), len(starts_c)], np.float32)
    cor2 = np.zeros([len(starts_r), len(starts_c)], np.float32)

    shift_finder = DetectorTranslationAlongBeam()
    for ir, r in enumerate(starts_r):
        for ic, c in enumerate(starts_c):
            try:
                coeff_v, coeff_h, shifts_vh_per_img = shift_finder.find_shift(
                    np.array([image[r : r + tile_size, c : c + tile_size], flat[r : r + tile_size, c : c + tile_size]]),
                    np.array([0, 1]),
                    return_shifts=True,
                    low_pass=(1.0, 0.3),
                    high_pass=(tile_size, tile_size * 0.3),
                )
                cor1[ir, ic], cor2[ir, ic] = shifts_vh_per_img[1]

            except ValueError as e:
                if "positions are outside" in str(e):
                    logger.debug(str(e))
                    cor1[ir, ic], cor2[ir, ic] = (0, 0)
                else:
                    raise

    cor1[np.isnan(cor1)] = 0
    cor2[np.isnan(cor2)] = 0

    if correction_spike_threshold is not None:
        cor1 = correct_spikes(cor1, correction_spike_threshold)
        cor2 = correct_spikes(cor2, correction_spike_threshold)

    # TODO implement the previous spikes correction in CCDCorrection - median_clip
    # spikes_corrector = CCDCorrection(cor1.shape, median_clip_thresh=3, abs_diff=True, preserve_borders=True)
    # cor1 = spikes_corrector.median_clip_correction(cor1)
    # cor2 = spikes_corrector.median_clip_correction(cor2)

    cor1 = np.pad(cor1, ((1, 1), (1, 1)), mode=padding_mode)
    cor2 = np.pad(cor2, ((1, 1), (1, 1)), mode=padding_mode)

    hp = np.concatenate([[0.0], starts_c + tile_size * 0.5, [image.shape[1]]])
    vp = np.concatenate([[0.0], starts_r + tile_size * 0.5, [image.shape[0]]])

    h_ticks = np.arange(image.shape[1]).astype(np.float32)
    v_ticks = np.arange(image.shape[0]).astype(np.float32)

    spline_degree = {"linear": 1, "cubic": 3}[interpolation_kind]

    interpolator = scipy.interpolate.RectBivariateSpline(vp, hp, cor1, kx=spline_degree, ky=spline_degree)
    cor1 = interpolator(h_ticks, v_ticks)

    interpolator = scipy.interpolate.RectBivariateSpline(vp, hp, cor2, kx=spline_degree, ky=spline_degree)
    cor2 = interpolator(h_ticks, v_ticks)

    hh = np.arange(image.shape[1]).astype(np.float32)
    vv = np.arange(image.shape[0]).astype(np.float32)
    unshifted_v, unshifted_h = np.meshgrid(vv, hh, indexing="ij")

    shifted_v = unshifted_v - cor1
    shifted_h = unshifted_h - cor2

    coordinates = np.transpose(np.array([shifted_v, shifted_h]), axes=[1, 2, 0])

    return coordinates