diff --git a/datasetgenerate.py b/datasetgenerate.py index a3c179b7b02361bcf7bd5e2ef61da99da5283808..0ac1f1e7d4d2f1b9b2c1801daeef4eddea80dfd4 100644 --- a/datasetgenerate.py +++ b/datasetgenerate.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 import collections import pathlib import random @@ -20,33 +18,17 @@ import torch.nn.functional as F from torchvision import transforms from torch.autograd import Variable -device = "cuda" if torch.cuda.is_available() else "cpu" - -# image_to_palette: image -> corresponding palettes -# image_to_palette = collections.defaultdict(set) -# -# with open("/datasets/dribbble/dribbble/dribbble_designs - dribbble_designs.tsv", 'r') as tsvfile: -# for line in tsvfile: -# # print(line.strip().split('\t')) # ['uid', 'url', 'media_path', 'description', 'comments', 'tags', 'color_palette', 'likes', 'saves', 'date', 'collection_time', 'write_date'] -# _, _, media_path, _, _, _, color_palette, *_ = line.strip().split('\t') -# media_filename = media_path.split('/')[-1] -# image_to_palette[media_filename] = set(color_palette[1:-1].split(',')) - def extract_dominant_colors(img): - # Convert the image to a 2D array of RGB values rgb_img = img.reshape(-1, 3) - # Use KMeans clustering to extract dominant colors criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) k = 5 ret, label, center = cv2.kmeans(np.float32(rgb_img), k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS) - # Get the dominant colors and their percentages unique, counts = np.unique(label, return_counts=True) dominant_colors = center[unique].tolist() percentages = counts / sum(counts) * 100 - # Create a palette list with the dominant colors and their hex codes palette = [] for color in dominant_colors: hex_code = "#{:02x}{:02x}{:02x}".format(*map(int, color)) @@ -54,21 +36,8 @@ def extract_dominant_colors(img): return palette -# In[5]: - - -# get data -# - visualize image -# - visualize the color palette def viz_color_palette(hexcodes): - """ - visualize color palette - """ - # hexcodes = list(hexcodes) - # while len(hexcodes) < 6: - # hexcodes = hexcodes + hexcodes - # hexcodes = hexcodes[:6] palette = [] for hexcode in hexcodes: @@ -78,54 +47,8 @@ def viz_color_palette(hexcodes): palette = np.array(palette)[np.newaxis, :, :] return palette - -# def viz_image(path, image_to_palette: Dict): -# """ -# visualize image -# visualize palette (using image_to_palette) -# """ -# assert pathlib.Path(path).name in image_to_palette -# img = cv2.imread(str(path)) -# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) -# palette = viz_color_palette(image_to_palette[pathlib.Path(path).name]) -# -# # visualize image -# plt.imshow(img) -# plt.show() -# -# # visualize color palette -# print(palette.shape) -# plt.imshow(palette) -# # print(palette.shape) -# plt.axis('off') -# plt.show() -# -# return - - -# pathlist = pathlib.Path("/datasets/dribbble_half/dribbble_half/data").glob("*.png") -# for path in pathlist: -# viz_image(path, image_to_palette) -# break - - -# ### Generate Dataset - -# In[ ]: - - -# generate augmented images for training -# - input/: images with original palette -# - output/: images with new palette -# - old_palette/: pickled files of original palette -# - new_palette/: pickled files of new palette - def augment_image(img, title, hue_shift): - # plt.imshow(img) - # plt.title(f"Original {title} (in RGB)") - # plt.show() - # RGB -> HSV -> hue-shift img_HSV = matplotlib.colors.rgb_to_hsv(img) a_2d_index = np.array([[1,0,0] for _ in range(img_HSV.shape[1])]).astype('bool') img_HSV[:, a_2d_index] = (img_HSV[:, a_2d_index] + hue_shift) % 1 @@ -135,7 +58,6 @@ def augment_image(img, title, hue_shift): plt.title(f"New {title} (in RGB)") plt.show() - # fixed original luminance img = img.astype(np.float) / 255.0 new_img = new_img.astype(np.float) / 255.0 ori_img_LAB = rgb2lab(img) @@ -150,10 +72,8 @@ def augment_image(img, title, hue_shift): return new_img_augmented images = [f for f in os.listdir() if f.endswith(".jpg")] -#print(images) for path in images: - #print(path) - # assert pathlib.Path(path).name in image_to_palette + img = cv2.imread(str(path)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) palette = viz_color_palette(extract_dominant_colors(img)) @@ -162,559 +82,9 @@ for path in images: augmented_image = augment_image(img, "Image", hue_shift) augmented_palette = augment_image(palette, "Palette", hue_shift) - file_input = path[:-4] + '_converted_to_BGR' - file_output = path[:-4] + '_shift_luminance_corrected' - - cv2.imwrite(file_input + '.jpg', img) - pickle.dump(palette, open(f'data/train/old_palette/{file_input}.pkl', 'wb')) - cv2.imwrite(file_output + '.jpg', augmented_image) - pickle.dump(augmented_palette, open(f'data/train/new_palette/{file_output}.pkl', 'wb')) - #print("this is done") - - -# ### Feature Encoder (FE) and Recoloring Decoder (RD) - -# In[6]: - + path_stem = path[:-4] -# image = cv2.imread("/home/jovyan/work/data/train/input/0a02bab21de25f3ea1345dacf2a23300.png") -# plt.imshow(image) -# plt.show() -# -# -# # In[7]: -# -# -# from functools import partial -# class Conv2dAuto(nn.Conv2d): -# def __init__(self, *args, **kwargs): -# super().__init__(*args, **kwargs) -# self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) # dynamic add padding based on the kernel_size -# -# conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False) -# -# def activation_func(activation): -# return nn.ModuleDict([ -# ['relu', nn.ReLU(inplace=True)], -# ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)], -# ['selu', nn.SELU(inplace=True)], -# ['none', nn.Identity()] -# ])[activation] -# -# def conv_bn(in_channels, out_channels, conv, *args, **kwargs): -# return nn.Sequential(conv(in_channels, out_channels, *args, **kwargs), nn.InstanceNorm2d(out_channels)) -# -# class ResidualBlock(nn.Module): -# def __init__(self, in_channels, out_channels, activation='relu'): -# super().__init__() -# self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation -# self.blocks = nn.Identity() -# self.activate = activation_func(activation) -# self.shortcut = nn.Identity() -# -# def forward(self, x): -# residual = x -# if self.should_apply_shortcut: residual = self.shortcut(x) -# x = self.blocks(x) -# x += residual -# x = self.activate(x) -# return x -# -# @property -# def should_apply_shortcut(self): -# return self.in_channels != self.out_channels -# -# class ResNetResidualBlock(ResidualBlock): -# def __init__(self, in_channels, out_channels, expansion=1, downsampling=1, conv=conv3x3, *args, **kwargs): -# super().__init__(in_channels, out_channels, *args, **kwargs) -# self.expansion, self.downsampling, self.conv = expansion, downsampling, conv -# self.shortcut = nn.Sequential( -# nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1, -# stride=self.downsampling, bias=False), -# nn.BatchNorm2d(self.expanded_channels)) if self.should_apply_shortcut else None -# -# -# @property -# def expanded_channels(self): -# return self.out_channels * self.expansion -# -# @property -# def should_apply_shortcut(self): -# return self.in_channels != self.expanded_channels -# -# class ResNetBasicBlock(ResNetResidualBlock): -# """ -# Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation -# """ -# expansion = 1 -# def __init__(self, in_channels, out_channels, *args, **kwargs): -# super().__init__(in_channels, out_channels, *args, **kwargs) -# self.blocks = nn.Sequential( -# conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling), -# activation_func(self.activation), -# conv_bn(self.out_channels, self.expanded_channels, conv=self.conv, bias=False), -# ) -# -# class ResNetLayer(nn.Module): -# """ -# A ResNet layer composed by `n` blocks stacked one after the other -# """ -# def __init__(self, in_channels, out_channels, block=ResNetBasicBlock, n=1, *args, **kwargs): -# super().__init__() -# # 'We perform downsampling directly by convolutional layers that have a stride of 2.' -# downsampling = 2 if in_channels != out_channels else 1 -# self.blocks = nn.Sequential( -# block(in_channels , out_channels, *args, **kwargs, downsampling=downsampling), -# *[block(out_channels * block.expansion, -# out_channels, downsampling=1, *args, **kwargs) for _ in range(n - 1)] -# ) -# -# def forward(self, x): -# x = self.blocks(x) -# return x -# -# class FeatureEncoder(nn.Module): -# def __init__(self): -# super(FeatureEncoder, self).__init__() -# -# # convolutional -# self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) -# self.norm1_1 = nn.InstanceNorm2d(64) -# self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0) -# -# # residual blocks -# self.res1 = ResNetLayer(64, 128, block=ResNetBasicBlock, n=1) -# self.res2 = ResNetLayer(128, 256, block=ResNetBasicBlock, n=1) -# self.res3 = ResNetLayer(256, 512, block=ResNetBasicBlock, n=1) -# -# def forward(self, x): -# x = F.relu(self.norm1_1(self.conv1_1(x))) -# c4 = self.pool1(x) -# c3 = self.res1(c4) -# c2 = self.res2(c3) -# c1 = self.res3(c2) -# return c1, c2, c3, c4 -# -# -# # In[8]: -# -# -# def double_conv(in_channels, out_channels): -# return nn.Sequential( -# nn.Conv2d(in_channels, out_channels, 3, padding=1), -# nn.InstanceNorm2d(out_channels), -# nn.Conv2d(out_channels, out_channels, 3, padding=1), -# nn.InstanceNorm2d(out_channels), -# ) -# -# class RecoloringDecoder(nn.Module): -# # c => (bz, channel, h, w) -# # [Pt, c1]: (18 + 512) -> (256) -# # [c2, d1]: (256 + 256) -> (128) -# # [Pt, c3, d2]: (18 + 128 + 128) -> (64) -# # [Pt, c4, d3]: (18 + 64 + 64) -> 64 -# # [Illu, d4]: (1 + 64) -> 3 -# -# def __init__(self): -# super().__init__() -# self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) -# -# self.dconv_up_4 = double_conv(18 + 512, 256) -# self.dconv_up_3 = double_conv(256 + 256, 128) -# self.dconv_up_2 = double_conv(18 + 128 + 128, 64) -# self.dconv_up_1 = double_conv(18 + 64 + 64, 64) -# self.conv_last = nn.Conv2d(1 + 64, 3, 3, padding=1) -# -# -# def forward(self, c1, c2, c3, c4, target_palettes_1d, illu): -# bz, h, w = c1.shape[0], c1.shape[2], c1.shape[3] -# target_palettes = torch.ones(bz, 18, h, w).float().to(device) -# target_palettes = target_palettes.reshape(h, w, bz * 18) * target_palettes_1d -# target_palettes = target_palettes.permute(2, 0, 1).reshape(bz, 18, h, w) -# -# # concatenate target_palettes with c1 -# x = torch.cat((c1.float(), target_palettes.float()), 1) -# x = self.dconv_up_4(x) -# x = self.upsample(x) -# -# # concatenate c2 with x -# x = torch.cat([c2, x], dim=1) -# x = self.dconv_up_3(x) -# x = self.upsample(x) -# -# # concatenate target_palettes and c3 with x -# bz, h, w = x.shape[0], x.shape[2], x.shape[3] -# target_palettes = torch.ones(bz, 18, h, w).float().to(device) -# target_palettes = target_palettes.reshape(h, w, bz * 18) * target_palettes_1d -# target_palettes = target_palettes.permute(2, 0, 1).reshape(bz, 18, h, w) -# x = torch.cat([target_palettes.float(), c3, x], dim=1) -# x = self.dconv_up_2(x) -# x = self.upsample(x) -# -# # concatenate target_palettes and c4 with x -# bz, h, w = x.shape[0], x.shape[2], x.shape[3] -# target_palettes = torch.ones(bz, 18, h, w).float().to(device) -# target_palettes = target_palettes.reshape(h, w, bz * 18) * target_palettes_1d -# target_palettes = target_palettes.permute(2, 0, 1).reshape(bz, 18, h, w) -# x = torch.cat([target_palettes.float(), c4, x], dim=1) -# x = self.dconv_up_1(x) -# x = self.upsample(x) -# illu = illu.view(illu.size(0), 1, illu.size(1), illu.size(2)) -# x = torch.cat((x, illu), dim = 1) -# x = self.conv_last(x) -# return x -# -# -# # In[9]: -# -# -# from torch.utils.data import Dataset, DataLoader -# import pathlib -# -# def get_illuminance(img): -# """ -# Get the luminance of an image. Shape: (h, w) -# """ -# img = img.permute(1, 2, 0) # (h, w, channel) -# img = img.numpy() -# img = img.astype(np.float) / 255.0 -# img_LAB = rgb2lab(img) -# img_L = img_LAB[:,:,0] # luminance # (h, w) -# return torch.from_numpy(img_L) -# -# class ColorTransferDataset(Dataset): -# def __init__(self, data_folder, transform): -# super().__init__() -# self.data_folder = data_folder -# self.transform = transform -# -# def __len__(self): -# output_folder = self.data_folder/"output" -# return len(list(output_folder.glob("*"))) -# -# def __getitem__(self, idx): -# input_img_folder = self.data_folder/"input" -# old_palette = self.data_folder/"old_palette" -# new_palette = self.data_folder/"new_palette" -# output_img_folder = self.data_folder/"output" -# files = list(output_img_folder.glob("*")) -# -# f = files[idx] -# ori_image = transform(cv2.imread(str(input_img_folder/f.name))) -# new_image = transform(cv2.imread(str(output_img_folder/f.name))) -# illu = get_illuminance(ori_image) -# -# new_palette = pickle.load(open(str(new_palette/f.stem) +'.pkl', 'rb')) -# new_palette = new_palette[:, :6, :].ravel() / 255.0 -# -# old_palette = pickle.load(open(str(old_palette/f.stem) +'.pkl', 'rb')) -# old_palette = old_palette[:, :6, :].ravel() / 255.0 -# -# ori_image = ori_image.double() -# new_image = new_image.double() -# illu = illu.double() -# new_palette = torch.from_numpy(new_palette).double() -# old_palette = torch.from_numpy(old_palette).double() -# -# return ori_image, new_image, illu, new_palette, old_palette -# -# def viz_color_palette(hexcodes): -# """ -# visualize color palette -# """ -# hexcodes = list(hexcodes) -# while len(hexcodes) < 6: -# hexcodes = hexcodes + hexcodes -# hexcodes = hexcodes[:6] -# -# palette = [] -# for hexcode in hexcodes: -# rgb = np.array(list(int(hexcode.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))) -# palette.append(rgb) -# -# palette = np.array(palette)[np.newaxis, :, :] -# return palette -# -# def viz_image_ori_new_out(ori, palette, new, out): -# """ -# visualize original image, input palette, true new image, and output image from the model. -# """ -# ori = ori.detach().cpu().numpy() -# new = new.detach().cpu().numpy() -# out = out.detach().cpu().numpy() -# palette = palette.detach().cpu().numpy() -# -# plt.imshow(np.transpose(ori, (1,2,0)), interpolation='nearest') -# plt.title("Original Image") -# plt.show() -# -# palette = palette.reshape((1, 6, 3)) -# plt.imshow(palette, interpolation='nearest') -# plt.title("Palette") -# plt.show() -# -# # plt.imshow((np.transpose(out, (1,2,0)) * 255).astype(np.uint8)) -# plt.imshow((np.transpose(out, (1,2,0)))) -# plt.title("Output Image") -# plt.show() -# -# plt.imshow(np.transpose(new, (1,2,0)), interpolation='nearest') -# plt.title("True Image") -# plt.show() -# -# -# # ### train FE and RD -# -# # In[15]: -# -# -# # hyperparameters -# bz = 16 -# epoches = 1000 -# lr = 0.0002 -# -# # pre-processsing -# transform = transforms.Compose([ -# transforms.ToPILImage(), -# transforms.Resize((432, 288)), -# transforms.ToTensor(), -# ]) -# -# -# # dataset and dataloader -# train_data = ColorTransferDataset(pathlib.Path("/home/jovyan/work/data/train"), transform) -# train_loader = DataLoader(train_data, batch_size=bz) -# -# # create model, criterion and optimzer -# FE = FeatureEncoder().float().to(device) -# RD = RecoloringDecoder().float().to(device) -# criterion = nn.MSELoss() -# optimizer = torch.optim.AdamW(list(FE.parameters()) + list(RD.parameters()), lr=lr, weight_decay=4e-3) -# -# -# # In[ ]: -# -# -# # train FE and RD -# min_loss = float('inf') -# for e in range(epoches): -# total_loss = 0. -# for i_batch, sampled_batched in enumerate(tqdm(train_loader)): -# ori_image, new_image, illu, new_palette, ori_palette = sampled_batched -# palette = new_palette.flatten() -# c1, c2, c3, c4 = FE.forward(ori_image.float().to(device)) -# out = RD.forward(c1, c2, c3, c4, palette.float().to(device), illu.float().to(device)) -# -# optimizer.zero_grad() -# loss = criterion(out, new_image.float().to(device)) -# loss.backward() -# optimizer.step() -# total_loss += loss.item() -# print(e, total_loss) -# -# if total_loss < min_loss: -# min_loss = total_loss -# state = { -# 'epoch': e, -# 'FE': FE.state_dict(), -# 'RD': RD.state_dict(), -# 'optimizer': optimizer.state_dict(), -# } -# torch.save(state, "/home/jovyan/work/saved_models/FE_RD.pth") -# -# -# # In[17]: -# -# -# # load model from saved model file -# state = torch.load("/home/jovyan/work/saved_models/FE_RD.pth") -# FE = FeatureEncoder().float().to(device) -# RD = RecoloringDecoder().float().to(device) -# FE.load_state_dict(state['FE']) -# RD.load_state_dict(state['RD']) -# optimizer.load_state_dict(state['optimizer']) -# -# for i_batch, sampled_batched in enumerate(train_loader): -# ori_image, new_image, illu, new_palette, ori_palette = sampled_batched -# flat_palette = new_palette.flatten() -# c1, c2, c3, c4 = FE.forward(ori_image.float().to(device)) -# out = RD.forward(c1, c2, c3, c4, flat_palette.float().to(device), illu.float().to(device)) -# break -# -# idx = 3 -# viz_image_ori_new_out(ori_image[idx], new_palette[idx], new_image[idx], out[idx]) -# -# -# # ## Adversarial Training -# -# # In[18]: -# -# -# # discriminator model for adversarial training -# class Discriminator(nn.Module): -# def __init__(self, input_channel): -# super().__init__() -# self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=4, stride=2) -# self.norm1 = nn.InstanceNorm2d(64) -# self.conv2 = nn.Conv2d(64, 64, kernel_size=4, stride=2) -# self.norm2 = nn.InstanceNorm2d(64) -# self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=2) -# self.norm3 = nn.InstanceNorm2d(64) -# self.conv4 = nn.Conv2d(64, 64, kernel_size=4, stride=2) -# self.norm4 = nn.InstanceNorm2d(64) -# self.linear = nn.Linear(25600, 1) -# -# def forward(self, x): -# x = F.leaky_relu(self.norm1(self.conv1(x))) -# x = F.leaky_relu(self.norm2(self.conv2(x))) -# x = F.leaky_relu(self.norm3(self.conv3(x))) -# x = F.leaky_relu(self.norm4(self.conv4(x))) -# x = x.view(x.size(0), -1) # flatten -# x = self.linear(x) -# return torch.sigmoid(x) -# -# -# # In[19]: -# -# -# # load model from saved model file -# state = torch.load("/home/jovyan/work/saved_models/FE_RD.pth") -# FE = FeatureEncoder().float().to(device) -# RD = RecoloringDecoder().float().to(device) -# FE.load_state_dict(state['FE']) -# RD.load_state_dict(state['RD']) -# optimizer.load_state_dict(state['optimizer']) -# -# # freeze FE -# for param in FE.parameters(): -# param.requires_grad = False -# -# -# # In[12]: -# -# -# def merge(img, palette_1d): -# """ -# replicating Palette spatially and concatenating in depth to the image. -# necessary for the adversarial training -# """ -# img = img.to(device) -# palette_1d = palette_1d.to(device) -# -# palette_1d = palette_1d.flatten() -# bz, h, w = img.shape[0], img.shape[2], img.shape[3] -# palettes = torch.ones(bz, 18, h, w).float().to(device) -# palettes = palettes.reshape(h, w, bz * 18) * palette_1d -# palettes = palettes.permute(2, 0, 1).reshape(bz, 18, h, w) -# -# # concatenate target_palettes with c1 -# x = torch.cat((img.float(), palettes.float()), 1) -# return x -# -# -# # In[20]: -# -# -# bz = 16 -# epoches = 1000 -# lr = 0.0002 -# -# # pre-processsing -# transform = transforms.Compose([ -# transforms.ToPILImage(), -# transforms.Resize((432, 288)), -# transforms.ToTensor(), -# ]) -# -# train_data = ColorTransferDataset(pathlib.Path("/home/jovyan/work/data/train"), transform) -# train_loader = DataLoader(train_data, batch_size=bz) -# -# D = Discriminator(input_channel=21).to(device) -# adversarial_loss = torch.nn.BCELoss() -# optimizer_G = torch.optim.Adam(RD.parameters(), betas=(0.5, 0.999)) -# optimizer_D = torch.optim.Adam(D.parameters(), betas=(0.5, 0.999)) -# train_loader = DataLoader(train_data, batch_size=bz) -# -# -# # In[ ]: -# -# -# Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor -# lambda_MSE_loss = 10 -# -# min_g_loss = float('inf') -# for e in range(epoches): -# total_g_loss = 0 -# total_d_loss = 0 -# for i_batch, sampled_batched in enumerate(tqdm(train_loader)): -# ori_image, new_image, illu, new_palette, ori_palette = sampled_batched -# flat_palette = new_palette.flatten() -# c1, c2, c3, c4 = FE.forward(ori_image.float().to(device)) -# out_image = RD.forward(c1, c2, c3, c4, flat_palette.float().to(device), illu.float().to(device)) -# -# valid = Variable(Tensor(ori_image.size(0), 1).fill_(1.0), requires_grad=False) -# fake = Variable(Tensor(new_image.size(0), 1).fill_(0.0), requires_grad=False) -# -# ori_i_new_p = merge(ori_image, new_palette) -# ori_i_ori_p = merge(ori_image, ori_palette) -# out_i_ori_p = merge(out_image, ori_palette) -# out_i_new_p = merge(out_image, new_palette) -# -# # generator loss -# optimizer_G.zero_grad() -# g_loss = adversarial_loss(D(out_i_new_p), valid) + lambda_MSE_loss * criterion(out_image, new_image.float().to(device)) -# g_loss.backward() -# optimizer_G.step() -# -# # discriminator loss -# real_loss = adversarial_loss(D(ori_i_ori_p), valid) -# fake_loss = adversarial_loss(D(ori_i_new_p.detach()), fake) + adversarial_loss(D(out_i_ori_p.detach()), fake) + adversarial_loss(D(out_i_new_p.detach()), fake) -# -# optimizer_D.zero_grad() -# d_loss = (real_loss + fake_loss) / 4 -# d_loss.backward() -# optimizer_D.step() -# -# total_g_loss += g_loss -# total_d_loss += d_loss -# -# if total_g_loss < min_g_loss: -# min_g_loss = total_g_loss -# state = { -# 'epoch': e, -# 'FE': FE.state_dict(), -# 'RD': RD.state_dict(), -# 'D': D.state_dict(), -# 'optimizer': optimizer.state_dict(), -# } -# torch.save(state, "/home/jovyan/work/saved_models/adv_FE_RD.pth") -# -# print(f"{e}: Generator loss of {total_g_loss:.3f}; Discrimator loss of {total_d_loss:.3f}") -# -# -# -# # In[21]: -# -# -# # load model from saved model file -# state = torch.load("/home/jovyan/work/saved_models/adv_FE_RD.pth") -# FE = FeatureEncoder().float().to(device) -# RD = RecoloringDecoder().float().to(device) -# FE.load_state_dict(state['FE']) -# RD.load_state_dict(state['RD']) -# optimizer.load_state_dict(state['optimizer']) -# -# for i_batch, sampled_batched in enumerate(train_loader): -# ori_image, new_image, illu, new_palette, ori_palette = sampled_batched -# print(ori_image.shape, illu.shape, new_palette.shape) -# flat_palette = new_palette.flatten() -# c1, c2, c3, c4 = FE.forward(ori_image.float().to(device)) -# out = RD.forward(c1, c2, c3, c4, flat_palette.float().to(device), illu.float().to(device)) -# break -# -# idx = 3 -# viz_image_ori_new_out(ori_image[idx], new_palette[idx], new_image[idx], out[idx]) -# -# -# # In[ ]: -# + cv2.imwrite(f'data/train/input/{path}', img) + pickle.dump(palette, open(f'data/train/old_palette/{path_stem}.pkl', 'wb')) + cv2.imwrite(f'data/train/output/{path}', augmented_image) + pickle.dump(augmented_palette, open(f'data/train/new_palette/{path_stem}.pkl', 'wb'))