Skip to content
Snippets Groups Projects
datasetgenerate.py 2.68 KiB
Newer Older
  • Learn to ignore specific revisions
  • import collections
    import pathlib
    import random
    import os
    import pickle
    from typing import Dict, Tuple, Sequence
    
    import cv2
    from skimage.color import rgb2lab, lab2rgb
    import matplotlib
    import matplotlib.pyplot as plt
    import numpy as np
    from tqdm import tqdm
    
    import torch
    from torch import nn
    import torch.nn.functional as F
    from torchvision import transforms
    from torch.autograd import Variable
    
    def extract_dominant_colors(img):
        rgb_img = img.reshape(-1, 3)
    
        criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
        k = 5
        ret, label, center = cv2.kmeans(np.float32(rgb_img), k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    
        unique, counts = np.unique(label, return_counts=True)
        dominant_colors = center[unique].tolist()
        percentages = counts / sum(counts) * 100
    
        palette = []
        for color in dominant_colors:
            hex_code = "#{:02x}{:02x}{:02x}".format(*map(int, color))
            palette.append(hex_code)
    
        return palette
    
    
    def viz_color_palette(hexcodes):
    
        palette = []
        for hexcode in hexcodes:
            rgb = np.array(list(int(hexcode.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)))
            palette.append(rgb)
    
        palette = np.array(palette)[np.newaxis, :, :]
        return palette
    
    def augment_image(img, title, hue_shift):
    
        img_HSV = matplotlib.colors.rgb_to_hsv(img)
        a_2d_index = np.array([[1,0,0] for _ in range(img_HSV.shape[1])]).astype('bool')
        img_HSV[:, a_2d_index] = (img_HSV[:, a_2d_index] + hue_shift) % 1
    
        new_img = matplotlib.colors.hsv_to_rgb(img_HSV).astype(int)
        plt.imshow(new_img)
        plt.title(f"New {title} (in RGB)")
        plt.show()
    
        img = img.astype(np.float) / 255.0
        new_img = new_img.astype(np.float) / 255.0
        ori_img_LAB = rgb2lab(img)
        new_img_LAB = rgb2lab(new_img)
        new_img_LAB[:, :, 0] = ori_img_LAB[:, :, 0]
        new_img_augmented = (lab2rgb(new_img_LAB)*255.0).astype(int)
        plt.imshow(new_img_augmented)
        plt.title(f"New {title} (in RGB) with Fixed Luminance")
        plt.show()
        plt.close()
    
        return new_img_augmented
    
    images = [f for f in os.listdir() if f.endswith(".jpg")]
    for path in images:
    
        img = cv2.imread(str(path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        palette = viz_color_palette(extract_dominant_colors(img))
    
        hue_shift = random.random()
        augmented_image = augment_image(img, "Image", hue_shift)
        augmented_palette = augment_image(palette, "Palette", hue_shift)
    
    
    Kirk Ford's avatar
    Kirk Ford committed
        path_stem = path[:-4]
    
    Kirk Ford's avatar
    Kirk Ford committed
        cv2.imwrite(f'data/train/input/{path}', img)
        pickle.dump(palette, open(f'data/train/old_palette/{path_stem}.pkl', 'wb'))
        cv2.imwrite(f'data/train/output/{path}', augmented_image)
        pickle.dump(augmented_palette, open(f'data/train/new_palette/{path_stem}.pkl', 'wb'))