from __future__ import annotations

import typing

import imageio
import matplotlib.pyplot as plt
import numpy as np
from scipy import fftpack

from linalg.Matrix import Matrix
from linalg.Vector import Vector
from linalg.linmap.maps import identity


def dct2Matrix(N: int) -> Matrix:
    M = identity(N)
    for k in range(0, N):
        for n in range(0, N):
            M[k][n] = np.cos((np.pi * (n + 0.5) * k) / N)
    return M


def dct3Matrix(N: int) -> Matrix:
    return dct2Matrix(N).inv()


def dct2(v: ARRAY_LIKE) -> Vector:
    return dct2Matrix(len(v)) * Vector(*v)


def dct3(v: ARRAY_LIKE) -> Vector:
    return dct3Matrix(len(v)) * Vector(*v)


def dct2_2d(m: MATRIX_LIKE) -> np.ndarray:
    return fftpack.dct(fftpack.dct(m, axis=0, norm='ortho'), axis=1, norm='ortho')


def dct3_2d(m: MATRIX_LIKE) -> np.ndarray:
    return fftpack.idct(fftpack.idct(m, axis=0, norm='ortho'), axis=1, norm='ortho')


def dct2_blockWise(filename: str, threshold=0.01) -> np.ndarray:
    image = imageio.imread(filename)
    imageSize = image.shape
    dct_image = np.zeros(imageSize)
    for i in np.r_[:imageSize[0]:8]:
        for j in np.r_[:imageSize[1]:8]:
            dct_image[i:(i + 8), j:(j + 8)] = dct2_2d(image[i:(i + 8), j:(j + 8)])

    # Threshold
    dct_threshold = dct_image * (abs(dct_image) > (threshold * np.max(dct_image)))

    for i in np.r_[:imageSize[0]:8]:
        for j in np.r_[:imageSize[1]:8]:
            dct_image[i:(i + 8), j:(j + 8)] = dct3_2d(dct_threshold[i:(i + 8), j:(j + 8)])

    keeping = np.sum(dct_threshold != 0.0) / (imageSize[0] * imageSize[1] * 1.0)
    print('filename: %s, keeping: %f%%' % (filename, keeping * 100))

    return dct_image


def plot_dct2_image(filename: str, threshold=0.01, destination=None):
    orig = imageio.imread(filename)
    image_dct = dct2_blockWise(filename, threshold)
    fig, ax = plt.subplot_mosaic([['orig', 'dct']])

    # Original image
    ax['orig'].imshow(orig, cmap='gray')
    ax['orig'].axis('off')
    ax['orig'].set_title('Original')
    # DCT Image
    ax['dct'].imshow(image_dct, cmap='gray')
    ax['dct'].axis('off')
    ax['dct'].set_title('DCT (threshold=%f)' % threshold)

    if destination:
        plt.savefig(destination, bbox_inches='tight')

    plt.show()


ARRAY_LIKE = typing.Union[typing.List[float], np.ndarray, Vector]
MATRIX_LIKE = typing.Union[typing.List[ARRAY_LIKE], np.matrix, Matrix]
