import collections
import pathlib
import random
import os
import pickle
from typing import Dict, Tuple, Sequence

import cv2
from skimage.color import rgb2lab, lab2rgb
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torch.autograd import Variable

def extract_dominant_colors(img):
    rgb_img = img.reshape(-1, 3)

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    k = 5
    ret, label, center = cv2.kmeans(np.float32(rgb_img), k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)

    unique, counts = np.unique(label, return_counts=True)
    dominant_colors = center[unique].tolist()
    percentages = counts / sum(counts) * 100

    palette = []
    for color in dominant_colors:
        hex_code = "#{:02x}{:02x}{:02x}".format(*map(int, color))
        palette.append(hex_code)

    return palette


def viz_color_palette(hexcodes):

    palette = []
    for hexcode in hexcodes:
        rgb = np.array(list(int(hexcode.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)))
        palette.append(rgb)

    palette = np.array(palette)[np.newaxis, :, :]
    return palette

def augment_image(img, title, hue_shift):

    img_HSV = matplotlib.colors.rgb_to_hsv(img)
    a_2d_index = np.array([[1,0,0] for _ in range(img_HSV.shape[1])]).astype('bool')
    img_HSV[:, a_2d_index] = (img_HSV[:, a_2d_index] + hue_shift) % 1

    new_img = matplotlib.colors.hsv_to_rgb(img_HSV).astype(int)
    plt.imshow(new_img)
    plt.title(f"New {title} (in RGB)")
    plt.show()

    img = img.astype(np.float) / 255.0
    new_img = new_img.astype(np.float) / 255.0
    ori_img_LAB = rgb2lab(img)
    new_img_LAB = rgb2lab(new_img)
    new_img_LAB[:, :, 0] = ori_img_LAB[:, :, 0]
    new_img_augmented = (lab2rgb(new_img_LAB)*255.0).astype(int)
    plt.imshow(new_img_augmented)
    plt.title(f"New {title} (in RGB) with Fixed Luminance")
    plt.show()
    plt.close()

    return new_img_augmented

images = [f for f in os.listdir() if f.endswith(".jpg")]
for path in images:

    img = cv2.imread(str(path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    palette = viz_color_palette(extract_dominant_colors(img))

    hue_shift = random.random()
    augmented_image = augment_image(img, "Image", hue_shift)
    augmented_palette = augment_image(palette, "Palette", hue_shift)

    path_stem = path[:-4]

    cv2.imwrite(f'data/train/input/{path}', img)
    pickle.dump(palette, open(f'data/train/old_palette/{path_stem}.pkl', 'wb'))
    cv2.imwrite(f'data/train/output/{path}', augmented_image)
    pickle.dump(augmented_palette, open(f'data/train/new_palette/{path_stem}.pkl', 'wb'))