from linalg.Vector import Vector
from linalg.Matrix import Matrix

import unittest
import numpy as np
import warnings


class MatrixClass(unittest.TestCase):
    def test_init_matrix(self):
        Matrix([1, 2], [3, 3], [3, 4])

        self.assertTrue(True)

    def test_init_raises(self):
        warnings.simplefilter('ignore')
        self.assertRaises(ValueError, lambda: Matrix([1, 2, 1], [3, 3], [3, 4]))
        self.assertRaises(ValueError, lambda: Matrix([[1, 2, 3]]))

    def test_addition_with_scalar_raises(self):
        self.assertRaises(TypeError, lambda: Matrix([1, 2], [2, 1]) + 1)
        self.assertRaises(TypeError, lambda: 1 + Matrix([1, 2], [2, 1]))

    def test_addition_of_incompatible_matrices_raises(self):
        M = Matrix([1, 2], [2, 3])
        N = Matrix([1, 2], [1, -3], [0, 0])

        self.assertRaises(ValueError, lambda: M + N)
        self.assertRaises(ValueError, lambda: N + M)
        self.assertRaises(ValueError, lambda: M + [[1, 2], [1, -3], [0, 0]])
        self.assertRaises(ValueError, lambda: [[1, 2], [2, 3]] + N)

    def test_addition_with_matrix(self):
        M1 = Matrix([1, 2], [2, 3])
        M2 = Matrix([1, 2], [1, -3])
        M1_plus_M2 = Matrix([2, 4], [3, 0])

        self.assertTrue((M1 + M2 == M2 + M1).all())
        self.assertTrue((M1 + M2 == M1_plus_M2).all())
        self.assertTrue((M1 + [[1, 2], [1, -3]] == M1_plus_M2).all())
        self.assertTrue(([[1, 2], [2, 3]] + M2 == M1_plus_M2).all())

    def test_multiply_with_scalar(self):
        M = Matrix([1, 2, 3],
                   [2, 0, 1])

        N = Matrix([-2, -4, -6],
                   [-4, 0, -2])

        Z = Matrix([0, 0, 0],
                   [0, 0, 0])

        self.assertTrue(M * 0 == Z)
        self.assertTrue(1 * M == M * 1)
        self.assertTrue(1 * M == M)
        self.assertTrue(2 * M == [[2, 4, 6], [4, 0, 2]])
        self.assertTrue((-2) * M == N)

    def test_multiply_with_incompatible_matrix_raises(self):
        M = Matrix([1, 2],
                   [2, 0])

        N = Matrix([1, 2],
                   [2, 0],
                   [1, 1])

        self.assertRaises(ValueError, lambda: M * N)
        self.assertRaises(ValueError, lambda: M * [[1, 2], [2, 0], [1, 1]])

    def test_multiply_with_matrix(self):
        M = Matrix([1, 2],
                   [2, 0])

        N = Matrix([3, 1],
                   [2, 3])

        MN = Matrix([7, 7], [6, 2])
        NM = Matrix([5, 6], [8, 4])

        self.assertTrue(M * N == MN)
        self.assertTrue(N * M == NM)
        self.assertIsInstance(M * [[3, 1], [2, 3]], Matrix)

    def test_multiply_with_vector_from_right(self):
        M = Matrix([1, 2],
                   [2, 0])

        V = Vector(-1, 5)
        MV = Vector(9, -2)

        self.assertIsInstance(M * V, Vector)
        self.assertTrue(M * V == MV)

    def test_multiply_distributivity(self):
        A = Matrix([1, 2],
                   [2, 0])

        B = Matrix([3, 1],
                   [2, 3])

        C = Matrix([-1],
                   [2])

        self.assertTrue((A * B) * C == A * (B * C))

    def test_transpose_method(self):
        M = Matrix([1, 2],
                   [-2, 0])

        self.assertTrue((M.transpose() == [[1, -2], [2, 0]]))
        self.assertIsInstance(M.transpose(), Matrix)

    def test_getItem_implementation(self):
        M = Matrix([1, 2],
                   [-2, 0])

        self.assertTrue((M[0] == [1, 2]).all())
        self.assertTrue((M[1] == [-2, 0]).all())

    def test_getRow_method(self):
        M = Matrix([1, 2],
                   [-2, 0])

        self.assertTrue((M.getRow(0) == [1, 2]).all())
        self.assertTrue((M.getRow(1) == [-2, 0]).all())

    def test_getCol_method(self):
        M = Matrix([1, 2],
                   [-2, 0])

        self.assertTrue((M.getCol(0) == [1, -2]).all())
        self.assertTrue((M.getCol(1) == [2, 0]).all())

    def test_str_implementation(self):
        M = Matrix([1, 2],
                   [-2, 0])

        self.assertEqual(str(M), str(np.array([[1, 2], [-2, 0]])))

    def test_transform_matrix_in_homogeneous_coordinates(self):
        M = Matrix([1, 2],
                   [-2, 0])

        M_h = Matrix([1, 2, 0],
                     [-2, 0, 0],
                     [0, 0, 1])

        N = Matrix([0, 0],
                   [0, 0],
                   [0, 0])

        N_h = Matrix([0, 0, 0],
                     [0, 0, 0],
                     [0, 0, 0],
                     [0, 0, 1])

        self.assertTrue(M.homog() == M_h)
        self.assertTrue(N.homog() == N_h)

# TODO: appendRow, appendCol, inv