import pytest

import numpy as np
from numpy.testing import assert_equal, assert_allclose

from glue.core import Data, DataCollection
from glue.core.coordinates import Coordinates, IdentityCoordinates
from glue.core.link_helpers import LinkSame
from glue.core.exceptions import IncompatibleDataException, IncompatibleAttribute

from ..state import ImageViewerState, ImageLayerState, AggregateSlice


class SimpleCoordinates(Coordinates):

    def __init__(self):
        super().__init__(pixel_n_dim=3, world_n_dim=3)

    def pixel_to_world_values(self, *args):
        return tuple([2.5 * p for p in args])

    def world_to_pixel_values(self, *args):
        return tuple([0.4 * w for w in args])

    @property
    def axis_correlation_matrix(self):
        matrix = np.zeros((self.world_n_dim, self.pixel_n_dim), dtype=bool)
        matrix[2, 2] = True
        matrix[0:2, 0:2] = True
        return matrix


class TestImageViewerState(object):

    def setup_method(self, method):
        self.state = ImageViewerState()

    def test_pixel_world_linking(self):

        data = Data(label='data', x=[[1, 2], [3, 4]], y=[[5, 6], [7, 8]],
                    coords=IdentityCoordinates(n_dim=2))
        layer_state = ImageLayerState(layer=data, viewer_state=self.state)
        self.state.layers.append(layer_state)

        w1, w2 = data.world_component_ids
        p1, p2 = data.pixel_component_ids

        self.state.reference_data = data

        # Setting world components should set the pixel ones

        self.state.x_att_world = w1
        self.state.y_att_world = w2

        assert self.state.x_att is p1
        assert self.state.y_att is p2

        # Setting one component to the same as the other should trigger the other
        # to flip to prevent them from both being the same

        self.state.x_att_world = w2
        assert self.state.x_att is p2
        assert self.state.y_att is p1
        assert self.state.y_att_world is w1

        self.state.y_att_world = w2
        assert self.state.x_att is p1
        assert self.state.x_att_world is w1
        assert self.state.y_att is p2

        # Changing x_att and y_att should change the world equivalents

        self.state.x_att = p2
        assert self.state.x_att_world is w2
        assert self.state.y_att is p1
        assert self.state.y_att_world is w1

        self.state.y_att = p2
        assert self.state.y_att_world is w2
        assert self.state.x_att is p1
        assert self.state.x_att_world is w1


class TestSlicingAggregation():

    def setup_method(self, method):
        self.viewer_state = ImageViewerState()
        self.data = Data(x=np.ones((3, 4, 5, 6, 7)))
        self.layer_state = ImageLayerState(layer=self.data, viewer_state=self.viewer_state)
        self.viewer_state.layers.append(self.layer_state)
        self.p = self.data.pixel_component_ids

    def test_default(self):
        # Check default settings
        assert self.viewer_state.x_att == self.p[4]
        assert self.viewer_state.y_att == self.p[3]
        assert self.viewer_state.slices == (0, 0, 0, 0, 0)
        assert self.layer_state.get_sliced_data().shape == (6, 7)

    def test_flipped(self):
        # Make sure slice is transposed if needed
        self.viewer_state.x_att = self.p[3]
        self.viewer_state.y_att = self.p[4]
        assert self.viewer_state.slices == (0, 0, 0, 0, 0)
        assert self.layer_state.get_sliced_data().shape == (7, 6)

    def test_slice_preserved(self):
        # Make sure slice stays the same if changing attributes
        self.viewer_state.slices = (1, 3, 2, 5, 4)
        self.viewer_state.x_att = self.p[2]
        self.viewer_state.y_att = self.p[4]
        assert self.viewer_state.slices == (1, 3, 2, 5, 4)
        assert self.viewer_state.wcsaxes_slice == ['y', 5, 'x', 3, 1]
        assert self.layer_state.get_sliced_data().shape == (7, 5)
        self.viewer_state.x_att = self.p[2]
        self.viewer_state.y_att = self.p[1]
        assert self.viewer_state.slices == (1, 3, 2, 5, 4)
        assert self.viewer_state.wcsaxes_slice == [4, 5, 'x', 'y', 1]
        assert self.layer_state.get_sliced_data().shape == (4, 5)
        self.viewer_state.x_att = self.p[0]
        self.viewer_state.y_att = self.p[4]
        assert self.viewer_state.slices == (1, 3, 2, 5, 4)
        assert self.viewer_state.wcsaxes_slice == ['y', 5, 2, 3, 'x']
        assert self.layer_state.get_sliced_data().shape == (7, 3)

    def test_aggregation(self):
        # Check whether using AggregateSlice works
        slc1 = AggregateSlice(slice(None), 0, np.mean)
        slc2 = AggregateSlice(slice(2, 5), 3, np.sum)
        self.viewer_state.slices = (slc1, 3, 2, slc2, 4)
        self.viewer_state.x_att = self.p[2]
        self.viewer_state.y_att = self.p[4]
        assert self.viewer_state.slices == (slc1, 3, 2, slc2, 4)
        assert self.viewer_state.wcsaxes_slice == ['y', 3, 'x', 3, 0]
        result = self.layer_state.get_sliced_data()
        assert result.shape == (7, 5)
        assert_equal(result, 3)  # sum along 3 indices in one of the dimensions


class TestReprojection():

    def setup_method(self, method):

        self.data_collection = DataCollection()

        self.array = np.arange(3024).reshape((6, 7, 8, 9))

        # The reference dataset. Shape is (6, 7, 8, 9).
        self.data1 = Data(x=self.array, coords=IdentityCoordinates(n_dim=4))
        self.data_collection.append(self.data1)

        # A dataset with the same shape but not linked. Shape is (6, 7, 8, 9).
        self.data2 = Data(x=self.array)
        self.data_collection.append(self.data2)

        # A dataset with the same number of dimesnions but in a different
        # order, linked to the first. Shape is (9, 7, 6, 8).
        self.data3 = Data(x=np.moveaxis(self.array, (3, 1, 0, 2), (0, 1, 2, 3)))
        self.data_collection.append(self.data3)
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[0],
                                               self.data3.pixel_component_ids[2]))
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[1],
                                               self.data3.pixel_component_ids[1]))
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[2],
                                               self.data3.pixel_component_ids[3]))
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[3],
                                               self.data3.pixel_component_ids[0]))

        # A dataset with fewer dimensions, linked to the first one. Shape is
        # (8, 7, 6)
        self.data4 = Data(x=self.array[:, :, :, 0].transpose())
        self.data_collection.append(self.data4)
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[0],
                                               self.data4.pixel_component_ids[2]))
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[1],
                                               self.data4.pixel_component_ids[1]))
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[2],
                                               self.data4.pixel_component_ids[0]))

        # A dataset with even fewer dimensions, linked to the first one. Shape
        # is (8, 6)
        self.data5 = Data(x=self.array[:, 0, :, 0].transpose())
        self.data_collection.append(self.data5)
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[0],
                                               self.data5.pixel_component_ids[1]))
        self.data_collection.add_link(LinkSame(self.data1.pixel_component_ids[2],
                                               self.data5.pixel_component_ids[0]))

        # A dataset that is not on the same pixel grid and requires reprojection
        self.data6 = Data()
        self.data6.coords = SimpleCoordinates()
        self.array_nonaligned = np.arange(60).reshape((5, 3, 4))
        self.data6['x'] = np.array(self.array_nonaligned)
        self.data_collection.append(self.data6)
        self.data_collection.add_link(LinkSame(self.data1.world_component_ids[0],
                                               self.data6.world_component_ids[1]))
        self.data_collection.add_link(LinkSame(self.data1.world_component_ids[1],
                                               self.data6.world_component_ids[2]))
        self.data_collection.add_link(LinkSame(self.data1.world_component_ids[2],
                                               self.data6.world_component_ids[0]))

        self.viewer_state = ImageViewerState()
        self.viewer_state.layers.append(ImageLayerState(viewer_state=self.viewer_state, layer=self.data1))
        self.viewer_state.layers.append(ImageLayerState(viewer_state=self.viewer_state, layer=self.data2))
        self.viewer_state.layers.append(ImageLayerState(viewer_state=self.viewer_state, layer=self.data3))
        self.viewer_state.layers.append(ImageLayerState(viewer_state=self.viewer_state, layer=self.data4))
        self.viewer_state.layers.append(ImageLayerState(viewer_state=self.viewer_state, layer=self.data5))
        self.viewer_state.layers.append(ImageLayerState(viewer_state=self.viewer_state, layer=self.data6))

        self.viewer_state.reference_data = self.data1

    def test_default_axis_order(self):

        # Start off with a combination of x/y that means that only one of the
        # other datasets will be matched.

        self.viewer_state.x_att = self.data1.pixel_component_ids[3]
        self.viewer_state.y_att = self.data1.pixel_component_ids[2]
        self.viewer_state.slices = (3, 2, 4, 1)

        image = self.viewer_state.layers[0].get_sliced_data()
        assert_equal(image, self.array[3, 2, :, :])

        with pytest.raises(IncompatibleAttribute):
            self.viewer_state.layers[1].get_sliced_data()

        image = self.viewer_state.layers[2].get_sliced_data()
        assert_equal(image, self.array[3, 2, :, :])

        with pytest.raises(IncompatibleDataException):
            self.viewer_state.layers[3].get_sliced_data()

        with pytest.raises(IncompatibleDataException):
            self.viewer_state.layers[4].get_sliced_data()

    def test_transpose_axis_order(self):

        # Next make it so the x/y axes correspond to the dimensions with length
        # 6 and 8 which most datasets will be compatible with, and this also
        # requires a tranposition.

        self.viewer_state.x_att = self.data1.pixel_component_ids[0]
        self.viewer_state.y_att = self.data1.pixel_component_ids[2]
        self.viewer_state.slices = (3, 2, 4, 1)

        image = self.viewer_state.layers[0].get_sliced_data()
        print(image.shape)
        assert_equal(image, self.array[:, 2, :, 1].transpose())

        with pytest.raises(IncompatibleAttribute):
            self.viewer_state.layers[1].get_sliced_data()

        image = self.viewer_state.layers[2].get_sliced_data()
        print(image.shape)
        assert_equal(image, self.array[:, 2, :, 1].transpose())

        image = self.viewer_state.layers[3].get_sliced_data()
        assert_equal(image, self.array[:, 2, :, 0].transpose())

        image = self.viewer_state.layers[4].get_sliced_data()
        assert_equal(image, self.array[:, 0, :, 0].transpose())

    def test_transpose_axis_order_view(self):

        # As for the previous test, but this time with a view applied

        self.viewer_state.x_att = self.data1.pixel_component_ids[0]
        self.viewer_state.y_att = self.data1.pixel_component_ids[2]
        self.viewer_state.slices = (3, 2, 4, 1)

        view = [slice(1, None, 2), slice(None, None, 3)]

        image = self.viewer_state.layers[0].get_sliced_data(view=view)
        assert_equal(image, self.array[::3, 2, 1::2, 1].transpose())

        with pytest.raises(IncompatibleAttribute):
            self.viewer_state.layers[1].get_sliced_data(view=view)

        image = self.viewer_state.layers[2].get_sliced_data(view=view)
        print(image.shape)
        assert_equal(image, self.array[::3, 2, 1::2, 1].transpose())

        image = self.viewer_state.layers[3].get_sliced_data(view=view)
        assert_equal(image, self.array[::3, 2, 1::2, 0].transpose())

        image = self.viewer_state.layers[4].get_sliced_data(view=view)
        assert_equal(image, self.array[::3, 0, 1::2, 0].transpose())

    def test_reproject(self):

        # Test a case where the data needs to actually be reprojected

        # As for the previous test, but this time with a view applied

        self.viewer_state.x_att = self.data1.pixel_component_ids[0]
        self.viewer_state.y_att = self.data1.pixel_component_ids[2]
        self.viewer_state.slices = (3, 2, 4, 1)

        view = [slice(1, None, 2), slice(None, None, 3)]

        actual = self.viewer_state.layers[5].get_sliced_data(view=view)

        # The data to be reprojected is 3-dimensional. The axes we have set
        # correspond to 1 (for x) and 0 (for y). The third dimension of the
        # data to be reprojected should be sliced. This is linked with the
        # second dimension of the original data, for which the slice index is
        # 2. Since the data to be reprojected has coordinates that are 2.5 times
        # those of the reference data, this means the slice index should be 0.8,
        # which rounded corresponds to 1.
        expected = self.array_nonaligned[:, :, 1]

        # Now in the frame of the reference data, the data to show are indices
        # [0, 3] along x and [1, 3, 5, 7] along y. Applying the transformation,
        # this gives values of [0, 1.2] and [0.4, 1.2, 2, 2.8] for x and y,
        # and rounded, this gives [0, 1] and [0, 1, 2, 3]. As a reminder, in the
        # data to reproject, dimension 0 is y and dimension 1 is x
        expected = expected[:4, :2]

        # Let's make sure this works!
        assert_equal(actual, expected)

    def test_too_many_dimensions(self):

        # If we change the reference data, then the first dataset won't be
        # visible anymore because it has too many dimensions

        self.viewer_state.reference_data = self.data4

        with pytest.raises(IncompatibleAttribute):
            self.viewer_state.layers[0].get_sliced_data()

        self.viewer_state.reference_data = self.data6

        with pytest.raises(IncompatibleAttribute):
            self.viewer_state.layers[0].get_sliced_data()


def test_update_x_att_and_y_att():

    # Regression test for a bug that caused y_att to not be updated before
    # events were sent out about x_att changing.

    viewer_state = ImageViewerState()

    data1 = Data(x=np.ones((3, 4, 5)))

    layer_state1 = ImageLayerState(layer=data1, viewer_state=viewer_state)
    viewer_state.layers.append(layer_state1)

    data2 = Data(x=np.ones((3, 4, 5)))

    layer_state2 = ImageLayerState(layer=data2, viewer_state=viewer_state)
    viewer_state.layers.append(layer_state2)

    def check_consistency(*args, **kwargs):
        # Make sure that x_att and y_att are always for same dataset
        assert viewer_state.x_att.parent is viewer_state.y_att.parent

    viewer_state.add_global_callback(check_consistency)
    viewer_state.add_callback('x_att', check_consistency)
    viewer_state.add_callback('y_att', check_consistency)
    viewer_state.add_callback('x_att_world', check_consistency)
    viewer_state.add_callback('y_att_world', check_consistency)
    viewer_state.add_callback('slices', check_consistency)

    viewer_state.reference_data = data1
    assert viewer_state.x_att is data1.pixel_component_ids[2]
    assert viewer_state.y_att is data1.pixel_component_ids[1]

    viewer_state.reference_data = data2
    assert viewer_state.x_att is data2.pixel_component_ids[2]
    assert viewer_state.y_att is data2.pixel_component_ids[1]


def test_attribute_units():

    # Unit test to make sure that the unit conversion works correctly for
    # v_min/v_max.

    viewer_state = ImageViewerState()

    data1 = Data(x=np.arange(100).reshape((10, 10)))
    data1.get_component('x').units = 'km'

    layer_state1 = ImageLayerState(layer=data1, viewer_state=viewer_state)
    viewer_state.layers.append(layer_state1)

    assert layer_state1.percentile == 100
    assert layer_state1.v_min == 0
    assert layer_state1.v_max == 99

    layer_state1.attribute_display_unit = 'm'

    assert layer_state1.v_min == 0
    assert layer_state1.v_max == 99000

    assert layer_state1.percentile == 100

    layer_state1.percentile = 95

    assert_allclose(layer_state1.v_min, 2475)
    assert_allclose(layer_state1.v_max, 96525)

    assert layer_state1.percentile == 95

    layer_state1.attribute_display_unit = 'km'

    assert_allclose(layer_state1.v_min, 2.475)
    assert_allclose(layer_state1.v_max, 96.525)

    layer_state1.attribute_display_unit = 'm'

    layer_state1.v_max = 50000

    assert layer_state1.percentile == 'Custom'

    layer_state1.attribute_display_unit = 'km'

    assert_allclose(layer_state1.v_min, 2.475)
    assert_allclose(layer_state1.v_max, 50)


def test_stretch_global():

    # Test the option of using global vs per-slice stretch

    viewer_state = ImageViewerState()

    data1 = Data(x=np.arange(1000).reshape((10, 10, 10)))

    layer_state = ImageLayerState(layer=data1, viewer_state=viewer_state)
    viewer_state.layers.append(layer_state)

    assert layer_state.stretch_global is True

    assert layer_state.percentile == 100
    assert layer_state.v_min == 0
    assert layer_state.v_max == 999

    assert viewer_state.slices == (0, 0, 0)

    layer_state.stretch_global = False

    assert layer_state.v_min == 0
    assert layer_state.v_max == 99

    viewer_state.slices = (9, 0, 0)

    assert layer_state.v_min == 900
    assert layer_state.v_max == 999

    layer_state.percentile = 90

    assert layer_state.v_min == 904.95
    assert layer_state.v_max == 994.05

    layer_state.stretch_global = True

    assert layer_state.v_min == 49.95
    assert layer_state.v_max == 949.05


def test_attribute_units_invalid():

    # Regression test for a bug that caused a crash if a dataset had an
    # unrecognized unit

    viewer_state = ImageViewerState()

    data = Data(x=np.arange(100).reshape((10, 10)))
    data.get_component('x').units = 'banana'

    ImageLayerState(layer=data, viewer_state=viewer_state)
