#!/usr/bin/env python
# /*##########################################################################
#
# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/
"""Find contours examples

.. note:: This module has an optional dependency with sci-kit image library.
    You might need to install it if you don't already have it.
"""

import logging
import sys
import numpy
import time

from silx.gui import qt
import silx.gui.plot
from silx.gui.colors import Colormap
import silx.image.bilinear

logging.basicConfig()
_logger = logging.getLogger("find_contours")


try:
    import skimage
except ImportError:
    _logger.debug("Error while importing skimage", exc_info=True)
    skimage = None

if skimage is not None:
    try:
        from silx.image.marchingsquares._skimage import MarchingSquaresSciKitImage
    except ImportError:
        _logger.debug("Error while importing MarchingSquaresSciKitImage", exc_info=True)
        MarchingSquaresSciKitImage = None
else:
    MarchingSquaresSciKitImage = None


def rescale_image(image, shape):
    y, x = numpy.ogrid[: shape[0], : shape[1]]
    y, x = y * 1.0 * (image.shape[0] - 1) / (shape[0] - 1), x * 1.0 * (
        image.shape[1] - 1
    ) / (shape[1] - 1)
    b = silx.image.bilinear.BilinearImage(image)
    # TODO: could be optimized using strides
    x2d = numpy.zeros_like(y) + x
    y2d = numpy.zeros_like(x) + y
    result = b.map_coordinates((y2d, x2d))
    return result


def create_spiral(size, nb=1, freq=100):
    half = size // 2
    y, x = numpy.ogrid[-half:half, -half:half]
    coef = 1.0 / half
    y, x = y * coef, x * coef + 0.0001
    distance = numpy.sqrt(x * x + y * y)
    angle = numpy.arctan(y / x)
    data = numpy.sin(angle * nb * 2 + distance * freq * half / 100, dtype=numpy.float32)
    return data


def create_magnetic_field(size, x1=0.0, y1=0.0, x2=0.0, y2=0.0):
    half = size // 2
    yy, xx = numpy.ogrid[-half:half, -half:half]
    coef = 1.0 / half
    yy1, xx1 = (yy + half * y1) * coef, (xx + half * x1) * coef
    distance1 = numpy.sqrt(xx1 * xx1 + yy1 * yy1)
    yy2, xx2 = (yy + half * y2) * coef, (xx + half * x2) * coef
    distance2 = numpy.sqrt(xx2 * xx2 + yy2 * yy2)
    return (numpy.arctan2(distance1, distance2) - numpy.pi * 0.25) * 1000


def create_gravity_field(size, objects):
    half = size // 2
    yy, xx = numpy.ogrid[-half:half, -half:half]
    coef = 1.0 / half

    def distance(x, y):
        yy1, xx1 = (yy + half * y) * coef, (xx + half * x) * coef
        return numpy.sqrt(xx1**2 + yy1**2)

    result = numpy.zeros((size, size), dtype=numpy.float32)
    for x, y, m in objects:
        result += m / distance(x, y)
    return numpy.log(result) * 1000


def create_gradient(size, dx=0, dy=0, sx=1.0, sy=1.0):
    half = size // 2
    yy, xx = numpy.ogrid[-half:half, -half:half]
    coef = 1.0 / half
    yy, xx = (yy - (dy * half)) * coef, (xx - (dx * half)) * coef + 0.0001
    distance = numpy.sqrt(xx * xx * sx + yy * yy * sy)
    return distance


def create_composite_gradient(size, dx=0, dy=0, sx=1.0, sy=1.0):
    hole = (size - 4) // 4
    gap = 10
    base = create_gradient(size + hole + gap * 4, dx, dy, sx, sy)
    result = numpy.zeros((size, size))
    width = (size - 2) // 2
    half_hole = hole // 2

    def copy_module(x1, y1, x2, y2, width, height):
        result[y1 : y1 + height, x1 : x1 + width] = base[
            y2 : y2 + height, x2 : x2 + width
        ]

    y1 = 0
    y2 = 0
    copy_module(0, y1, half_hole, y2, width, hole)
    copy_module(width + 1, y1, half_hole + width, y2, width, hole)

    y1 += hole + 1
    y2 += hole + gap
    copy_module(0, y1, 0, y2, width, hole)
    copy_module(width + 1, y1, width + hole, y2, width, hole)

    y1 += hole + 1
    y2 += hole + gap
    copy_module(0, y1, half_hole, y2, width, hole)
    copy_module(width + 1, y1, half_hole + width, y2, width, hole)

    y1 += hole + 1
    y2 += hole + gap
    copy_module(0, y1, half_hole, y2, width, hole)
    copy_module(width + 1, y1, half_hole + width, y2, width, hole)

    return result


def create_value_noise(shape, octaves=8, weights=None, first_array=None):
    data = numpy.zeros(shape, dtype=numpy.float32)
    t = 2
    for i in range(octaves):
        if t > shape[0] and t > shape[1]:
            break
        if i == 0 and first_array is not None:
            d = first_array
        else:
            if weights is None:
                w = (256 >> i) - 1
            else:
                w = weights[i]
            d = numpy.random.randint(w, size=(t, t)).astype(dtype=numpy.uint8)
        d = rescale_image(d, shape)
        data = data + d
        t = t << 1
    return data


def create_island(shape, summit, under_water):
    # Force a centric shape
    first_array = numpy.zeros((4, 4), dtype=numpy.uint8)
    first_array[1:3, 1:3] = 255
    weights = [255] + [(256 >> (i)) - 1 for i in range(8)]
    data = create_value_noise(
        shape, octaves=7, first_array=first_array, weights=weights
    )
    # more slops
    data *= data
    # normalize the height
    data -= data.min()
    data = data * ((summit + under_water) / data.max()) - under_water
    return data


def createRgbaMaskImage(mask, color):
    """Generate an RGBA image where a custom color is apply to the location of
    the mask. Non masked part of the image is transparent."""
    image = numpy.zeros((mask.shape[0], mask.shape[1], 4), dtype=numpy.uint8)
    color = numpy.array(color)
    image[mask] = color
    return image


class FindContours(qt.QMainWindow):
    """
    This window show an example of use of a Hdf5TreeView.

    The tree is initialized with a list of filenames. A panel allow to play
    with internal property configuration of the widget, and a text screen
    allow to display events.
    """

    def __init__(self, filenames=None):
        """
        :param files_: List of HDF5 or Spec files (pathes or
            :class:`silx.io.spech5.SpecH5` or :class:`h5py.File`
            instances)
        """
        qt.QMainWindow.__init__(self)
        self.setWindowTitle("Silx HDF5 widget example")

        self.__plot = silx.gui.plot.Plot2D(parent=self)
        dummy = numpy.array([[0]])
        self.__plot.addImage(dummy, legend="image", z=-10, replace=False)
        dummy = numpy.array([[[0, 0, 0, 0]]])
        self.__plot.addImage(dummy, legend="iso-pixels", z=0, replace=False)

        self.__algo = None
        self.__polygons = []
        self.__customPolygons = []
        self.__image = None
        self.__mask = None
        self.__customValue = None

        mainPanel = qt.QWidget(self)
        layout = qt.QHBoxLayout()
        layout.addWidget(self.__createConfigurationPanel(self))
        layout.addWidget(self.__plot)
        mainPanel.setLayout(layout)

        self.setCentralWidget(mainPanel)

    def __createConfigurationPanel(self, parent):
        panel = qt.QWidget(parent=parent)
        layout = qt.QVBoxLayout()
        panel.setLayout(layout)

        self.__kind = qt.QButtonGroup(self)
        self.__kind.setExclusive(True)

        group = qt.QGroupBox(self)
        group.setTitle("Image")
        layout.addWidget(group)
        groupLayout = qt.QVBoxLayout(group)

        button = qt.QRadioButton(parent=panel)
        button.setText("Island")
        button.clicked.connect(self.generateIsland)
        button.setCheckable(True)
        button.setChecked(True)
        groupLayout.addWidget(button)
        self.__kind.addButton(button)

        button = qt.QRadioButton(parent=panel)
        button.setText("Gravity")
        button.clicked.connect(self.generateGravityField)
        button.setCheckable(True)
        groupLayout.addWidget(button)
        self.__kind.addButton(button)

        button = qt.QRadioButton(parent=panel)
        button.setText("Magnetic")
        button.clicked.connect(self.generateMagneticField)
        button.setCheckable(True)
        groupLayout.addWidget(button)
        self.__kind.addButton(button)

        button = qt.QRadioButton(parent=panel)
        button.setText("Spiral")
        button.clicked.connect(self.generateSpiral)
        button.setCheckable(True)
        groupLayout.addWidget(button)
        self.__kind.addButton(button)

        button = qt.QRadioButton(parent=panel)
        button.setText("Gradient")
        button.clicked.connect(self.generateGradient)
        button.setCheckable(True)
        groupLayout.addWidget(button)
        self.__kind.addButton(button)

        button = qt.QRadioButton(parent=panel)
        button.setText("Composite gradient")
        button.clicked.connect(self.generateCompositeGradient)
        button.setCheckable(True)
        groupLayout.addWidget(button)
        self.__kind.addButton(button)

        button = qt.QPushButton(parent=panel)
        button.setText("Generate a new image")
        button.clicked.connect(self.generate)
        groupLayout.addWidget(button)

        # Contours

        group = qt.QGroupBox(self)
        group.setTitle("Contours")
        layout.addWidget(group)
        groupLayout = qt.QVBoxLayout(group)

        button = qt.QCheckBox(parent=panel)
        button.setText("Use the plot's mask")
        button.setCheckable(True)
        button.setChecked(True)
        button.clicked.connect(self.updateContours)
        groupLayout.addWidget(button)
        self.__useMaskButton = button

        button = qt.QPushButton(parent=panel)
        button.setText("Update contours")
        button.clicked.connect(self.updateContours)
        groupLayout.addWidget(button)

        # Implementations

        group = qt.QGroupBox(self)
        group.setTitle("Implementation")
        layout.addWidget(group)
        groupLayout = qt.QVBoxLayout(group)

        self.__impl = qt.QButtonGroup(self)
        self.__impl.setExclusive(True)

        button = qt.QRadioButton(parent=panel)
        button.setText("silx")
        button.clicked.connect(self.updateContours)
        button.setCheckable(True)
        button.setChecked(True)
        groupLayout.addWidget(button)
        self.__implMerge = button
        self.__impl.addButton(button)

        button = qt.QRadioButton(parent=panel)
        button.setText("silx with cache")
        button.clicked.connect(self.updateContours)
        button.setCheckable(True)
        groupLayout.addWidget(button)
        self.__implMergeCache = button
        self.__impl.addButton(button)

        button = qt.QRadioButton(parent=panel)
        button.setText("skimage")
        button.clicked.connect(self.updateContours)
        button.setCheckable(True)
        groupLayout.addWidget(button)
        self.__implSkimage = button
        self.__impl.addButton(button)
        if MarchingSquaresSciKitImage is None:
            button.setEnabled(False)
            button.setToolTip("skimage is not installed or not compatible")

        # Processing

        group = qt.QGroupBox(self)
        group.setTitle("Processing")
        layout.addWidget(group)
        group.setLayout(self.__createInfoLayout(group))

        # Processing

        group = qt.QGroupBox(self)
        group.setTitle("Custom level")
        layout.addWidget(group)
        groupLayout = qt.QVBoxLayout(group)

        qt.QLabel(parent=panel)
        self.__value = qt.QSlider(panel)
        self.__value.setOrientation(qt.Qt.Horizontal)
        self.__value.sliderMoved.connect(self.__updateCustomContours)
        self.__value.valueChanged.connect(self.__updateCustomContours)
        groupLayout.addWidget(self.__value)

        return panel

    def __createInfoLayout(self, parent):
        layout = qt.QGridLayout()

        header = qt.QLabel(parent=parent)
        header.setText("Time: ")
        label = qt.QLabel(parent=parent)
        label.setText("")
        layout.addWidget(header, 0, 0)
        layout.addWidget(label, 0, 1)
        self.__timeLabel = label

        header = qt.QLabel(parent=parent)
        header.setText("Nb polygons: ")
        label = qt.QLabel(parent=parent)
        label.setText("")
        layout.addWidget(header, 2, 0)
        layout.addWidget(label, 2, 1)
        self.__polygonsLabel = label

        header = qt.QLabel(parent=parent)
        header.setText("Nb points: ")
        label = qt.QLabel(parent=parent)
        label.setText("")
        layout.addWidget(header, 1, 0)
        layout.addWidget(label, 1, 1)
        self.__pointsLabel = label

        return layout

    def __cleanCustomContour(self):
        for name in self.__customPolygons:
            self.__plot.removeCurve(name)
        self.__customPolygons = []
        item = self.__plot.getImage(legend="iso-pixels")
        item.setData([[[0, 0, 0, 0]]])

    def __cleanPolygons(self):
        for name in self.__polygons:
            self.__plot.removeCurve(name)

    def clean(self):
        self.__cleanCustomContour()
        self.__cleanPolygons()
        self.__polygons = []
        self.__image = None
        self.__mask = None

    def updateContours(self):
        self.__redrawContours()
        self.updateCustomContours()

    def __updateCustomContours(self, value):
        self.__customValue = value
        self.updateCustomContours()

    def updateCustomContours(self):
        if self.__algo is None:
            return
        value = self.__customValue
        self.__cleanCustomContour()
        if value is None:
            return

        # iso pixels
        iso_pixels = self.__algo.find_pixels(value)
        if len(iso_pixels) != 0:
            mask = numpy.zeros(self.__image.shape, dtype=numpy.int8)
            indexes = iso_pixels[:, 0] * self.__image.shape[1] + iso_pixels[:, 1]
            mask = mask.ravel()
            mask[indexes] = 1
            mask.shape = self.__image.shape
            mask = createRgbaMaskImage(mask, color=numpy.array([255, 0, 0, 128]))
            item = self.__plot.getImage(legend="iso-pixels")
            item.setData(mask)

        # iso contours
        polygons = self.__algo.find_contours(value)
        for ipolygon, polygon in enumerate(polygons):
            if len(polygon) == 0:
                continue
            x = polygon[:, 1] + 0.5
            y = polygon[:, 0] + 0.5
            legend = "custom-polygon-%d" % ipolygon
            self.__customPolygons.append(legend)
            self.__plot.addCurve(
                x=x,
                y=y,
                linestyle="--",
                color="red",
                linewidth=2.0,
                legend=legend,
                resetzoom=False,
            )

    def __updateAlgo(self, image, mask=None):
        if mask is None:
            if self.__useMaskButton.isChecked():
                mask = self.__plot.getMaskToolsDockWidget().getSelectionMask()

        self.__image = image
        self.__mask = mask

        implButton = self.__impl.checkedButton()
        if implButton == self.__implMerge:
            from silx.image.marchingsquares import MarchingSquaresMergeImpl

            self.__algo = MarchingSquaresMergeImpl(self.__image, self.__mask)
        elif implButton == self.__implMergeCache:
            from silx.image.marchingsquares import MarchingSquaresMergeImpl

            self.__algo = MarchingSquaresMergeImpl(
                self.__image, self.__mask, use_minmax_cache=True
            )
        elif (
            implButton == self.__implSkimage and MarchingSquaresSciKitImage is not None
        ):
            self.__algo = MarchingSquaresSciKitImage(self.__image, self.__mask)
        else:
            _logger.error("No algorithm available")
            self.__algo = None

    def setData(self, image, mask=None, value=0.0):
        self.clean()

        self.__updateAlgo(image, mask=None)

        # image
        item = self.__plot.getImage(legend="image")
        item.setData(image)
        item.setColormap(self.__colormap)

        self.__plot.resetZoom()

    def __redrawContours(self):
        self.__updateAlgo(self.__image)
        if self.__algo is None:
            return
        self.__cleanPolygons()
        self.__drawContours(self.__values, self.__lineStyleCallback)

    def __drawContours(self, values, lineStyleCallback=None):
        if self.__algo is None:
            return

        self.__values = values
        self.__lineStyleCallback = lineStyleCallback
        if self.__values is None:
            return

        nbTime = 0
        nbPolygons = 0
        nbPoints = 0

        # iso contours
        ipolygon = 0
        for ivalue, value in enumerate(values):
            startTime = time.time()
            polygons = self.__algo.find_contours(value)
            nbTime += time.time() - startTime
            nbPolygons += len(polygons)
            for polygon in polygons:
                if len(polygon) == 0:
                    continue
                nbPoints += len(polygon)
                x = polygon[:, 1] + 0.5
                y = polygon[:, 0] + 0.5
                legend = "polygon-%d" % ipolygon
                if lineStyleCallback is not None:
                    extraStyle = lineStyleCallback(value, ivalue, ipolygon)
                else:
                    extraStyle = {"linestyle": "-", "linewidth": 1.0, "color": "black"}
                self.__polygons.append(legend)
                self.__plot.addCurve(
                    x=x, y=y, legend=legend, resetzoom=False, **extraStyle
                )
                ipolygon += 1

        self.__timeLabel.setText("%0.3fs" % nbTime)
        self.__polygonsLabel.setText("%d" % nbPolygons)
        self.__pointsLabel.setText("%d" % nbPoints)

    def __defineDefaultValues(self, value=None):
        # Do not use min and max to avoid to create iso contours on small
        # and many artefacts
        if value is None:
            value = self.__image.mean()
        self.__customValue = value
        div = 12
        delta = (self.__image.max() - self.__image.min()) / div
        self.__value.setValue(int(numpy.round(value)))
        minv = self.__image.min() + delta
        maxv = self.__image.min() + delta * (div - 1)
        self.__value.setRange(int(numpy.floor(minv)), int(numpy.ceil(maxv)))
        self.updateCustomContours()

    def generate(self):
        self.__kind.checkedButton().click()

    def generateSpiral(self):
        shape = 512
        nb_spiral = numpy.random.randint(1, 8)
        freq = numpy.random.randint(2, 50)
        image = create_spiral(shape, nb_spiral, freq)
        image *= 1000.0
        self.__colormap = Colormap("cool")
        self.setData(image=image, mask=None)
        self.__defineDefaultValues()

    def generateIsland(self):
        shape = (512, 512)
        image = create_island(shape, summit=4808.72, under_water=1500)
        self.__colormap = Colormap("terrain")
        self.setData(image=image, mask=None)

        values = range(-800, 5000, 200)

        def styleCallback(value, ivalue, ipolygon):
            if value == 0:
                style = {"linestyle": "-", "linewidth": 1.0, "color": "black"}
            elif value % 1000 == 0:
                style = {"linestyle": "--", "linewidth": 0.5, "color": "black"}
            else:
                style = {"linestyle": "--", "linewidth": 0.1, "color": "black"}
            return style

        self.__drawContours(values, styleCallback)

        self.__value.setValue(0)
        self.__value.setRange(0, 5000)
        self.__updateCustomContours(0)

    def generateMagneticField(self):
        shape = 512
        x1 = numpy.random.random() * 2 - 1
        y1 = numpy.random.random() * 2 - 1
        x2 = numpy.random.random() * 2 - 1
        y2 = numpy.random.random() * 2 - 1
        image = create_magnetic_field(shape, x1, y1, x2, y2)
        self.__colormap = Colormap("coolwarm")
        self.setData(image=image, mask=None)

        maximum = abs(image.max())
        m = abs(image.min())
        if m > maximum:
            maximum = m
        maximum = int(maximum)
        values = range(-maximum, maximum, maximum // 20)

        def styleCallback(value, ivalue, ipolygon):
            if (ivalue % 2) == 0:
                style = {"linestyle": "-", "linewidth": 0.5, "color": "black"}
            else:
                style = {"linestyle": "-", "linewidth": 0.5, "color": "white"}
            return style

        self.__drawContours(values, styleCallback)
        self.__defineDefaultValues(value=0)

    def generateGravityField(self):
        shape = 512
        nb = numpy.random.randint(2, 10)
        objects = []
        for _ in range(nb):
            x = numpy.random.random() * 2 - 1
            y = numpy.random.random() * 2 - 1
            m = numpy.random.random() * 10 + 1.0
            objects.append((x, y, m))
        image = create_gravity_field(shape, objects)
        self.__colormap = Colormap("inferno")
        self.setData(image=image, mask=None)

        delta = (image.max() - image.min()) / 30.0
        values = numpy.arange(image.min(), image.max(), delta)

        def styleCallback(value, ivalue, ipolygon):
            return {"linestyle": "-", "linewidth": 0.1, "color": "white"}

        self.__drawContours(values, styleCallback)
        self.__defineDefaultValues()

    def generateGradient(self):
        shape = 512
        dx = numpy.random.random() * 2 - 1
        dy = numpy.random.random() * 2 - 1
        sx = numpy.random.randint(10, 5000) / 10.0
        sy = numpy.random.randint(10, 5000) / 10.0
        image = create_gradient(shape, dx=dx, dy=dy, sx=sx, sy=sy)
        image *= 1000.0

        def styleCallback(value, ivalue, ipolygon):
            colors = [
                "#9400D3",
                "#4B0082",
                "#0000FF",
                "#00FF00",
                "#FFFF00",
                "#FF7F00",
                "#FF0000",
            ]
            color = colors[ivalue % len(colors)]
            style = {"linestyle": "-", "linewidth": 2.0, "color": color}
            return style

        delta = (image.max() - image.min()) / 9.0
        values = numpy.arange(image.min(), image.max(), delta)
        values = values[1:8]

        self.__colormap = Colormap("Greys")
        self.setData(image=image, mask=None)
        self.__drawContours(values, styleCallback)
        self.__defineDefaultValues()

    def generateCompositeGradient(self):
        shape = 512
        hole = 1 / 4.0
        dx = numpy.random.random() * hole - hole / 2.0
        dy = numpy.random.random() * hole - hole * 2
        sx = numpy.random.random() * 10.0 + 1
        sy = numpy.random.random() * 10.0 + 1
        image = create_composite_gradient(shape, dx, dy, sx, sy)
        image *= 1000.0

        def styleCallback(value, ivalue, ipolygon):
            colors = [
                "#9400D3",
                "#4B0082",
                "#0000FF",
                "#00FF00",
                "#FFFF00",
                "#FF7F00",
                "#FF0000",
            ]
            color = colors[ivalue % len(colors)]
            style = {"linestyle": "-", "linewidth": 2.0, "color": color}
            return style

        delta = (image.max() - image.min()) / 9.0
        values = numpy.arange(image.min(), image.max(), delta)
        values = values[1:8]

        self.__colormap = Colormap("Greys")
        self.setData(image=image, mask=None)
        self.__drawContours(values, styleCallback)
        self.__defineDefaultValues()


def main():
    app = qt.QApplication([])
    sys.excepthook = qt.exceptionHandler
    window = FindContours()
    window.generateIsland()
    window.show()
    result = app.exec()
    # remove ending warnings relative to QTimer
    app.deleteLater()
    return result


if __name__ == "__main__":
    result = main()
    sys.exit(result)
