from __future__ import annotations

from linalg.Vector import Vector

import typing
import numpy as np
import numbers


class Matrix:
    elements: np.ndarray

    def __init__(self, *elements: ARRAY_LIKE) -> None:
        self.elements = np.array(elements)

        if self.elements.ndim != 2:
            raise ValueError('Matrix can\'t have different lengths or shapes and must be two-dimensional')

    def __add__(self, addend: ADDEND) -> Matrix:
        if isinstance(addend, np.ndarray) and not type(addend) == TYPE_NP_MATRIX:
            raise TypeError('Argument must be of type ' + str(ADDEND))

        if isinstance(addend, typing.List):
            return self.__add__(Matrix(*addend))

        if isinstance(addend, np.matrix):
            return self.__add__(Matrix(*np.array(addend)))

        if not isinstance(addend, Matrix):
            raise TypeError('Argument must be of type ' + str(ADDEND))

        if self.elements.shape != addend.elements.shape:
            raise ValueError('Matrices types are not compatible')

        return Matrix(*(self.elements + addend.elements))

    def __radd__(self, addend: ADDEND) -> Matrix:
        return self.__add__(addend)

    def __mul__(self, factor: FACTOR) -> typing.Union[Vector, Matrix]:
        if isinstance(factor, np.ndarray) and not type(factor) == TYPE_NP_MATRIX:
            raise TypeError('Argument must be of type ' + str(FACTOR))

        if isinstance(factor, numbers.Number):
            return Matrix(*(self.elements * factor))

        if isinstance(factor, Vector):
            return Vector(*self.elements.dot(factor.elements))

        if isinstance(factor, typing.List) and np.array(factor).ndim == 1:
            return Vector(*self.elements.dot(np.array(factor)))

        if isinstance(factor, typing.List):
            return Matrix(*(self.elements @ factor))

        if isinstance(factor, Matrix):
            return Matrix(*(self.elements @ factor.elements))

        raise ValueError("Unexpected error in Matrix class (mul)")

    def __rmul__(self, factor: FACTOR) -> typing.Union[Vector, Matrix]:
        if isinstance(factor, np.ndarray) and not type(factor) == TYPE_NP_MATRIX:
            raise TypeError('Argument must be of type ' + str(FACTOR))

        if isinstance(factor, Vector):
            raise TypeError('Multiplication with vector from left is not supported')

        if isinstance(factor, typing.List) and np.array(factor).ndim == 1:
            raise TypeError('Multiplication with vector from left is not supported')

        if isinstance(factor, numbers.Number):
            return Matrix(*(self.elements * factor))

        if isinstance(factor, typing.List):
            return Matrix(*(np.array(factor) @ self.elements))

        if isinstance(factor, Matrix):
            return Matrix(*(factor.elements @ factor.elements))

        raise ValueError("Unexpected error in Matrix class (mul)")

    def transpose(self) -> Matrix:
        return Matrix(*np.transpose(self.elements))

    def homog(self):
        numRows, numCols = self.elements.shape
        newCol = np.zeros(numRows)
        newRow = np.append(np.zeros(numCols), 1)
        return self.appendCol(newCol).appendRow(newRow)

    def inv(self) -> Matrix:
        return Matrix(*np.linalg.inv(self.elements))

    def getRow(self, n: int):
        return self.elements[n]

    def getCol(self, n: int):
        return self.elements[:, n]

    def appendCol(self, column: ARRAY_LIKE) -> Matrix:
        c = [[x] for x in column]
        return Matrix(*np.append(self.elements, c, axis=1))

    def appendRow(self, row: ARRAY_LIKE) -> Matrix:
        return Matrix(*np.vstack([self.elements, row]))

    def dropCol(self, n: int) -> Matrix:
        return Matrix(*np.delete(self.elements, n, axis=1))

    def dropRow(self, n: int) -> Matrix:
        return Matrix(*np.delete(self.elements, n, axis=0))

    def insertCol(self, pos: int, column: ARRAY_LIKE) -> Matrix:
        return Matrix(*np.insert(self.elements, pos, column, axis=1))

    def insertRow(self, pos: int, column: ARRAY_LIKE) -> Matrix:
        return Matrix(*np.insert(self.elements, pos, column, axis=0))

    def __getitem__(self, key):
        return self.elements[key]

    def __eq__(self, other):
        if not isinstance(other, Matrix):
            return self == Matrix(*other)

        return np.all(self.elements == other.elements)

    def __len__(self):
        return len(self.elements)

    def __str__(self):
        return str(self.elements)


TYPE_NP_MATRIX: type = type(np.matrix([]))
ARRAY_LIKE = typing.Union[typing.List[float], typing.List[int], typing.List[np.ndarray], np.ndarray]
MATRIX_LIKE = typing.Union[typing.List[ARRAY_LIKE], np.matrix, Matrix]
ADDEND = MATRIX_LIKE
FACTOR = typing.Union[typing.List[typing.List[float]],
                      typing.List[typing.List[int]],
                      np.ndarray, np.matrix, Matrix, ARRAY_LIKE, Vector, float]
