import matplotlib.pyplot as plt # type: ignore
import matplotlib as mpl # type: ignore
import numpy as np
from linalg.Matrix import Matrix

PLOT_DPI = 100
X_RANGE = (-10, 10)
Y_RANGE = (-10, 10)
MAIN_AXIS_COLOR = '#4C4C4C'             # Dark grey
CANONICAL_GRID_COLOR = '#D3D3D3'        # Alto
CANONICAL_X_VECTOR_COLOR = 'red'        # Red
CANONICAL_Y_VECTOR_COLOR = 'seagreen'   # Green
CANONICAL_MAP_COLOR = '#000000'         # Black
LIN_MAP_GRID_COLOR = 'darkblue'         # Dark blue
LIN_MAP_X_VECTOR_COLOR = 'darkred'      # Dark red
LIN_MAP_Y_VECTOR_COLOR = 'darkgreen'    # Dark green
LIN_MAP_MAP_COLOR = '#0000ff'           # Blue

mpl.rcParams['figure.dpi'] = PLOT_DPI


class CoordinateSystem:
    def __init__(self, xRange=X_RANGE, yRange=Y_RANGE, axis='off', drawCanonicalBasis=False):
        self.plot = plt
        self.xRange = xRange
        self.yRange = yRange
        self.points = []
        self.lines = []
        self.arrows = []

        # Default: Turn off axis lines and labels
        plt.axis(axis)

        # Set limits for the x and y axes
        plt.ylim(min(yRange), max(yRange))
        plt.xlim(min(xRange), max(xRange))

        # Set ticks for x and y axes
        plt.xticks(range(min(xRange), max(xRange) + 1))
        plt.yticks(range(min(yRange), max(yRange) + 1))

        # Draw grid
        self.drawGrid(Matrix([1, 0], [0, 1]), color=CANONICAL_GRID_COLOR)
        plt.axhline(y=0, zorder=0, color=MAIN_AXIS_COLOR)
        plt.axvline(x=0, zorder=0, color=MAIN_AXIS_COLOR)

        # Draw canonical position arrows
        if drawCanonicalBasis:
            self.drawPositionArrow(1, 0, color=CANONICAL_X_VECTOR_COLOR, zorder=100)
            self.drawPositionArrow(0, 1, color=CANONICAL_Y_VECTOR_COLOR, zorder=100)

    def drawGrid(self, M: Matrix, **kwargs):
        e1 = M.getCol(0)
        e2 = M.getCol(1)

        for i in range(min(self.yRange), max(self.yRange) + 1):
            xy1 = e2 * i
            xy2 = xy1 + e1
            self.drawInfiniteLine(xy1, xy2, zorder=0, **kwargs)

        # Draw y axis lines (j_head)
        for i in range(min(self.xRange), max(self.xRange) + 1):
            xy1 = e1 * i
            xy2 = xy1 + e2
            self.drawInfiniteLine(xy1, xy2, zorder=0, **kwargs)

    def drawArrow(self, x, y, dx, dy, store=False, width=0.1, head_width=0.3, **kwargs):
        if store:
            self.arrows.append([x, y, dx, dy])

        self.plot.arrow(x, y, dx, dy, width=width, head_width=head_width,
                        length_includes_head=True, **kwargs)

    def drawPositionArrow(self, x, y, store=False, **kwargs):
        x_start = 0
        y_start = 0
        dx = x
        dy = y

        self.drawArrow(x_start, y_start, dx, dy, store=store, **kwargs)

    def drawLine(self, p1, p2, color=CANONICAL_MAP_COLOR, **kwargs):
        self.lines.append([p1, p2])
        self.plot.plot([p1[0], p2[0]], [p1[1], p2[1]], color, **kwargs)

    def drawLines(self, pair, color=CANONICAL_MAP_COLOR, **kwargs):
        for p in pair:
            self.drawLine(p[0], p[1], color, **kwargs)

    def drawInfiniteLine(self, p1, p2, **kwargs):
        self.plot.axline(p1, p2, **kwargs)

    def drawDirectionArrow(self, p1, p2, store=False, **kwargs):
        x = p1[0]
        y = p1[1]
        dx = p2[0] - p1[0]
        dy = p2[1] - p1[1]

        self.drawArrow(x, y, dx, dy, store=store, **kwargs)

    def drawLinearCombination(self, scalar, V, **kwargs):
        outerScope = {
            "currentPosition": np.array([0, 0])
        }

        def f(scalar, V, **kwargs):
            v = np.array(V)
            sign = 1 if scalar >= 0 else -1
            step = abs(scalar)
            newPosition = None
            while step != 0:
                currentPosition = outerScope["currentPosition"]
                if step >= 1:
                    newPosition = currentPosition + sign * v
                    self.drawDirectionArrow(currentPosition, newPosition, **kwargs)
                    step -= 1
                elif step > 0:
                    newPosition = currentPosition + sign * step * v
                    self.drawDirectionArrow(currentPosition, newPosition, **kwargs)
                    step = 0

                outerScope["currentPosition"] = newPosition

        f(scalar, V, **kwargs)

        return f

    def setTransformationMatrix(self, M, gColor=LIN_MAP_GRID_COLOR, tColor=LIN_MAP_MAP_COLOR, **kwargs):
        # Draw i_head lines
        self.drawGrid(M, color=gColor, **kwargs)

        # Draw position arrows
        i_head = M.getCol(0)
        j_head = M.getCol(1)
        self.drawPositionArrow(i_head[0], i_head[1], color=LIN_MAP_X_VECTOR_COLOR, zorder=100)
        self.drawPositionArrow(j_head[0], j_head[1], color=LIN_MAP_Y_VECTOR_COLOR, zorder=100)

        # Transform object
        for i in range(len(self.points)):
            self.drawPoint(M * self.points[i], color=tColor)

        for i in range(len(self.lines)):
            self.drawLine(M * self.lines[i][0], M * self.lines[i][1], color=tColor)

        for i in range(len(self.arrows)):
            arrowStartPosition = M * [self.arrows[i][0], self.arrows[i][1]]
            arrowEndPosition = M * [self.arrows[i][2], self.arrows[i][3]]
            self.drawArrow(*arrowStartPosition, *arrowEndPosition, color=tColor)

    def drawPoint(self, p, color=CANONICAL_MAP_COLOR, marker='+', markersize=1.5, **kwargs):
        self.points.append(p)
        self.plot.plot(p[0], p[1], color, marker=marker, markersize=markersize, **kwargs)

    def drawPoints(self, P, color=CANONICAL_MAP_COLOR, **kwargs):
        for p in P:
            self.drawPoint(p, color, **kwargs)

    def drawCircle(self, n, color=CANONICAL_MAP_COLOR, **kwargs):
        self.drawPoints(CoordinateSystem.getCirclePositions(n), color=color, **kwargs)

    def drawRectangle(self, p1, p2, p3, p4, **kwargs):
        self.drawLine(p1, p2, **kwargs)
        self.drawLine(p2, p3, **kwargs)
        self.drawLine(p3, p4, **kwargs)
        self.drawLine(p4, p1, **kwargs)

    def drawSquare(self, pStart, length, **kwargs):
        p1 = pStart
        p2 = [pStart[0] + length, pStart[1]]
        p3 = [pStart[0] + length, pStart[1] + length]
        p4 = [pStart[0], pStart[1] + length]
        self.drawRectangle(p1, p2, p3, p4, **kwargs)

    def show(self):
        self.plot.show()

    def save(self, filename, dpi=PLOT_DPI, *args, **kwargs):
        self.plot.savefig(filename, dpi=dpi, *args, **kwargs)

    @staticmethod
    def getCirclePositions(n):
        return [(np.cos(phi), np.sin(phi)) for phi in np.linspace(0, 2 * np.pi, n)]
