diff --git a/3d_scatter_animation.mp4 b/3d_scatter_animation.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7664c89f8da8a686c1bdc73b8222732252c51c96
Binary files /dev/null and b/3d_scatter_animation.mp4 differ
diff --git a/datasetgenerate.py b/datasetgenerate.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3c179b7b02361bcf7bd5e2ef61da99da5283808
--- /dev/null
+++ b/datasetgenerate.py
@@ -0,0 +1,720 @@
+#!/usr/bin/env python
+# coding: utf-8
+import collections
+import pathlib
+import random
+import os
+import pickle
+from typing import Dict, Tuple, Sequence
+
+import cv2
+from skimage.color import rgb2lab, lab2rgb
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+from tqdm import tqdm
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torchvision import transforms
+from torch.autograd import Variable
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# image_to_palette: image -> corresponding palettes
+# image_to_palette = collections.defaultdict(set)
+#
+# with open("/datasets/dribbble/dribbble/dribbble_designs - dribbble_designs.tsv", 'r') as tsvfile:
+#     for line in tsvfile:
+#         # print(line.strip().split('\t')) # ['uid', 'url', 'media_path', 'description', 'comments', 'tags', 'color_palette', 'likes', 'saves', 'date', 'collection_time', 'write_date']
+#         _, _, media_path, _, _, _, color_palette, *_ = line.strip().split('\t')
+#         media_filename = media_path.split('/')[-1]
+#         image_to_palette[media_filename] = set(color_palette[1:-1].split(','))
+
+def extract_dominant_colors(img):
+    # Convert the image to a 2D array of RGB values
+    rgb_img = img.reshape(-1, 3)
+
+    # Use KMeans clustering to extract dominant colors
+    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
+    k = 5
+    ret, label, center = cv2.kmeans(np.float32(rgb_img), k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
+
+    # Get the dominant colors and their percentages
+    unique, counts = np.unique(label, return_counts=True)
+    dominant_colors = center[unique].tolist()
+    percentages = counts / sum(counts) * 100
+
+    # Create a palette list with the dominant colors and their hex codes
+    palette = []
+    for color in dominant_colors:
+        hex_code = "#{:02x}{:02x}{:02x}".format(*map(int, color))
+        palette.append(hex_code)
+
+    return palette
+
+# In[5]:
+
+
+# get data
+#    - visualize image
+#    - visualize the color palette
+
+def viz_color_palette(hexcodes):
+    """
+    visualize color palette
+    """
+    # hexcodes = list(hexcodes)
+    # while len(hexcodes) < 6:
+    #     hexcodes = hexcodes + hexcodes
+    # hexcodes = hexcodes[:6]
+
+    palette = []
+    for hexcode in hexcodes:
+        rgb = np.array(list(int(hexcode.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)))
+        palette.append(rgb)
+
+    palette = np.array(palette)[np.newaxis, :, :]
+    return palette
+
+
+# def viz_image(path, image_to_palette: Dict):
+#     """
+#     visualize image
+#     visualize palette (using image_to_palette)
+#     """
+#     assert pathlib.Path(path).name in image_to_palette
+#     img = cv2.imread(str(path))
+#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+#     palette = viz_color_palette(image_to_palette[pathlib.Path(path).name])
+#
+#     # visualize image
+#     plt.imshow(img)
+#     plt.show()
+#
+#     # visualize color palette
+#     print(palette.shape)
+#     plt.imshow(palette)
+#     # print(palette.shape)
+#     plt.axis('off')
+#     plt.show()
+#
+#     return
+
+
+# pathlist = pathlib.Path("/datasets/dribbble_half/dribbble_half/data").glob("*.png")
+# for path in pathlist:
+#     viz_image(path, image_to_palette)
+#     break
+
+
+# ### Generate Dataset
+
+# In[ ]:
+
+
+# generate augmented images for training
+#    - input/: images with original palette
+#    - output/: images with new palette
+#    - old_palette/: pickled files of original palette 
+#    - new_palette/: pickled files of new palette 
+
+def augment_image(img, title, hue_shift):
+    # plt.imshow(img)
+    # plt.title(f"Original {title} (in RGB)")
+    # plt.show()
+
+    # RGB -> HSV -> hue-shift 
+    img_HSV = matplotlib.colors.rgb_to_hsv(img)
+    a_2d_index = np.array([[1,0,0] for _ in range(img_HSV.shape[1])]).astype('bool')
+    img_HSV[:, a_2d_index] = (img_HSV[:, a_2d_index] + hue_shift) % 1
+
+    new_img = matplotlib.colors.hsv_to_rgb(img_HSV).astype(int)
+    plt.imshow(new_img)
+    plt.title(f"New {title} (in RGB)")
+    plt.show()
+
+    # fixed original luminance
+    img = img.astype(np.float) / 255.0
+    new_img = new_img.astype(np.float) / 255.0
+    ori_img_LAB = rgb2lab(img)
+    new_img_LAB = rgb2lab(new_img)
+    new_img_LAB[:, :, 0] = ori_img_LAB[:, :, 0]
+    new_img_augmented = (lab2rgb(new_img_LAB)*255.0).astype(int)
+    plt.imshow(new_img_augmented)
+    plt.title(f"New {title} (in RGB) with Fixed Luminance")
+    plt.show()
+    plt.close()
+
+    return new_img_augmented
+
+images = [f for f in os.listdir() if f.endswith(".jpg")]
+#print(images)
+for path in images:
+    #print(path)
+    # assert pathlib.Path(path).name in image_to_palette
+    img = cv2.imread(str(path))
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    palette = viz_color_palette(extract_dominant_colors(img))
+
+    hue_shift = random.random()
+    augmented_image = augment_image(img, "Image", hue_shift)
+    augmented_palette = augment_image(palette, "Palette", hue_shift)
+
+    file_input = path[:-4] + '_converted_to_BGR'
+    file_output = path[:-4] + '_shift_luminance_corrected'
+
+    cv2.imwrite(file_input + '.jpg', img)
+    pickle.dump(palette, open(f'data/train/old_palette/{file_input}.pkl', 'wb'))
+    cv2.imwrite(file_output + '.jpg', augmented_image)
+    pickle.dump(augmented_palette, open(f'data/train/new_palette/{file_output}.pkl', 'wb'))
+    #print("this is done")
+
+
+# ### Feature Encoder (FE) and Recoloring Decoder (RD)
+
+# In[6]:
+
+
+# image = cv2.imread("/home/jovyan/work/data/train/input/0a02bab21de25f3ea1345dacf2a23300.png")
+# plt.imshow(image)
+# plt.show()
+#
+#
+# # In[7]:
+#
+#
+# from functools import partial
+# class Conv2dAuto(nn.Conv2d):
+#     def __init__(self, *args, **kwargs):
+#         super().__init__(*args, **kwargs)
+#         self.padding =  (self.kernel_size[0] // 2, self.kernel_size[1] // 2) # dynamic add padding based on the kernel_size
+#
+# conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)
+#
+# def activation_func(activation):
+#     return  nn.ModuleDict([
+#         ['relu', nn.ReLU(inplace=True)],
+#         ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
+#         ['selu', nn.SELU(inplace=True)],
+#         ['none', nn.Identity()]
+#     ])[activation]
+#
+# def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
+#     return nn.Sequential(conv(in_channels, out_channels, *args, **kwargs), nn.InstanceNorm2d(out_channels))
+#
+# class ResidualBlock(nn.Module):
+#     def __init__(self, in_channels, out_channels, activation='relu'):
+#         super().__init__()
+#         self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
+#         self.blocks = nn.Identity()
+#         self.activate = activation_func(activation)
+#         self.shortcut = nn.Identity()
+#
+#     def forward(self, x):
+#         residual = x
+#         if self.should_apply_shortcut: residual = self.shortcut(x)
+#         x = self.blocks(x)
+#         x += residual
+#         x = self.activate(x)
+#         return x
+#
+#     @property
+#     def should_apply_shortcut(self):
+#         return self.in_channels != self.out_channels
+#
+# class ResNetResidualBlock(ResidualBlock):
+#     def __init__(self, in_channels, out_channels, expansion=1, downsampling=1, conv=conv3x3, *args, **kwargs):
+#         super().__init__(in_channels, out_channels, *args, **kwargs)
+#         self.expansion, self.downsampling, self.conv = expansion, downsampling, conv
+#         self.shortcut = nn.Sequential(
+#             nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
+#                       stride=self.downsampling, bias=False),
+#             nn.BatchNorm2d(self.expanded_channels)) if self.should_apply_shortcut else None
+#
+#
+#     @property
+#     def expanded_channels(self):
+#         return self.out_channels * self.expansion
+#
+#     @property
+#     def should_apply_shortcut(self):
+#         return self.in_channels != self.expanded_channels
+#
+# class ResNetBasicBlock(ResNetResidualBlock):
+#     """
+#     Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation
+#     """
+#     expansion = 1
+#     def __init__(self, in_channels, out_channels, *args, **kwargs):
+#         super().__init__(in_channels, out_channels, *args, **kwargs)
+#         self.blocks = nn.Sequential(
+#             conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling),
+#             activation_func(self.activation),
+#             conv_bn(self.out_channels, self.expanded_channels, conv=self.conv, bias=False),
+#         )
+#
+# class ResNetLayer(nn.Module):
+#     """
+#     A ResNet layer composed by `n` blocks stacked one after the other
+#     """
+#     def __init__(self, in_channels, out_channels, block=ResNetBasicBlock, n=1, *args, **kwargs):
+#         super().__init__()
+#         # 'We perform downsampling directly by convolutional layers that have a stride of 2.'
+#         downsampling = 2 if in_channels != out_channels else 1
+#         self.blocks = nn.Sequential(
+#             block(in_channels , out_channels, *args, **kwargs, downsampling=downsampling),
+#             *[block(out_channels * block.expansion,
+#                     out_channels, downsampling=1, *args, **kwargs) for _ in range(n - 1)]
+#         )
+#
+#     def forward(self, x):
+#         x = self.blocks(x)
+#         return x
+#
+# class FeatureEncoder(nn.Module):
+#     def __init__(self):
+#         super(FeatureEncoder, self).__init__()
+#
+#         # convolutional
+#         self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
+#         self.norm1_1 = nn.InstanceNorm2d(64)
+#         self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+#
+#         # residual blocks
+#         self.res1 = ResNetLayer(64, 128, block=ResNetBasicBlock, n=1)
+#         self.res2 = ResNetLayer(128, 256, block=ResNetBasicBlock, n=1)
+#         self.res3 = ResNetLayer(256, 512, block=ResNetBasicBlock, n=1)
+#
+#     def forward(self, x):
+#         x = F.relu(self.norm1_1(self.conv1_1(x)))
+#         c4 = self.pool1(x)
+#         c3 = self.res1(c4)
+#         c2 = self.res2(c3)
+#         c1 = self.res3(c2)
+#         return c1, c2, c3, c4
+#
+#
+# # In[8]:
+#
+#
+# def double_conv(in_channels, out_channels):
+#     return nn.Sequential(
+#         nn.Conv2d(in_channels, out_channels, 3, padding=1),
+#         nn.InstanceNorm2d(out_channels),
+#         nn.Conv2d(out_channels, out_channels, 3, padding=1),
+#         nn.InstanceNorm2d(out_channels),
+#     )
+#
+# class RecoloringDecoder(nn.Module):
+#     # c => (bz, channel, h, w)
+#     # [Pt, c1]: (18 + 512) -> (256)
+#     # [c2, d1]: (256 + 256) -> (128)
+#     # [Pt, c3, d2]: (18 + 128 + 128) -> (64)
+#     # [Pt, c4, d3]: (18 + 64 + 64) -> 64
+#     # [Illu, d4]: (1 + 64) -> 3
+#
+#     def __init__(self):
+#         super().__init__()
+#         self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+#
+#         self.dconv_up_4 = double_conv(18 + 512, 256)
+#         self.dconv_up_3 = double_conv(256 + 256, 128)
+#         self.dconv_up_2 = double_conv(18 + 128 + 128, 64)
+#         self.dconv_up_1 = double_conv(18 + 64 + 64, 64)
+#         self.conv_last = nn.Conv2d(1 + 64, 3, 3, padding=1)
+#
+#
+#     def forward(self, c1, c2, c3, c4, target_palettes_1d, illu):
+#         bz, h, w = c1.shape[0], c1.shape[2], c1.shape[3]
+#         target_palettes = torch.ones(bz, 18, h, w).float().to(device)
+#         target_palettes = target_palettes.reshape(h, w, bz * 18) * target_palettes_1d
+#         target_palettes = target_palettes.permute(2, 0, 1).reshape(bz, 18, h, w)
+#
+#         # concatenate target_palettes with c1
+#         x = torch.cat((c1.float(), target_palettes.float()), 1)
+#         x = self.dconv_up_4(x)
+#         x = self.upsample(x)
+#
+#         # concatenate c2 with x
+#         x = torch.cat([c2, x], dim=1)
+#         x = self.dconv_up_3(x)
+#         x = self.upsample(x)
+#
+#         # concatenate target_palettes and c3 with x
+#         bz, h, w = x.shape[0], x.shape[2], x.shape[3]
+#         target_palettes = torch.ones(bz, 18, h, w).float().to(device)
+#         target_palettes = target_palettes.reshape(h, w, bz * 18) * target_palettes_1d
+#         target_palettes = target_palettes.permute(2, 0, 1).reshape(bz, 18, h, w)
+#         x = torch.cat([target_palettes.float(), c3, x], dim=1)
+#         x = self.dconv_up_2(x)
+#         x = self.upsample(x)
+#
+#         # concatenate target_palettes and c4 with x
+#         bz, h, w = x.shape[0], x.shape[2], x.shape[3]
+#         target_palettes = torch.ones(bz, 18, h, w).float().to(device)
+#         target_palettes = target_palettes.reshape(h, w, bz * 18) * target_palettes_1d
+#         target_palettes = target_palettes.permute(2, 0, 1).reshape(bz, 18, h, w)
+#         x = torch.cat([target_palettes.float(), c4, x], dim=1)
+#         x = self.dconv_up_1(x)
+#         x = self.upsample(x)
+#         illu = illu.view(illu.size(0), 1, illu.size(1), illu.size(2))
+#         x = torch.cat((x, illu), dim = 1)
+#         x = self.conv_last(x)
+#         return x
+#
+#
+# # In[9]:
+#
+#
+# from torch.utils.data import Dataset, DataLoader
+# import pathlib
+#
+# def get_illuminance(img):
+#     """
+#     Get the luminance of an image. Shape: (h, w)
+#     """
+#     img = img.permute(1, 2, 0)  # (h, w, channel)
+#     img = img.numpy()
+#     img = img.astype(np.float) / 255.0
+#     img_LAB = rgb2lab(img)
+#     img_L = img_LAB[:,:,0]  # luminance  # (h, w)
+#     return torch.from_numpy(img_L)
+#
+# class ColorTransferDataset(Dataset):
+#     def __init__(self, data_folder, transform):
+#         super().__init__()
+#         self.data_folder = data_folder
+#         self.transform = transform
+#
+#     def __len__(self):
+#         output_folder = self.data_folder/"output"
+#         return len(list(output_folder.glob("*")))
+#
+#     def __getitem__(self, idx):
+#         input_img_folder = self.data_folder/"input"
+#         old_palette = self.data_folder/"old_palette"
+#         new_palette = self.data_folder/"new_palette"
+#         output_img_folder = self.data_folder/"output"
+#         files = list(output_img_folder.glob("*"))
+#
+#         f = files[idx]
+#         ori_image = transform(cv2.imread(str(input_img_folder/f.name)))
+#         new_image = transform(cv2.imread(str(output_img_folder/f.name)))
+#         illu = get_illuminance(ori_image)
+#
+#         new_palette = pickle.load(open(str(new_palette/f.stem) +'.pkl', 'rb'))
+#         new_palette = new_palette[:, :6, :].ravel() / 255.0
+#
+#         old_palette = pickle.load(open(str(old_palette/f.stem) +'.pkl', 'rb'))
+#         old_palette = old_palette[:, :6, :].ravel() / 255.0
+#
+#         ori_image = ori_image.double()
+#         new_image = new_image.double()
+#         illu = illu.double()
+#         new_palette = torch.from_numpy(new_palette).double()
+#         old_palette = torch.from_numpy(old_palette).double()
+#
+#         return ori_image, new_image, illu, new_palette, old_palette
+#
+# def viz_color_palette(hexcodes):
+#     """
+#     visualize color palette
+#     """
+#     hexcodes = list(hexcodes)
+#     while len(hexcodes) < 6:
+#         hexcodes = hexcodes + hexcodes
+#     hexcodes = hexcodes[:6]
+#
+#     palette = []
+#     for hexcode in hexcodes:
+#         rgb = np.array(list(int(hexcode.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)))
+#         palette.append(rgb)
+#
+#     palette = np.array(palette)[np.newaxis, :, :]
+#     return palette
+#
+# def viz_image_ori_new_out(ori, palette, new, out):
+#     """
+#     visualize original image, input palette, true new image, and output image from the model.
+#     """
+#     ori = ori.detach().cpu().numpy()
+#     new = new.detach().cpu().numpy()
+#     out = out.detach().cpu().numpy()
+#     palette = palette.detach().cpu().numpy()
+#
+#     plt.imshow(np.transpose(ori, (1,2,0)), interpolation='nearest')
+#     plt.title("Original Image")
+#     plt.show()
+#
+#     palette = palette.reshape((1, 6, 3))
+#     plt.imshow(palette, interpolation='nearest')
+#     plt.title("Palette")
+#     plt.show()
+#
+#     # plt.imshow((np.transpose(out, (1,2,0)) * 255).astype(np.uint8))
+#     plt.imshow((np.transpose(out, (1,2,0))))
+#     plt.title("Output Image")
+#     plt.show()
+#
+#     plt.imshow(np.transpose(new, (1,2,0)), interpolation='nearest')
+#     plt.title("True Image")
+#     plt.show()
+#
+#
+# # ### train FE and RD
+#
+# # In[15]:
+#
+#
+# # hyperparameters
+# bz = 16
+# epoches = 1000
+# lr = 0.0002
+#
+# # pre-processsing
+# transform = transforms.Compose([
+#     transforms.ToPILImage(),
+#     transforms.Resize((432, 288)),
+#     transforms.ToTensor(),
+# ])
+#
+#
+# # dataset and dataloader
+# train_data = ColorTransferDataset(pathlib.Path("/home/jovyan/work/data/train"), transform)
+# train_loader = DataLoader(train_data, batch_size=bz)
+#
+# # create model, criterion and optimzer
+# FE = FeatureEncoder().float().to(device)
+# RD = RecoloringDecoder().float().to(device)
+# criterion = nn.MSELoss()
+# optimizer = torch.optim.AdamW(list(FE.parameters()) + list(RD.parameters()), lr=lr, weight_decay=4e-3)
+#
+#
+# # In[ ]:
+#
+#
+# # train FE and RD
+# min_loss = float('inf')
+# for e in range(epoches):
+#     total_loss = 0.
+#     for i_batch, sampled_batched in enumerate(tqdm(train_loader)):
+#         ori_image, new_image, illu, new_palette, ori_palette = sampled_batched
+#         palette = new_palette.flatten()
+#         c1, c2, c3, c4 = FE.forward(ori_image.float().to(device))
+#         out = RD.forward(c1, c2, c3, c4, palette.float().to(device), illu.float().to(device))
+#
+#         optimizer.zero_grad()
+#         loss = criterion(out, new_image.float().to(device))
+#         loss.backward()
+#         optimizer.step()
+#     total_loss += loss.item()
+#     print(e, total_loss)
+#
+#     if total_loss < min_loss:
+#         min_loss = total_loss
+#         state = {
+#             'epoch': e,
+#             'FE': FE.state_dict(),
+#             'RD': RD.state_dict(),
+#             'optimizer': optimizer.state_dict(),
+#         }
+#         torch.save(state, "/home/jovyan/work/saved_models/FE_RD.pth")
+#
+#
+# # In[17]:
+#
+#
+# # load model from saved model file
+# state = torch.load("/home/jovyan/work/saved_models/FE_RD.pth")
+# FE = FeatureEncoder().float().to(device)
+# RD = RecoloringDecoder().float().to(device)
+# FE.load_state_dict(state['FE'])
+# RD.load_state_dict(state['RD'])
+# optimizer.load_state_dict(state['optimizer'])
+#
+# for i_batch, sampled_batched in enumerate(train_loader):
+#     ori_image, new_image, illu, new_palette, ori_palette = sampled_batched
+#     flat_palette = new_palette.flatten()
+#     c1, c2, c3, c4 = FE.forward(ori_image.float().to(device))
+#     out = RD.forward(c1, c2, c3, c4, flat_palette.float().to(device), illu.float().to(device))
+#     break
+#
+# idx = 3
+# viz_image_ori_new_out(ori_image[idx], new_palette[idx], new_image[idx], out[idx])
+#
+#
+# # ## Adversarial Training
+#
+# # In[18]:
+#
+#
+# # discriminator model for adversarial training
+# class Discriminator(nn.Module):
+#     def __init__(self, input_channel):
+#         super().__init__()
+#         self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=4, stride=2)
+#         self.norm1 = nn.InstanceNorm2d(64)
+#         self.conv2 = nn.Conv2d(64, 64, kernel_size=4, stride=2)
+#         self.norm2 = nn.InstanceNorm2d(64)
+#         self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=2)
+#         self.norm3 = nn.InstanceNorm2d(64)
+#         self.conv4 = nn.Conv2d(64, 64, kernel_size=4, stride=2)
+#         self.norm4 = nn.InstanceNorm2d(64)
+#         self.linear = nn.Linear(25600, 1)
+#
+#     def forward(self, x):
+#         x = F.leaky_relu(self.norm1(self.conv1(x)))
+#         x = F.leaky_relu(self.norm2(self.conv2(x)))
+#         x = F.leaky_relu(self.norm3(self.conv3(x)))
+#         x = F.leaky_relu(self.norm4(self.conv4(x)))
+#         x = x.view(x.size(0), -1)  # flatten
+#         x = self.linear(x)
+#         return torch.sigmoid(x)
+#
+#
+# # In[19]:
+#
+#
+# # load model from saved model file
+# state = torch.load("/home/jovyan/work/saved_models/FE_RD.pth")
+# FE = FeatureEncoder().float().to(device)
+# RD = RecoloringDecoder().float().to(device)
+# FE.load_state_dict(state['FE'])
+# RD.load_state_dict(state['RD'])
+# optimizer.load_state_dict(state['optimizer'])
+#
+# # freeze FE
+# for param in FE.parameters():
+#     param.requires_grad = False
+#
+#
+# # In[12]:
+#
+#
+# def merge(img, palette_1d):
+#     """
+#     replicating Palette spatially and concatenating in depth to the image.
+#     necessary for the adversarial training
+#     """
+#     img = img.to(device)
+#     palette_1d = palette_1d.to(device)
+#
+#     palette_1d = palette_1d.flatten()
+#     bz, h, w = img.shape[0], img.shape[2], img.shape[3]
+#     palettes = torch.ones(bz, 18, h, w).float().to(device)
+#     palettes = palettes.reshape(h, w, bz * 18) * palette_1d
+#     palettes = palettes.permute(2, 0, 1).reshape(bz, 18, h, w)
+#
+#     # concatenate target_palettes with c1
+#     x = torch.cat((img.float(), palettes.float()), 1)
+#     return x
+#
+#
+# # In[20]:
+#
+#
+# bz = 16
+# epoches = 1000
+# lr = 0.0002
+#
+# # pre-processsing
+# transform = transforms.Compose([
+#     transforms.ToPILImage(),
+#     transforms.Resize((432, 288)),
+#     transforms.ToTensor(),
+# ])
+#
+# train_data = ColorTransferDataset(pathlib.Path("/home/jovyan/work/data/train"), transform)
+# train_loader = DataLoader(train_data, batch_size=bz)
+#
+# D = Discriminator(input_channel=21).to(device)
+# adversarial_loss = torch.nn.BCELoss()
+# optimizer_G = torch.optim.Adam(RD.parameters(), betas=(0.5, 0.999))
+# optimizer_D = torch.optim.Adam(D.parameters(), betas=(0.5, 0.999))
+# train_loader = DataLoader(train_data, batch_size=bz)
+#
+#
+# # In[ ]:
+#
+#
+# Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
+# lambda_MSE_loss = 10
+#
+# min_g_loss = float('inf')
+# for e in range(epoches):
+#     total_g_loss = 0
+#     total_d_loss = 0
+#     for i_batch, sampled_batched in enumerate(tqdm(train_loader)):
+#         ori_image, new_image, illu, new_palette, ori_palette = sampled_batched
+#         flat_palette = new_palette.flatten()
+#         c1, c2, c3, c4 = FE.forward(ori_image.float().to(device))
+#         out_image = RD.forward(c1, c2, c3, c4, flat_palette.float().to(device), illu.float().to(device))
+#
+#         valid = Variable(Tensor(ori_image.size(0), 1).fill_(1.0), requires_grad=False)
+#         fake = Variable(Tensor(new_image.size(0), 1).fill_(0.0), requires_grad=False)
+#
+#         ori_i_new_p = merge(ori_image, new_palette)
+#         ori_i_ori_p = merge(ori_image, ori_palette)
+#         out_i_ori_p = merge(out_image, ori_palette)
+#         out_i_new_p = merge(out_image, new_palette)
+#
+#         # generator loss
+#         optimizer_G.zero_grad()
+#         g_loss = adversarial_loss(D(out_i_new_p), valid) + lambda_MSE_loss * criterion(out_image, new_image.float().to(device))
+#         g_loss.backward()
+#         optimizer_G.step()
+#
+#         # discriminator loss
+#         real_loss = adversarial_loss(D(ori_i_ori_p), valid)
+#         fake_loss = adversarial_loss(D(ori_i_new_p.detach()), fake) + adversarial_loss(D(out_i_ori_p.detach()), fake) + adversarial_loss(D(out_i_new_p.detach()), fake)
+#
+#         optimizer_D.zero_grad()
+#         d_loss = (real_loss + fake_loss) / 4
+#         d_loss.backward()
+#         optimizer_D.step()
+#
+#         total_g_loss += g_loss
+#         total_d_loss += d_loss
+#
+#     if total_g_loss < min_g_loss:
+#         min_g_loss = total_g_loss
+#         state = {
+#             'epoch': e,
+#             'FE': FE.state_dict(),
+#             'RD': RD.state_dict(),
+#             'D': D.state_dict(),
+#             'optimizer': optimizer.state_dict(),
+#         }
+#         torch.save(state, "/home/jovyan/work/saved_models/adv_FE_RD.pth")
+#
+#     print(f"{e}: Generator loss of {total_g_loss:.3f}; Discrimator loss of {total_d_loss:.3f}")
+#
+#
+#
+# # In[21]:
+#
+#
+# # load model from saved model file
+# state = torch.load("/home/jovyan/work/saved_models/adv_FE_RD.pth")
+# FE = FeatureEncoder().float().to(device)
+# RD = RecoloringDecoder().float().to(device)
+# FE.load_state_dict(state['FE'])
+# RD.load_state_dict(state['RD'])
+# optimizer.load_state_dict(state['optimizer'])
+#
+# for i_batch, sampled_batched in enumerate(train_loader):
+#     ori_image, new_image, illu, new_palette, ori_palette = sampled_batched
+#     print(ori_image.shape, illu.shape, new_palette.shape)
+#     flat_palette = new_palette.flatten()
+#     c1, c2, c3, c4 = FE.forward(ori_image.float().to(device))
+#     out = RD.forward(c1, c2, c3, c4, flat_palette.float().to(device), illu.float().to(device))
+#     break
+#
+# idx = 3
+# viz_image_ori_new_out(ori_image[idx], new_palette[idx], new_image[idx], out[idx])
+#
+#
+# # In[ ]:
+#
diff --git a/paletteget.py b/paletteget.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cc85aae68f6ae240b95e12ee5cdd027d6d3dd0b
--- /dev/null
+++ b/paletteget.py
@@ -0,0 +1,48 @@
+import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+import os
+
+def extract_dominant_colors(img):
+    # Convert the image to a 2D array of RGB values
+    rgb_img = img.reshape(-1, 3)
+
+    # Use KMeans clustering to extract dominant colors
+    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
+    k = 5
+    ret, label, center = cv2.kmeans(np.float32(rgb_img), k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
+
+    # Get the dominant colors and their percentages
+    unique, counts = np.unique(label, return_counts=True)
+    dominant_colors = center[unique].tolist()
+    percentages = counts / sum(counts) * 100
+
+    # Create a palette list with the dominant colors and their hex codes
+    palette = []
+    for color in dominant_colors:
+        hex_code = "#{:02x}{:02x}{:02x}".format(*map(int, color))
+        palette.append((color, hex_code))
+
+    return palette, percentages
+
+# Get a list of all image files in the current directory
+images = [f for f in os.listdir() if f.endswith(".jpg")]
+
+# Extract the dominant colors from each image
+output = []
+for image_file in images:
+    img = cv2.imread(image_file)
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+    palette, percentages = extract_dominant_colors(img)
+
+    output.append((image_file, palette, percentages))
+
+# Write the output to a file
+with open("colour_output.txt", "w") as f:
+    for image_file, palette, percentages in output:
+        f.write("Image: " + image_file + "\n")
+        f.write("Palette:\n")
+        for i, (color, hex_code) in enumerate(palette):
+            f.write("  Color: " + str(color) + " Hex code: " + hex_code + " Percentage: {:.2f}%\n".format(percentages[i]))
+        f.write("\n")
diff --git a/palettegrabber.py b/palettegrabber.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b4dffafc8d23e0e6102fd17515fab556c53d788
--- /dev/null
+++ b/palettegrabber.py
@@ -0,0 +1,45 @@
+import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
+# Step 1: Load the image
+img = cv2.imread("IMG_0704.jpg")
+img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+# Step 2: Convert the image to a 2D array of RGB values
+rgb_img = img.reshape(-1, 3)
+
+# Step 3: Use KMeans clustering to extract dominant colors
+criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
+k = 6
+ret, label, center = cv2.kmeans(np.float32(rgb_img), k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
+
+# Step 4: Plot the dominant colors
+center = np.uint8(center)
+res = center[label.flatten()]
+output_image = res.reshape((img.shape))
+
+# Plot the image
+plt.imshow(output_image)
+plt.axis("off")
+plt.show()
+
+# Plot a pie chart
+unique, counts = np.unique(label, return_counts=True)
+dominant_colors = center[unique].tolist()
+percentages = counts / sum(counts) * 100
+normalized_colors = [list(map(lambda x: x/255, color)) for color in dominant_colors]
+patches, texts = plt.pie(percentages, colors=normalized_colors, startangle=90)
+plt.axis("equal")
+plt.tight_layout()
+plt.show()
+
+# Print a palette list
+palette = []
+for color in dominant_colors:
+    hex_code = "#{:02x}{:02x}{:02x}".format(*map(int, color))
+    palette.append((color, hex_code))
+
+print("Palette:")
+for i, (color, hex_code) in enumerate(palette):
+    print("  Color:", color, " Hex code:", hex_code, " Percentage: {:.2f}%".format(percentages[i]))
diff --git a/paletteindex.py b/paletteindex.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cf3cb329a7705479c4c865aa3a0b2c6b2443a0a
--- /dev/null
+++ b/paletteindex.py
@@ -0,0 +1,68 @@
+import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+import os
+
+def extract_dominant_colors(img):
+    # Convert the image to a 2D array of RGB values
+    rgb_img = img.reshape(-1, 3)
+
+    # Use KMeans clustering to extract dominant colors
+    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
+    k = 5
+    ret, label, center = cv2.kmeans(np.float32(rgb_img), k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
+
+    # Get the dominant colors and their percentages
+    unique, counts = np.unique(label, return_counts=True)
+    dominant_colors = center[unique].tolist()
+    percentages = counts / sum(counts) * 100
+
+    # Create a palette list with the dominant colors and their hex codes
+    palette = []
+    for color in dominant_colors:
+        hex_code = "#{:02x}{:02x}{:02x}".format(*map(int, color))
+        palette.append(hex_code)
+
+    return palette
+
+
+def viz_color_palette(hexcodes):
+    """
+    visualize color palette
+    """
+    #hexcodes = list(hexcodes)
+    # while len(hexcodes) < 6:
+    #     hexcodes = hexcodes + hexcodes
+    # hexcodes = hexcodes[:6]
+
+    print(hexcodes)
+    palette = []
+    for hexcode in hexcodes:
+        rgb = np.array(list(int(hexcode.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)))
+        palette.append(rgb)
+
+    palette = np.array(palette)[np.newaxis, :, :]
+    return palette[0]
+
+
+# Get a list of all image files in the current directory
+images = [f for f in os.listdir() if f.endswith(".jpg")]
+
+# Extract the dominant colors from each image
+output = []
+for image_file in images:
+    img = cv2.imread(image_file)
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+    palette = extract_dominant_colors(img)
+    print(viz_color_palette(palette))
+    output.append((image_file, palette))
+
+# # Write the output to a file
+# with open("colour_output.txt", "w") as f:
+#     for image_file, palette, percentages in output:
+#         f.write("Image: " + image_file + "\n")
+#         f.write("Palette:\n")
+#         for i, (color, hex_code) in enumerate(palette):
+#             f.write("  Color: " + str(color) + " Hex code: " + hex_code + " Percentage: {:.2f}%\n".format(percentages[i]))
+#         f.write("\n")
diff --git a/scatterplotter.py b/scatterplotter.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dfc5a4c6a7666caf1adc972ac0c50a7e1944f82
--- /dev/null
+++ b/scatterplotter.py
@@ -0,0 +1,85 @@
+import os
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+import threading
+import warnings
+import matplotlib.animation as animation
+from mpl_toolkits.mplot3d import Axes3D
+
+warnings.filterwarnings("ignore")
+# Define the path to the folder containing the image folders
+folder_path = 'C:/Users/kirk/Desktop/artbench-10-imagefolder-split/test/'
+
+# Define a function to process an image and return its dominant colors
+def process_image(file_path):
+    # Open the image and convert it to a numpy array
+    image = Image.open(file_path)
+    image_array = np.array(image)
+    # Reshape the array to be a list of pixels
+    pixel_list = image_array.reshape(-1, image_array.shape[-1])
+    # Calculate the dominant colors using k-means clustering with k=5
+    from sklearn.cluster import KMeans
+    kmeans = KMeans(n_clusters=5, random_state=0).fit(pixel_list)
+    # Get the RGB values of the most dominant color
+    most_common_color = kmeans.cluster_centers_[np.argmax(np.unique(kmeans.labels_, return_counts=True)[1])]
+    return most_common_color
+
+# Define a function to process a folder of images using threads
+def process_folder(folder_name, results):
+    # Create a list to store the dominant colors for each image in the folder
+    dominant_colors = []
+    # Loop through each file in the folder and append the dominant colors to the list
+    file_paths = [folder_path + folder_name + '/' + file_name for file_name in os.listdir(folder_path + folder_name) if file_name.endswith('.jpg')]
+    for file_path in tqdm(file_paths, desc=folder_name):
+        color = process_image(file_path)
+        dominant_colors.append(color)
+    # If any dominant colors were found, concatenate them into a single array
+    if len(dominant_colors) > 0:
+        dominant_colors = np.array(dominant_colors)
+    results.append((folder_name, dominant_colors))
+
+
+# Process each folder using threads
+results = []
+threads = []
+for folder_name in os.listdir(folder_path):
+    if not os.path.isdir(folder_path + folder_name):
+        continue
+    thread = threading.Thread(target=process_folder, args=(folder_name, results))
+    thread.start()
+    threads.append(thread)
+
+# Wait for all the threads to finish
+for thread in threads:
+    thread.join()
+
+# Create a 3D scatter plot of the dominant colors for each folder
+fig = plt.figure()
+ax = fig.add_subplot(111, projection='3d')
+for folder_name, dominant_colors in results:
+    if len(dominant_colors) > 0:
+        # Extract the R, G, and B values from the dominant colors array
+        R = dominant_colors[:, 0]
+        G = dominant_colors[:, 1]
+        B = dominant_colors[:, 2]
+        ax.scatter(R, G, B, label=folder_name)
+# Add labels and a legend to the plot
+ax.set_xlabel('Red')
+ax.set_ylabel('Green')
+ax.set_zlabel('Blue')
+ax.legend()
+
+# Set up the animation
+def rotate(angle):
+    ax.view_init(azim=angle)
+    return fig,
+
+angles = np.linspace(0, 360, 360)
+rot_animation = animation.FuncAnimation(fig, rotate, frames=angles, interval=50)
+
+# Save the animation as a mp4 video
+rot_animation.save('3d_scatter_animation.mp4', fps=30, extra_args=['-vcodec', 'libx264'])
+
+plt.show()
diff --git a/visualizer.py b/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb6fdffa42c2f9b94dd24c66a6ac9942e6c32412
--- /dev/null
+++ b/visualizer.py
@@ -0,0 +1,54 @@
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.manifold import TSNE
+from PIL import Image
+from tqdm import tqdm
+
+directory = 'C:/Users/kirk/Desktop/artbench-10-imagefolder-split/train/'
+
+# Define the paths to the image folders
+folder_paths = [os.path.join(directory, folder) for folder in os.listdir(directory) if os.path.isdir(os.path.join(directory, folder))]
+
+# Initialize empty lists for images and labels
+images = []
+labels = []
+
+
+# Create a dictionary to map folder names to integer labels
+folder_labels = {'art_nouveau': 0, 'baroque': 1, 'expressionism': 2, 'impressionism': 3,
+                 'post_impressionism': 4, 'realism': 5, 'renaissance': 6, 'romanticism': 7,
+                 'surrealism': 8, 'ukiyo_e': 9}
+
+# Loop through each folder and load the images
+for folder_path in folder_paths:
+    # Get the label for this folder
+    folder_name = os.path.basename(folder_path)
+    label = folder_labels[folder_name]
+    for filename in tqdm(os.listdir(folder_path), desc=f'Loading images from {folder_name}', unit='image'):
+        # Load the image as a numpy array
+        img = np.array(Image.open(os.path.join(folder_path, filename)))
+        # Add the image to the list
+        images.append(img)
+        # Add the label to the list
+        labels.append(label)
+
+
+
+
+# Convert the lists to numpy arrays
+images = np.array(images)
+labels = np.array(labels)
+
+# Reshape the images to 1D arrays
+images = images.reshape((images.shape[0], -1))
+
+# Initialize t-SNE with default parameters
+tsne = TSNE()
+
+# Fit and transform the data
+embeddings = tsne.fit_transform(images)
+
+# Plot the embeddings with different colors for each cluster
+plt.scatter(embeddings[:, 0], embeddings[:, 1], c=labels)
+plt.show()