+
+
+
+
+ Given clean line drawings, rough sketches or photographs of arbitrary resolution as input, our framework generates the corresponding vector line drawings directly. As shown in (b), the framework models a virtual pen surrounded by a dynamic window (red boxes), which moves while drawing the strokes. It learns to move around by scaling the window and sliding to an undrawn area for restarting the drawing (bottom example; sliding trajectory in blue arrow). With our proposed stroke regularization mechanism, the framework is able to enlarge the window and draw long strokes for simplicity (top example).
+
+
+
+
+
+
+
+
+
Abstract
+
+ Vector line art plays an important role in graphic design, however, it is tedious to manually create.
+ We introduce a general framework to produce line drawings from a wide variety of images,
+ by learning a mapping from raster image space to vector image space.
+ Our approach is based on a recurrent neural network that draws the lines one by one.
+ A differentiable rasterization module allows for training with only supervised raster data.
+ We use a dynamic window around a virtual pen while drawing lines,
+ implemented with a proposed aligned cropping and differentiable pasting modules.
+ Furthermore, we develop a stroke regularization loss
+ that encourages the model to use fewer and longer strokes to simplify the resulting vector image.
+ Ablation studies and comparisons with existing methods corroborate the efficiency of our approach
+ which is able to generate visually better results in less computation time,
+ while generalizing better to a diversity of images and applications.
+
+
+
+
+
+ Our framework generates the parametrized strokes step by step in a recurrent manner.
+ It uses a dynamic window (dashed red boxes) around a virtual pen to draw the strokes,
+ and can both move and change the size of the window.
+ (a) Four main modules at each time step: aligned cropping, stroke generation, differentiable rendering and differentiable pasting.
+ (b) Architecture of the stroke generation module.
+ (c) Structural strokes predicted at each step;
+ movement only is illustrated by blue arrows during which no stroke is drawn on the canvas.
+
+
+
+
+
+@article{mo2021virtualsketching,
+ title = {General Virtual Sketching Framework for Vector Line Art},
+ author = {Mo, Haoran and Simo-Serra, Edgar and Gao, Chengying and Zou, Changqing and Wang, Ruomei},
+ journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH 2021)},
+ year = {2021},
+ volume = {40},
+ number = {4},
+ pages = {51:1--51:14}
+}
+
+
+
+
Related Work
+
+
+ Jean-Dominique Favreau, Florent Lafarge and Adrien Bousseau.
+ Fidelity vs. Simplicity: a Global Approach to Line Drawing Vectorization. SIGGRAPH 2016.
+ [Paper]
+ [Webpage]
+
+
+
+
+ Mikhail Bessmeltsev and Justin Solomon.
+ Vectorization of Line Drawings via PolyVector Fields. SIGGRAPH 2019.
+ [Paper]
+ [Code]
+
+
+
+
+ Edgar Simo-Serra, Satoshi Iizuka and Hiroshi Ishikawa.
+ Mastering Sketching: Adversarial Augmentation for Structured Prediction. SIGGRAPH 2018.
+ [Paper]
+ [Webpage]
+ [Code]
+
+
+
+
+ Zhewei Huang, Wen Heng and Shuchang Zhou.
+ Learning to Paint With Model-based Deep Reinforcement Learning. ICCV 2019.
+ [Paper]
+ [Code]
+
+
+
+
+
+
+
+
+
diff --git a/hi-arm/qmupd_vs/draw_tools.py b/hi-arm/qmupd_vs/draw_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd699d1fb3ac09125fe060ae8565fb903e2837e5
--- /dev/null
+++ b/hi-arm/qmupd_vs/draw_tools.py
@@ -0,0 +1,657 @@
+import os
+import cv2
+from matplotlib import pyplot as plt
+import numpy as np
+from IPython.display import clear_output
+from scipy.interpolate import splprep, splev
+import shutil
+import glob
+import time
+import sys
+import numpy as np
+from PIL import Image
+import tensorflow as tf
+import cv2
+from utils import get_colors, draw, image_pasting_v3_testing
+from model_common_test import DiffPastingV3
+import random
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+def fix_edge_contour(contour, im_shape):
+ """
+ 有时候生成的轮廓点会有一些头部或者尾部紧挨着图像边沿的情况,这样的点位是不需要的,需要过滤掉。
+ 如果轮廓点头部或者尾部紧挨着图像边沿,过滤裁掉该部分的点位
+ """
+ # 将轮廓转换为列表
+ contour = contour.tolist()
+
+ # 检查轮廓的头部点
+ while True:
+ x, y = contour[0][0]
+ if x == 0 or y == 0 or x == (im_shape[1] - 1) or y == (im_shape[0] - 1):
+ del contour[0]
+ else:
+ break
+
+ # 检查轮廓的尾部点
+ while True:
+ x, y = contour[-1][0]
+ if x == 0 or y == 0 or x == (im_shape[1] - 1) or y == (im_shape[0] - 1):
+ del contour[-1]
+ else:
+ break
+
+ # 将轮廓转换回numpy数组
+ contour = np.array(contour)
+ return contour
+
+def getContourList(image, pen_width: int = 3, min_contour_len: int = 30, is_show: bool = False):
+ """
+ 从图像中获取轮廓列表
+ :param image: 图像
+ :param pen_width: 笔的粗细
+ :param min_contour_len: 最短的轮廓长度
+ :param is_show: 是否显示图像
+ :return: 轮廓列表
+ """
+ # 读取图片
+ # im = cv2.imread("../data/1_fake.png",cv2.IMREAD_GRAYSCALE)
+ if image is None:
+ print("Can't read the image file.")
+ return
+ elif len(image.shape) == 3:
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+ elif len(image.shape) == 4:
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2GRAY)
+ # 转换二值化
+ image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)[1]
+
+ # 获取图像线条的绘制顺序,以方便于机器人连续运动绘制图像
+ # Create a copy of the original image to draw contours on
+ image_copy = image.copy()
+
+ image_with_contours = np.full_like(image_copy, 255)
+
+ # Initialize a list to store the contours
+ contour_list = []
+
+ directions = [(0, 1), (0, -1), (1, 0), (-1, 0), (1, 1), (1, -1), (-1, 1), (-1, -1)]
+ sec0 = (0, image_copy.shape[0])
+ sec1 = (sec0[1]-1, sec0[1]+image_copy.shape[1]-1)
+ sec2 = (sec1[1]-1, sec1[1]+image_copy.shape[0]-1)
+ sec3 = (sec2[1]-1, sec2[1]+image_copy.shape[1]-2)
+ while True:
+ # Find contours in the image
+ # 并且找到的轮廓都在黑色的像素上
+ _, contours, _ = cv2.findContours(image_copy, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
+
+ # If no contours are found, break the loop
+ # 没有轮廓需要中止;当图像是全白时,可以检测到一个轮廓,也需要终止
+ if len(contours) == 0 or (len(contours)==1 and np.all(image_copy == 255)):
+ break
+
+ # Remove the border contour
+ # contours = [cnt for cnt in contours if not np.any(cnt == 0) and not np.any(cnt == height-1) and not np.any(cnt == width-1)]
+ # `cv2.findContours`函数在找到轮廓时,实际上是在找到黑色对象(前景)和白色背景之间的边界
+ # 这意味着轮廓的坐标可能不会精确地落在原始图像的黑色像素上,而是在黑色和白色像素之间。
+ # 如果你希望轮廓精确地落在黑色像素上,需要对`cv2.findContours`的结果进行一些后处理。例如,遍历轮廓的每个点,然后将它们的坐标向最近的黑色像素进行取整。
+ # 避免后续在擦除时,并没有擦除原有图像的黑色像素
+ print(f"pen width: {pen_width}")
+ if pen_width == 1:
+ for contour in contours:
+ for point in contour:
+ x, y = point[0]
+ if image_copy[y, x] == 255:
+ for dx, dy in directions:
+ nx, ny = x + dx, y + dy
+ if nx >= 0 and ny >= 0 and nx < image_copy.shape[1] and ny < image_copy.shape[0]:
+ if image_copy[ny, nx] == 0:
+ point[0][0] = nx
+ point[0][1] = ny
+ break
+
+ cv2.drawContours(image_with_contours, contours, -1, 0, 1)
+ # erase the exist contours
+ cv2.drawContours(image_copy, contours, -1, 255, pen_width)
+ # add contours to list
+ # Sort the elements in contours according to the length of the elements.
+ # The longest contour is at the front, which is convenient for subsequent drawing and can be drawn first.
+
+ # remove the contour when the contour is the box of image
+ contours = list(contours)
+ max_len = 0
+ for i in reversed(range(len(contours))):
+ # 太短的也不要
+ if len(contours[i]) < min_contour_len:
+ contours.pop(i)
+ continue
+ # 将画四个边框的轮廓去掉
+ if (len(contours[i]) >= ( image_with_contours.shape[0]*2 + image_with_contours.shape[0]*2 - 4) and \
+ (contours[i][sec0[0]:sec0[1], :, 0] == 0).all() and \
+ (contours[i][sec1[0]:sec1[1], :, 1] == image_with_contours.shape[0]-1).all() and \
+ (contours[i][sec2[0]:sec2[1], :, 0] == image_with_contours.shape[1]-1).all() and \
+ (contours[i][sec3[0]:sec3[1], :, 1] == 0).all()):
+ contours.pop(i)
+ continue
+ contours.sort(key=lambda x: x.shape[0], reverse=True)
+ contour_list.extend(contours)
+ if is_show:
+ # show the image with the drawn contours
+ # Clear the previous plot
+ clear_output(wait=True)
+
+ plt.subplot(1,3,1)
+ plt.imshow(image, cmap='gray', vmin=0, vmax=255)
+
+ plt.subplot(1,3,2)
+ plt.imshow(image_copy, cmap='gray', vmin=0, vmax=255)
+
+ plt.subplot(1,3,3)
+ # Show the image with the current contour
+ plt.imshow(image_with_contours, cmap='gray', vmin=0, vmax=255)
+ plt.show()
+ for i in reversed(range(len(contour_list))):
+ contour = contour_list[i]
+ contour = fix_edge_contour(contour=contour, im_shape=image.shape)
+ if len(contour) < min_contour_len:
+ contour_list.pop(i)
+ return contour_list
+
+def sortContoursList(contour_list):
+ """
+ 根据以下规则排序:
+ 1. 先从最长的1/3个轮廓中,挑选出最长的一些轮廓(大致1/5的轮廓)
+ 2. 以上一个轮廓的终点为准,找到剩下轮廓中,起点与该点位最近的距离排序
+ """
+ contour_list.sort(key=lambda x: x.shape[0], reverse=True)
+ # 数量太少,直接返回排序后的轮廓列表,不需要太多策略
+ if len(contour_list) <= 10:
+ return contour_list
+ origin_count = len(contour_list)
+ # 1. 先从最长的1/3个轮廓中,随机选出一些轮廓(大致1/2的轮廓),
+ # 这样画尝的轮廓容易先画出来图像的大体轮廓。另外,随机一下,是为了避免每次都是画同样或者相似的轮廓
+ tmp_contour_list = contour_list[:int(len(contour_list)/3)]
+ np.random.shuffle(tmp_contour_list)
+ tmp_contour_list = tmp_contour_list[:int(len(tmp_contour_list)/2)]
+ for contour in tmp_contour_list:
+ for i in reversed(range(len(contour_list))):
+ if contour_list[i] is contour:
+ contour_list.pop(i)
+ break
+ ret_contour_list = tmp_contour_list
+ # 2. 以上一个轮廓的终点为准,找到剩下轮廓中,起点与该点位最近的距离排序
+ count = len(tmp_contour_list)
+ while (count < origin_count):
+ # 找到最后一个轮廓的终点
+ last_contour = ret_contour_list[-1]
+ last_point = last_contour[-1][0]
+ # 找到剩下轮廓中,起点与该点位最近的距离排序
+ min_index = -1
+ min_distance = 999999999
+ for i in range(len(contour_list)):
+ # print(contour_list[i].shape)
+ first_point = contour_list[i][0][0]
+ distance = (first_point[0] - last_point[0])**2 + (first_point[1] - last_point[1])**2
+ if distance < min_distance:
+ min_distance = distance
+ min_index = i
+ ret_contour_list.append(contour_list[min_index])
+ contour_list.pop(min_index)
+ count += 1
+ return ret_contour_list
+
+def remove_overlap_and_near_contours(contours_list, image_size, extend_pixel , near_threshold=0.5, min_contour_length=10):
+ """
+ 移除重叠及过近的轮廓
+ :param contours_list: 轮廓列表
+ :param image_size: 图像大小
+ :param extend_pixel: 扩展像素
+ :param near_threshold: 过近阈值
+ """
+ # 思路:模拟画图,如果后面的轮廓与前面的轮廓重叠或者过近,那么就不画
+ circle_lookup = np.zeros((extend_pixel*2+1, extend_pixel*2+1), dtype=np.bool_)
+ for i in range(-extend_pixel, extend_pixel+1):
+ for j in range(-extend_pixel, extend_pixel+1):
+ if (i**2 + j**2) <= extend_pixel**2:
+ circle_lookup[i, j] = True
+ map = np.zeros((image_size[0], image_size[1]), dtype=np.bool_)
+ new_contours_list = []
+ for contour in contours_list:
+ # 太短的轨迹不画
+ if len(contour) < min_contour_length:
+ continue
+ # 画图
+ contour_length = len(contour)
+ overlap_length = 0
+ for point in contour:
+ x, y = int(point[0][0]),int(point[0][1])
+ # 统计重叠度
+ if (map[x, y] == True):
+ overlap_length += 1
+ # 与原来重叠度比较高,则去掉,这条轨迹不画了。
+ if overlap_length / contour_length >= near_threshold:
+ continue
+ else:
+ # 去掉长度为0的轮廓
+ if (len(contour) > 0):
+ new_contours_list.append(np.array(contour))
+ else:
+ print("==========contour length is 0, in position 3")
+ # new_contours_list.append(np.array(contour))
+ # 把当前轨迹经过的像素都在map中进行标记,以便于后续查询需要
+ for point in contour:
+ x, y = int(point[0][0]),int(point[0][1])
+ for i in range(-extend_pixel, extend_pixel+1):
+ for j in range(-extend_pixel, extend_pixel+1):
+ if circle_lookup[i, j]:
+ if x+i >= 0 and x+i < image_size[0] and y+j >= 0 and y+j < image_size[1]:
+ map[x+i, y+j] = True
+ return new_contours_list
+
+
+def sample_and_smooth_contours(contour_list, interval: int = 5):
+ """
+ 采样并平滑拟合轮廓
+ :param contour_list: 轮廓列表
+ :param interval: 采样间隔
+ :return: 平滑拟合并采样后的轮廓列表。注意为浮点的数组
+ """
+ f_contour_list = []
+ for contour in contour_list:
+ # 对contour中的点进行B样条进行拟合,然后平滑和重采样,
+ # Fit a B-spline to the contour
+ if (contour[0] == contour[-1]).all():
+ contour = contour.reshape(-1, 2)
+ tck, u = splprep(contour.T, w=None, u=None, ue=None, k=3, task=0, s=1.0, t=None, full_output=0, nest=None, per=1, quiet=1)
+ else:
+ contour = contour.reshape(-1, 2)
+ tck, u = splprep(contour.T, w=None, u=None, ue=None, k=3, task=0, s=1.0, t=None, full_output=0, nest=None, per=0, quiet=1)
+ # 设置重采样的点数
+ num = contour.shape[0] // interval
+ u_new = np.linspace(u.min(), u.max(), num)
+ x_new, y_new = splev(u_new, tck, der=0)
+ f_contour = np.array([x_new, y_new]).T.reshape(-1, 1, 2)
+ f_contour_list.append(f_contour)
+ return f_contour_list
+
+
+def save_contour_points(contour_list, filepath):
+ """
+ 保存轮廓点到文件中,每个轮廓占一行,x和y坐标用逗号分割,点之间用逗号分割
+ Usage:
+ save_contour_points(f_contour_list, "../data/1_fake_data.txt")
+ """
+ dirname = os.path.dirname(filepath)
+ if (not os.path.exists(dirname)):
+ os.makedirs(dirname)
+ with open(filepath, "w") as f:
+ for contour in contour_list:
+ for point in contour:
+ x, y = point[0]
+ f.write(f"{x},{y},")
+ f.write("\n")
+
+
+def load_contours_list(filename):
+ contours_list = []
+ with open(filename, "r") as f:
+ for line in f:
+ points = line.strip().split(",")
+ # 去处最后一个空字符
+ if points[-1] == '':
+ points = points[:-1]
+ contour = []
+ for i in range(0, len(points), 2):
+ x, y = float(points[i]), float(points[i+1])
+ contour.append(np.array([[x, y]]))
+ # 去掉长度为0的轮廓
+ if (len(contour) > 0):
+ contours_list.append(np.array(contour))
+ print(f"Load {len(contours_list)} contours.")
+ return contours_list
+
+def generate_style_image(image_name, dataroot, output_dir):
+ # plt.imsave("./data/input.jpg", image)
+ # shutil.copy("../data/input.jpg", "../../QMUPD/examples/input.jpg")
+ start_time = time.time()
+ # curr_path = os.getcwd()
+ #================== settings ==================
+ # style_root = "../../QMUPD/"
+ # os.chdir(style_root)
+
+ exp = 'QMUPD_model'
+ epoch='200'
+ gpu_id = '-1'
+ netga = 'resnet_style2_9blocks'
+ model0_res = 0
+ model1_res = 0
+ imgsize = 512
+ extraflag = ' --netga %s --model0_res %d --model1_res %d' % (netga, model0_res, model1_res)
+ base_image = os.path.splitext(os.path.basename(image_name))[0]
+ # 生成风格图像
+ # im = draw_tools.generate_style_image(image)
+ # cv2.imshow('image', image)
+ # cv2.waitKey(0)
+ # cv2.destroyAllWindows()
+ # 临时方案,把图像移动到dataset中
+ if not os.path.exists(dataroot):
+ os.makedirs(dataroot)
+ else:
+ # 清空
+ files = glob.glob(f'%s*' % dataroot)
+ for f in files:
+ os.remove(f)
+ # copy
+ shutil.copy(image_name, dataroot)
+
+ # 清空结果
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ else:
+ # 清空
+ files = glob.glob(f'%s*' % output_dir)
+ for f in files:
+ os.remove(f)
+
+ #==================== command ==================
+ vec = [0,1,0]
+ svec = '%d,%d,%d' % (vec[0],vec[1],vec[2])
+ img1 = 'imagesstyle%d-%d-%d'%(vec[0],vec[1],vec[2])
+ print('results/%s/test_%s/index%s.html'%(exp,epoch,img1[6:]))
+ command = 'python3 qmupd_single_image.py --dataroot %s --name %s --model test --output_nc 1 --no_dropout --model_suffix _A %s --num_test 1000 --epoch %s --style_control 1 --imagefolder %s --sinput svec --svec %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,extraflag,epoch,img1,svec,imgsize,imgsize,gpu_id)
+ os.system(command)
+ return os.path.join(output_dir, f'{base_image}_fake.png')
+
+
+def display_strokes_final(sess, pasting_func, data, init_cursor, image_size, infer_lengths, init_width,
+ save_base,
+ cursor_type='next', min_window_size=32, raster_size=128):
+ """
+ :param data: (N_strokes, 9): flag, x0, y0, x1, y1, x2, y2, r0, r2
+ :return:
+ """
+ canvas = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke]
+ canvas2_temp = np.zeros((image_size, image_size), dtype=np.float32) # [0.0-BG, 1.0-stroke]
+ drawn_region = np.zeros_like(canvas)
+ overlap_region = np.zeros_like(canvas)
+ canvas_color_with_overlap = np.zeros((image_size, image_size, 3), dtype=np.float32)
+ canvas_color_wo_overlap = np.zeros((image_size, image_size, 3), dtype=np.float32)
+ canvas_color_with_moving = np.zeros((image_size, image_size, 3), dtype=np.float32)
+
+ cursor_idx = 0
+
+ if init_cursor.ndim == 1:
+ init_cursor = [init_cursor]
+
+ stroke_count = len(data)
+ color_rgb_set = get_colors(stroke_count) # list of (3,) in [0, 255]
+ color_idx = 0
+
+ valid_stroke_count = stroke_count - np.sum(data[:, 0]).astype(np.int32) + len(init_cursor)
+ valid_color_rgb_set = get_colors(valid_stroke_count) # list of (3,) in [0, 255]
+ valid_color_idx = -1
+ # print('Drawn stroke number', valid_stroke_count)
+ # print(' flag x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2')
+
+ # tempimage = np.zeros((image_size, image_size, 3), dtype=np.uint8) + 255
+ # color = random.randint(50, 120)
+ # cv2.imshow('canvas_rgb', tempimage)
+ contours_list = []
+ for round_idx in range(len(infer_lengths)):
+ contour = []
+ round_length = infer_lengths[round_idx]
+
+ cursor_pos = init_cursor[cursor_idx] # (2)
+ cursor_idx += 1
+ prev_width = init_width
+ prev_scaling = 1.0
+ prev_window_size = float(raster_size) # (1)
+ # cv2.imshow('canvas_rgb', canvas_black)
+ # 每个笔画
+ last_point = None
+ for round_inner_i in range(round_length):
+ stroke_idx = np.sum(infer_lengths[:round_idx]).astype(np.int32) + round_inner_i
+
+ curr_window_size_raw = prev_scaling * prev_window_size
+ curr_window_size_raw = np.maximum(curr_window_size_raw, min_window_size)
+ curr_window_size_raw = np.minimum(curr_window_size_raw, image_size)
+
+ pen_state = data[stroke_idx, 0]
+ stroke_params = data[stroke_idx, 1:] # (8)
+ x1y1, x2y2, width2, scaling2 = stroke_params[0:2], stroke_params[2:4], stroke_params[4], stroke_params[5]
+ x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
+ x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
+ x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
+ widths = np.stack([prev_width, width2], axis=0) # (2)
+ stroke_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1) # (8)
+
+ next_width = stroke_params[4]
+ next_scaling = stroke_params[5]
+ next_window_size = next_scaling * curr_window_size_raw
+ next_window_size = np.maximum(next_window_size, min_window_size)
+ next_window_size = np.minimum(next_window_size, image_size)
+
+ prev_width = next_width * curr_window_size_raw / next_window_size
+ prev_scaling = next_scaling
+ prev_window_size = curr_window_size_raw
+
+ f = stroke_params_proc.tolist() # (8)
+ f += [1.0, 1.0]
+ gt_stroke_img, contour_deatil = draw(f) # (H, W), [0.0-stroke, 1.0-BG]
+ # print("stroke image", contour)
+ # contour = cursor_pos * image_size + contour
+ # cv2.imshow('canvas_stroke', gt_stroke_img)
+ # print("gt_stroke_img shape:", gt_stroke_img.shape)
+ # cv2.waitKey(30)
+ gt_stroke_img_large = image_pasting_v3_testing(1.0 - gt_stroke_img, cursor_pos,
+ image_size,
+ curr_window_size_raw,
+ pasting_func, sess) # [0.0-BG, 1.0-stroke]
+ # print("gt_stroke_img_large shape:", gt_stroke_img_large.shape)
+ is_overlap = False
+
+ if pen_state == 0:
+ canvas += gt_stroke_img_large # [0.0-BG, 1.0-stroke]
+ # print("canvas shape:", canvas.shape)
+ # cv2.imshow('canvas_rgb_lager', canvas)
+ # cv2.waitKey(30)
+ curr_drawn_stroke_region = np.zeros_like(gt_stroke_img_large)
+ curr_drawn_stroke_region[gt_stroke_img_large > 0.5] = 1
+ intersection = drawn_region * curr_drawn_stroke_region
+ # regard stroke with >50% overlap area as overlaped stroke
+ if np.sum(intersection) / np.sum(curr_drawn_stroke_region) > 0.5:
+ # enlarge the stroke a bit for better visualization
+ overlap_region[gt_stroke_img_large > 0] += 1
+ is_overlap = True
+
+ drawn_region[gt_stroke_img_large > 0.5] = 1
+
+ color_rgb = color_rgb_set[color_idx] # (3) in [0, 255]
+ color_idx += 1
+
+ color_rgb = np.reshape(color_rgb, (1, 1, 3)).astype(np.float32)
+ color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - color_rgb / 255.0)
+ canvas_color_with_moving = canvas_color_with_moving * np.expand_dims((1.0 - gt_stroke_img_large),
+ axis=-1) + color_stroke # (H, W, 3)
+ if pen_state == 0:
+ valid_color_idx += 1
+
+ if pen_state == 0:
+ valid_color_rgb = valid_color_rgb_set[valid_color_idx] # (3) in [0, 255]
+ # valid_color_idx += 1
+
+ valid_color_rgb = np.reshape(valid_color_rgb, (1, 1, 3)).astype(np.float32)
+ valid_color_stroke = np.expand_dims(gt_stroke_img_large, axis=-1) * (1.0 - valid_color_rgb / 255.0)
+ canvas_color_with_overlap = canvas_color_with_overlap * np.expand_dims((1.0 - gt_stroke_img_large),
+ axis=-1) + valid_color_stroke # (H, W, 3)
+ if not is_overlap:
+ canvas_color_wo_overlap = canvas_color_wo_overlap * np.expand_dims((1.0 - gt_stroke_img_large),
+ axis=-1) + valid_color_stroke # (H, W, 3)
+
+ # update cursor_pos based on hps.cursor_type
+ new_cursor_offsets = stroke_params[2:4] * (float(curr_window_size_raw) / 2.0) # (1, 6), patch-level
+ new_cursor_offset_next = new_cursor_offsets
+
+ # important!!!
+ new_cursor_offset_next = np.concatenate([new_cursor_offset_next[1:2], new_cursor_offset_next[0:1]], axis=-1)
+
+ cursor_pos_large = cursor_pos * float(image_size)
+
+ stroke_position_next = cursor_pos_large + new_cursor_offset_next # (2), large-level
+
+ if cursor_type == 'next':
+ cursor_pos_large = stroke_position_next # (2), large-level
+ else:
+ raise Exception('Unknown cursor_type')
+
+ cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0), float(image_size - 1)) # (2), large-level
+ if (pen_state == 0):
+ # cursor_pos_fact = int(cursor_pos * float(image_size) + 0.5)
+ cursor_pos_fact = np.minimum(np.maximum(cursor_pos * float(image_size), 0.0), float(image_size - 1))
+ # 假如超出边界
+ # cv2.circle(canvas2_temp, (int(cursor_pos_fact[0]), int(cursor_pos_fact[1])), 2, (255, 0, 0), 1)
+ # cv2.line(canvas2_temp, (int(cursor_pos_fact[0]), int(cursor_pos_fact[1])), (int(cursor_pos_large[0]), int(cursor_pos_large[1])), (255, 0, 0), 1)
+ # 有起点, 终点, 和轨迹
+ if (last_point is not None):
+ # 如果这一笔的笔画琪点和上一笔的笔画不在同一个位置
+ if ((int(cursor_pos_fact[0]) != int(last_point[0]) or int(cursor_pos_fact[1]) != int(last_point[1]))):
+ # 如果距离比较接近,也算同一个轨迹上面(减少机械臂抬手动作)
+ if (np.linalg.norm(cursor_pos_fact - last_point) > 2):
+ # print("add contour and new one")
+ # 去掉长度为0的轮廓
+ if (len(contour) > 0):
+ contours_list.append(np.array(contour))
+ else:
+ print("==========contour length is 0, in position 1")
+ # contours_list.append(np.array(contour))
+ contour = []
+
+ for x in contour_deatil:
+ # x[0] 转为 np.array
+ x = np.array(x)
+ point_pos = (x[0] - 128) * curr_window_size_raw / 256 + cursor_pos_fact
+ point_pos[0] = min(point_pos[0], image_size - 1)
+ point_pos[1] = min(point_pos[1], image_size - 1)
+ # 去重
+ if (last_point is not None):
+ if (int(point_pos[0]) != int(last_point[0]) or int(point_pos[1]) != int(last_point[1])):
+ contour.append(np.array([[point_pos[0], point_pos[1]]]))
+ last_point = point_pos
+ cv2.circle(canvas2_temp, (int(point_pos[0]), int(point_pos[1])), 1, (255, 255, 0), 1)
+ else:
+ contour.append(np.array([[point_pos[0], point_pos[1]]]))
+ last_point = point_pos
+ cv2.circle(canvas2_temp, (int(point_pos[0]), int(point_pos[1])), 1, (255, 255, 0), 1)
+
+ # print(len(contour))
+ # cv2.circle(canvas2_temp, (int(point_pos[0]), int(point_pos[1])), 1, (255, 255, 0), 1)
+ # break
+ # break
+ # print("cursor_pos_fact:", contour)
+ # cv2.imshow('canvas_rgb', canvas2_temp)
+ # cv2.waitKey(30)
+
+ cursor_pos = cursor_pos_large / float(image_size)
+
+ # print(int(cursor_pos[0] * image_size), int(cursor_pos[1] * image_size))
+ # 在对应位置画个点
+ # tempimage = cv2.circle(tempimage, (int(cursor_pos[0] * image_size), int(cursor_pos[1] * image_size)), 2, (color, color, color) , 1)
+ # cv2.imshow('canvas_rgb', tempimage)
+ # cv2.waitKey(30)
+ # if (pen_state == 0):
+ # contour.append([[cursor_pos[0] * image_size, cursor_pos[1] * image_size]])
+ # 去掉长度为0的轮廓
+ if (len(contour) > 0):
+ contours_list.append(np.array(contour))
+ # canvas_rgb = np.stack([np.clip(canvas, 0.0, 1.0) for _ in range(3)], axis=-1)
+ canvas_color_with_overlap = 255 - np.round(canvas_color_with_overlap * 255.0).astype(np.uint8)
+ canvas_color_wo_overlap = 255 - np.round(canvas_color_wo_overlap * 255.0).astype(np.uint8)
+ canvas_color_with_moving = 255 - np.round(canvas_color_with_moving * 255.0).astype(np.uint8)
+
+ canvas_color_png = Image.fromarray(canvas_color_with_overlap, 'RGB')
+ canvas_color_save_path = os.path.join(save_base, 'output_order_with_overlap.png')
+ canvas_color_png.save(canvas_color_save_path, 'PNG')
+ return contours_list
+
+def drawContours(contours_list, cavas_size):
+ image = np.zeros(cavas_size, dtype=np.uint8) + 255
+ for contour in contours_list:
+ # color = random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)
+ color = (0, 0, 0)
+ for i in range(len(contour)):
+ point = contour[i]
+ if i < len(contour) - 1:
+ # cv2.line(image, tuple(contour[i][0]), tuple(contour[i+1][0]), color, 1)
+ cv2.circle(image, (int(point[0][0]), int(point[0][1])), 1, color, 1)
+ return image
+
+
+def getContourList_v2(npz_path):
+ assert npz_path != ''
+
+ min_window_size = 32
+ raster_size = 128
+
+ split_idx = npz_path.rfind('/')
+ if split_idx == -1:
+ file_base = './'
+ file_name = npz_path[:-4]
+ else:
+ file_base = npz_path[:npz_path.rfind('/')]
+ file_name = npz_path[npz_path.rfind('/') + 1: -4]
+
+ regenerate_base = os.path.join(file_base, file_name)
+ os.makedirs(regenerate_base, exist_ok=True)
+
+ # differentiable pasting graph
+ paste_v3_func = DiffPastingV3(raster_size)
+
+ tfconfig = tf.ConfigProto()
+ tfconfig.gpu_options.allow_growth = True
+ sess = tf.InteractiveSession(config=tfconfig)
+ sess.run(tf.global_variables_initializer())
+
+ data = np.load(npz_path, encoding='latin1', allow_pickle=True)
+ strokes_data = data['strokes_data']
+ init_cursors = data['init_cursors']
+ image_size = data['image_size']
+ round_length = data['round_length']
+ init_width = data['init_width']
+ if round_length.ndim == 0:
+ round_lengths = [round_length]
+ else:
+ round_lengths = round_length
+ print('Processing ...')
+ contours_list = display_strokes_final(sess, paste_v3_func,
+ strokes_data, init_cursors, image_size, round_lengths, init_width,
+ regenerate_base,
+ min_window_size=min_window_size, raster_size=raster_size)
+ return contours_list
+# # mian
+# if __name__ == "__main__":
+# # 读取图片
+# im = cv2.imread("../data/1_fake.png",cv2.IMREAD_GRAYSCALE)
+# # 获取轮廓列表
+# contour_list = getContourList(im, is_show=True)
+# # 对轮廓列表进行排序
+# contour_list = sortContoursList(contour_list)
+# # 平滑拟合并采样轮廓
+# f_contour_list = sample_and_smooth_contours(im, contour_list, is_show=True)
+# # 保存轮廓点到文件中,每个轮廓占一行,x和y坐标用逗号分割,点之间用逗号分割
+# save_contour_points(f_contour_list, "../data/1_fake_data.txt")
+
+
+
+
+if __name__ == '__main__':
+ file = "./robot_data/contour_points/image_e1b3f4a3-08f1-4d52-ab40-c5badf38b46e_fake_contour_points.txt"
+ contours_lists = load_contours_list(file)
+ contours_lists = sortContoursList(contours_lists)
+ cv2.imshow("sorted", drawContours(contours_lists, (512, 512,3)))
+ contours_lists = remove_overlap_and_near_contours(contours_lists, (512, 512), 3, 0.9, 5)
+ # contours_lists = remove_overlap_and_near_contours(contours_lists, (512, 512), 4, 0.7)
+ cv2.imshow("remove overlap", drawContours(contours_lists, (512, 512,3)))
+ # save_contour_points(contours_lists, "./image_e1b3f4a3-08f1-4d52-ab40-c5badf38b46e_fake_0_contour_points_sorted.txt")
+ #contours_lists = sample_and_smooth_contours(contours_lists, 10)
+ cv2.imshow("sample and smooth", drawContours(contours_lists, (512, 512,3)))
+ cv2.waitKey(0)
\ No newline at end of file
diff --git a/hi-arm/qmupd_vs/environment.yaml b/hi-arm/qmupd_vs/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c962535006e16726b9db9e434d26c8240b49700a
--- /dev/null
+++ b/hi-arm/qmupd_vs/environment.yaml
@@ -0,0 +1,115 @@
+name: vsketch
+channels:
+ - pytorch
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
+dependencies:
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=2_kmp_llvm
+ - blas=1.0=mkl
+ - ca-certificates=2024.3.11=h06a4308_0
+ - cairo=1.14.8=0
+ - certifi=2016.2.28=py36_0
+ - cpuonly=2.0=0
+ - cudatoolkit=10.0.130=0
+ - cycler=0.10.0=py36_0
+ - dbus=1.10.20=0
+ - dominate=2.4.0=py_0
+ - expat=2.1.0=0
+ - fftw=3.3.9=h5eee18b_2
+ - fontconfig=2.12.1=3
+ - freetype=2.5.5=2
+ - glib=2.50.2=1
+ - gst-plugins-base=1.8.0=0
+ - gstreamer=1.8.0=0
+ - hdf5=1.10.2=hc401514_3
+ - icu=54.1=0
+ - jbig=2.1=0
+ - jpeg=9b=0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - libblas=3.9.0=1_h6e990d7_netlib
+ - libcblas=3.9.0=3_h893e4fe_netlib
+ - libffi=3.4.4=h6a678d5_0
+ - libgcc=7.2.0=h69d50b8_2
+ - libgcc-ng=13.2.0=h807b86a_5
+ - libgfortran=3.0.0=1
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
+ - libgfortran4=7.5.0=ha8ba4b0_17
+ - libgomp=13.2.0=h807b86a_5
+ - libiconv=1.14=0
+ - liblapack=3.9.0=3_h893e4fe_netlib
+ - libopenblas=0.3.18=hf726d26_0
+ - libpng=1.6.39=h5eee18b_0
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtiff=4.0.6=3
+ - libwebp-base=1.3.2=h5eee18b_0
+ - libxcb=1.12=1
+ - libxml2=2.9.4=0
+ - llvm-openmp=14.0.6=h9e868ea_0
+ - lz4-c=1.9.4=h6a678d5_0
+ - matplotlib=2.0.2=np113py36_0
+ - mkl=2017.0.3=0
+ - ncurses=6.4=h6a678d5_0
+ - olefile=0.46=pyhd3eb1b0_0
+ - opencv=3.4.1=py36h6fd60c2_1
+ - openssl=1.0.2l=0
+ - pip=21.3.1
+ - pcre=8.39=1
+ - pillow=4.2.1=py36_0
+ - pixman=0.34.0=0
+ - pyparsing=2.2.0=py36_0
+ - pyqt=5.6.0=py36_2
+ - python=3.6.2=0
+ - python-dateutil=2.6.1=py36_0
+ - python_abi=3.6=2_cp36m
+ - pytorch-mutex=1.0=cpu
+ - pytz=2017.2=py36_0
+ - qt=5.6.2=5
+ - readline=6.2=2
+ - scipy=0.19.1=np113py36_0
+ - setuptools=36.4.0=py36_1
+ - sip=4.18=py36_0
+ - sqlite=3.13.0=0
+ - tk=8.5.18=0
+ - wheel=0.29.0=py36_0
+ - xz=5.2.3=0
+ - zlib=1.2.13=h5eee18b_0
+ - zstd=1.3.3=h84994c4_0
+ - pip:
+ - absl-py==1.4.0
+ - astor==0.8.1
+ - cached-property==1.5.2
+ - cairocffi==1.0.0
+ - cffi==1.15.1
+ - dataclasses==0.8
+ - gast==0.5.4
+ - gizeh==0.1.11
+ - grpcio==1.48.2
+ - h5py==3.1.0
+ - importlib-metadata==4.8.3
+ - importlib-resources==5.4.0
+ - keras-applications==1.0.8
+ - keras-preprocessing==1.1.2
+ - markdown==3.3.7
+ - munch==4.0.0
+ - numpy==1.17.0
+ - opencv-python==3.4.2.16
+ - pip==21.3.1
+ - pretrainedmodels==0.7.4
+ - protobuf==3.19.6
+ - pycparser==2.21
+ - six==1.16.0
+ - tensorboard==1.12.2
+ - tensorflow==1.12.0
+ - termcolor==1.1.0
+ - torch==1.2.0+cpu
+ - torchvision==0.4.0+cpu
+ - tqdm==4.64.1
+ - typing-extensions==4.1.1
+ - werkzeug==2.0.3
+ - zipp==3.6.0
+prefix: /home/qian/anaconda3/envs/vsketch
diff --git a/hi-arm/qmupd_vs/examples/celebahq-11103.jpg b/hi-arm/qmupd_vs/examples/celebahq-11103.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..02c594956835579c76f00cf41dca9d803d56f4d4
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-11103.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-11918.jpg b/hi-arm/qmupd_vs/examples/celebahq-11918.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..21852a3f266515c979a2b74b3c52b17dbe341940
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-11918.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-15556.jpg b/hi-arm/qmupd_vs/examples/celebahq-15556.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..12503163767fefec021d37e593009992ad87799c
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-15556.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-25033.jpg b/hi-arm/qmupd_vs/examples/celebahq-25033.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..536e7fc045052f7a51e908b21d582485d8eda519
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-25033.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-2524.jpg b/hi-arm/qmupd_vs/examples/celebahq-2524.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b145b96861f531eb6003e198445742694295ac73
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-2524.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-26036.jpg b/hi-arm/qmupd_vs/examples/celebahq-26036.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1053b05098a6a66216ce14daedea0eee114737ae
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-26036.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-27799.jpg b/hi-arm/qmupd_vs/examples/celebahq-27799.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bfc7c6467d9e13b7fbd69237326c150a2554323d
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-27799.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-4797.jpg b/hi-arm/qmupd_vs/examples/celebahq-4797.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cf18592f8eb43ba7f80232291b4d3da51bda62b4
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-4797.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-7235.jpg b/hi-arm/qmupd_vs/examples/celebahq-7235.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1f58117681d314434e9a0480afff9ab0a21e2839
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-7235.jpg differ
diff --git a/hi-arm/qmupd_vs/examples/celebahq-896.jpg b/hi-arm/qmupd_vs/examples/celebahq-896.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cfa58028e8fca031925d8178a2751906b3d79fa2
Binary files /dev/null and b/hi-arm/qmupd_vs/examples/celebahq-896.jpg differ
diff --git a/hi-arm/qmupd_vs/hyper_parameters.py b/hi-arm/qmupd_vs/hyper_parameters.py
new file mode 100644
index 0000000000000000000000000000000000000000..66a3fa9f938d0b35f09a84811cb0058cda94e6a6
--- /dev/null
+++ b/hi-arm/qmupd_vs/hyper_parameters.py
@@ -0,0 +1,341 @@
+import tensorflow as tf
+
+
+#############################################
+# Common parameters
+#############################################
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string(
+ 'dataset_dir',
+ 'datasets',
+ 'The directory of sketch data of the dataset.')
+tf.app.flags.DEFINE_string(
+ 'log_root',
+ 'outputs/log',
+ 'Directory to store tensorboard.')
+tf.app.flags.DEFINE_string(
+ 'log_img_root',
+ 'outputs/log_img',
+ 'Directory to store intermediate output images.')
+tf.app.flags.DEFINE_string(
+ 'snapshot_root',
+ 'outputs/snapshot',
+ 'Directory to store model checkpoints.')
+tf.app.flags.DEFINE_string(
+ 'neural_renderer_path',
+ 'outputs/snapshot/pretrain_neural_renderer/renderer_300000.tfmodel',
+ 'Path to the neural renderer model.')
+tf.app.flags.DEFINE_string(
+ 'perceptual_model_root',
+ 'outputs/snapshot/pretrain_perceptual_model',
+ 'Directory to store perceptual model.')
+tf.app.flags.DEFINE_string(
+ 'data',
+ '',
+ 'The dataset type.')
+
+
+def get_default_hparams_clean():
+ """Return default HParams for sketch-rnn."""
+ hparams = tf.contrib.training.HParams(
+ program_name='new_train_clean_line_drawings',
+ data_set='clean_line_drawings', # Our dataset.
+
+ input_channel=1,
+
+ num_steps=75040, # Total number of steps of training.
+ save_every=75000,
+ eval_every=5000,
+
+ max_seq_len=48,
+ batch_size=20,
+ gpus=[0, 1],
+ loop_per_gpu=1,
+
+ sn_loss_type='increasing', # ['decreasing', 'fixed', 'increasing']
+ stroke_num_loss_weight=0.02,
+ stroke_num_loss_weight_end=0.0,
+ increase_start_steps=25000,
+ decrease_stop_steps=40000,
+
+ perc_loss_layers=['ReLU1_2', 'ReLU2_2', 'ReLU3_3', 'ReLU5_1'],
+ perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum']
+
+ init_cursor_on_undrawn_pixel=False,
+
+ early_pen_loss_type='move', # ['head', 'tail', 'move']
+ early_pen_loss_weight=0.1,
+ early_pen_length=7,
+
+ min_width=0.01,
+ min_window_size=32,
+ max_scaling=2.0,
+
+ encode_cursor_type='value',
+
+ image_size_small=128,
+ image_size_large=278,
+
+ cropping_type='v3', # ['v2', 'v3']
+ pasting_type='v3', # ['v2', 'v3']
+ pasting_diff=True,
+
+ concat_win_size=True,
+
+ encoder_type='conv13_c3',
+ # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
+ # ['conv13_c3_attn']
+ # ['combine33', 'combine43', 'combine53', 'combineFC']
+ vary_thickness=False,
+
+ outside_loss_weight=10.0,
+ win_size_outside_loss_weight=10.0,
+
+ resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']
+
+ concat_cursor=True,
+
+ use_softargmax=True,
+ soft_beta=10, # value for the soft argmax
+
+ raster_loss_weight=1.0,
+
+ dec_rnn_size=256, # Size of decoder.
+ dec_model='hyper', # Decoder: lstm, layer_norm or hyper.
+ # z_size=128, # Size of latent vector z. Recommend 32, 64 or 128.
+ bin_gt=True,
+
+ stop_accu_grad=True,
+
+ random_cursor=True,
+ cursor_type='next',
+
+ raster_size=128,
+
+ pix_drop_kp=1.0, # Dropout keep rate
+ add_coordconv=True,
+ position_format='abs',
+ raster_loss_base_type='perceptual', # [l1, mse, perceptual]
+
+ grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0.
+
+ learning_rate=0.0001, # Learning rate.
+ decay_rate=0.9999, # Learning rate decay per minibatch.
+ decay_power=0.9,
+ min_learning_rate=0.000001, # Minimum learning rate.
+
+ use_recurrent_dropout=True, # Dropout with memory loss. Recommended
+ recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep.
+ use_input_dropout=False, # Input dropout. Recommend leaving False.
+ input_dropout_prob=0.90, # Probability of input dropout keep.
+ use_output_dropout=False, # Output dropout. Recommend leaving False.
+ output_dropout_prob=0.90, # Probability of output dropout keep.
+
+ model_mode='train' # ['train', 'eval', 'sample']
+ )
+ return hparams
+
+
+def get_default_hparams_rough():
+ """Return default HParams for sketch-rnn."""
+ hparams = tf.contrib.training.HParams(
+ program_name='new_train_rough_sketches',
+ data_set='rough_sketches', # ['rough_sketches', 'faces']
+
+ input_channel=3,
+
+ num_steps=90040, # Total number of steps of training.
+ save_every=90000,
+ eval_every=5000,
+
+ max_seq_len=48,
+ batch_size=20,
+ gpus=[0, 1],
+ loop_per_gpu=1,
+
+ sn_loss_type='increasing', # ['decreasing', 'fixed', 'increasing']
+ stroke_num_loss_weight=0.1,
+ stroke_num_loss_weight_end=0.0,
+ increase_start_steps=25000,
+ decrease_stop_steps=40000,
+
+ photo_prob_type='one', # ['increasing', 'zero', 'one']
+ photo_prob_start_step=35000,
+
+ perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU5_1'],
+ perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum']
+
+ early_pen_loss_type='move', # ['head', 'tail', 'move']
+ early_pen_loss_weight=0.2,
+ early_pen_length=7,
+
+ min_width=0.01,
+ min_window_size=32,
+ max_scaling=2.0,
+
+ encode_cursor_type='value',
+
+ image_size_small=128,
+ image_size_large=278,
+
+ cropping_type='v3', # ['v2', 'v3']
+ pasting_type='v3', # ['v2', 'v3']
+ pasting_diff=True,
+
+ concat_win_size=True,
+
+ encoder_type='conv13_c3',
+ # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
+ # ['conv13_c3_attn']
+ # ['combine33', 'combine43', 'combine53', 'combineFC']
+
+ outside_loss_weight=10.0,
+ win_size_outside_loss_weight=10.0,
+
+ resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']
+
+ concat_cursor=True,
+
+ use_softargmax=True,
+ soft_beta=10, # value for the soft argmax
+
+ raster_loss_weight=1.0,
+
+ dec_rnn_size=256, # Size of decoder.
+ dec_model='hyper', # Decoder: lstm, layer_norm or hyper.
+ # z_size=128, # Size of latent vector z. Recommend 32, 64 or 128.
+ bin_gt=True,
+
+ stop_accu_grad=True,
+
+ random_cursor=True,
+ cursor_type='next',
+
+ raster_size=128,
+
+ pix_drop_kp=1.0, # Dropout keep rate
+ add_coordconv=True,
+ position_format='abs',
+ raster_loss_base_type='perceptual', # [l1, mse, perceptual]
+
+ grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0.
+
+ learning_rate=0.0001, # Learning rate.
+ decay_rate=0.9999, # Learning rate decay per minibatch.
+ decay_power=0.9,
+ min_learning_rate=0.000001, # Minimum learning rate.
+
+ use_recurrent_dropout=True, # Dropout with memory loss. Recommended
+ recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep.
+ use_input_dropout=False, # Input dropout. Recommend leaving False.
+ input_dropout_prob=0.90, # Probability of input dropout keep.
+ use_output_dropout=False, # Output dropout. Recommend leaving False.
+ output_dropout_prob=0.90, # Probability of output dropout keep.
+
+ model_mode='train' # ['train', 'eval', 'sample']
+ )
+ return hparams
+
+
+def get_default_hparams_normal():
+ """Return default HParams for sketch-rnn."""
+ hparams = tf.contrib.training.HParams(
+ program_name='new_train_faces',
+ data_set='faces', # ['rough_sketches', 'faces']
+
+ input_channel=3,
+
+ num_steps=90040, # Total number of steps of training.
+ save_every=90000,
+ eval_every=5000,
+
+ max_seq_len=48,
+ batch_size=20,
+ gpus=[0, 1],
+ loop_per_gpu=1,
+
+ sn_loss_type='fixed', # ['decreasing', 'fixed', 'increasing']
+ stroke_num_loss_weight=0.0,
+ stroke_num_loss_weight_end=0.0,
+ increase_start_steps=0,
+ decrease_stop_steps=40000,
+
+ photo_prob_type='interpolate', # ['increasing', 'zero', 'one', 'interpolate']
+ photo_prob_start_step=30000,
+ photo_prob_end_step=60000,
+
+ perc_loss_layers=['ReLU2_2', 'ReLU3_3', 'ReLU4_2', 'ReLU5_1'],
+ perc_loss_fuse_type='add', # ['max', 'add', 'raw_add', 'weighted_sum']
+
+ early_pen_loss_type='move', # ['head', 'tail', 'move']
+ early_pen_loss_weight=0.2,
+ early_pen_length=7,
+
+ min_width=0.01,
+ min_window_size=32,
+ max_scaling=2.0,
+
+ encode_cursor_type='value',
+
+ image_size_small=128,
+ image_size_large=256,
+
+ cropping_type='v3', # ['v2', 'v3']
+ pasting_type='v3', # ['v2', 'v3']
+ pasting_diff=True,
+
+ concat_win_size=True,
+
+ encoder_type='conv13_c3',
+ # ['conv10', 'conv10_deep', 'conv13', 'conv10_c3', 'conv10_deep_c3', 'conv13_c3']
+ # ['conv13_c3_attn']
+ # ['combine33', 'combine43', 'combine53', 'combineFC']
+
+ outside_loss_weight=10.0,
+ win_size_outside_loss_weight=10.0,
+
+ resize_method='AREA', # ['BILINEAR', 'NEAREST_NEIGHBOR', 'BICUBIC', 'AREA']
+
+ concat_cursor=True,
+
+ use_softargmax=True,
+ soft_beta=10, # value for the soft argmax
+
+ raster_loss_weight=1.0,
+
+ dec_rnn_size=256, # Size of decoder.
+ dec_model='hyper', # Decoder: lstm, layer_norm or hyper.
+ # z_size=128, # Size of latent vector z. Recommend 32, 64 or 128.
+ bin_gt=True,
+
+ stop_accu_grad=True,
+
+ random_cursor=True,
+ cursor_type='next',
+
+ raster_size=128,
+
+ pix_drop_kp=1.0, # Dropout keep rate
+ add_coordconv=True,
+ position_format='abs',
+ raster_loss_base_type='perceptual', # [l1, mse, perceptual]
+
+ grad_clip=1.0, # Gradient clipping. Recommend leaving at 1.0.
+
+ learning_rate=0.0001, # Learning rate.
+ decay_rate=0.9999, # Learning rate decay per minibatch.
+ decay_power=0.9,
+ min_learning_rate=0.000001, # Minimum learning rate.
+
+ use_recurrent_dropout=True, # Dropout with memory loss. Recommended
+ recurrent_dropout_prob=0.90, # Probability of recurrent dropout keep.
+ use_input_dropout=False, # Input dropout. Recommend leaving False.
+ input_dropout_prob=0.90, # Probability of input dropout keep.
+ use_output_dropout=False, # Output dropout. Recommend leaving False.
+ output_dropout_prob=0.90, # Probability of output dropout keep.
+
+ model_mode='train' # ['train', 'eval', 'sample']
+ )
+ return hparams
diff --git a/hi-arm/qmupd_vs/main.py b/hi-arm/qmupd_vs/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaa3f12c95d784b4d7c31fe6cfa46de1c55ca3f4
--- /dev/null
+++ b/hi-arm/qmupd_vs/main.py
@@ -0,0 +1,574 @@
+
+from camera_tools import CameraApp
+import draw_tools
+import cv2
+import os
+from options.test_options import TestOptions
+from data import create_dataset
+from models import create_model
+from util.visualizer import save_images
+import shutil
+import os, glob
+import warnings
+import util
+import paramiko
+
+#================== settings ==================
+exp = 'QMUPD_model'
+epoch='200'
+dataroot = 'robot_data/dataset/'
+gpu_id = '-1'
+netga = 'resnet_style2_9blocks'
+model0_res = 0
+model1_res = 0
+imgsize = 512
+extraflag = ' --netga %s --model0_res %d --model1_res %d' % (netga, model0_res, model1_res)
+output_dir = 'robot_data/output/'
+
+import numpy as np
+import os
+import tensorflow as tf
+from six.moves import range
+from PIL import Image
+import argparse
+
+import hyper_parameters as hparams
+from model_common_test import DiffPastingV3, VirtualSketchingModel
+from utils import reset_graph, load_checkpoint, update_hyperparams, draw, \
+ save_seq_data, image_pasting_v3_testing, draw_strokes
+from dataset_utils import load_dataset_testing
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
+
+
+def move_cursor_to_undrawn(current_pos_list, input_image_, patch_size,
+ move_min_dist, move_max_dist, trial_times=20):
+ """
+ :param current_pos_list: (select_times, 1, 2), [0.0, 1.0)
+ :param input_image_: (1, image_size, image_size, 3), [0-stroke, 1-BG]
+ :return: new_cursor_pos: (select_times, 1, 2), [0.0, 1.0)
+ """
+
+ def crop_patch(image, center, image_size, crop_size):
+ x0 = center[0] - crop_size // 2
+ x1 = x0 + crop_size
+ y0 = center[1] - crop_size // 2
+ y1 = y0 + crop_size
+ x0 = max(0, min(x0, image_size))
+ y0 = max(0, min(y0, image_size))
+ x1 = max(0, min(x1, image_size))
+ y1 = max(0, min(y1, image_size))
+ patch = image[y0:y1, x0:x1]
+ return patch
+
+ def isvalid_cursor(input_img, cursor, raster_size, image_size):
+ # input_img: (image_size, image_size, 3), [0.0-BG, 1.0-stroke]
+ cursor_large = cursor * float(image_size)
+ cursor_large = np.round(cursor_large).astype(np.int32)
+ input_crop_patch = crop_patch(input_img, cursor_large, image_size, raster_size)
+ if np.sum(input_crop_patch) > 0.0:
+ return True
+ else:
+ return False
+
+ def randomly_move_cursor(cursor_position, img_size, min_dist_p, max_dist_p):
+ # cursor_position: (2), [0.0, 1.0)
+ cursor_pos_large = cursor_position * img_size
+ min_dist = int(min_dist_p / 2.0 * img_size)
+ max_dist = int(max_dist_p / 2.0 * img_size)
+ rand_cursor_offset = np.random.randint(min_dist, max_dist, size=cursor_pos_large.shape)
+ rand_cursor_offset_sign = np.random.randint(0, 1 + 1, size=cursor_pos_large.shape)
+ rand_cursor_offset_sign[rand_cursor_offset_sign == 0] = -1
+ rand_cursor_offset = rand_cursor_offset * rand_cursor_offset_sign
+
+ new_cursor_pos_large = cursor_pos_large + rand_cursor_offset
+ new_cursor_pos_large = np.minimum(np.maximum(new_cursor_pos_large, 0), img_size - 1) # (2), large-level
+ new_cursor_pos = new_cursor_pos_large.astype(np.float32) / float(img_size)
+ return new_cursor_pos
+
+ input_image = 1.0 - input_image_[0] # (image_size, image_size, 3), [0-BG, 1-stroke]
+ img_size = input_image.shape[0]
+
+ new_cursor_pos = []
+ for cursor_i in range(current_pos_list.shape[0]):
+ curr_cursor = current_pos_list[cursor_i][0]
+
+ for trial_i in range(trial_times):
+ new_cursor = randomly_move_cursor(curr_cursor, img_size, move_min_dist, move_max_dist) # (2), [0.0, 1.0)
+
+ if isvalid_cursor(input_image, new_cursor, patch_size, img_size) or trial_i == trial_times - 1:
+ new_cursor_pos.append(new_cursor)
+ break
+
+ assert len(new_cursor_pos) == current_pos_list.shape[0]
+ new_cursor_pos = np.expand_dims(np.stack(new_cursor_pos, axis=0), axis=1) # (select_times, 1, 2), [0.0, 1.0)
+ return new_cursor_pos
+
+
+def sample(sess, model, input_photos, init_cursor, image_size, init_len, seq_lens,
+ state_dependent, pasting_func, round_stop_state_num,
+ min_dist_p, max_dist_p):
+ """Samples a sequence from a pre-trained model."""
+ select_times = 1
+ curr_canvas = np.zeros(dtype=np.float32,
+ shape=(select_times, image_size, image_size)) # [0.0-BG, 1.0-stroke]
+
+ initial_state = sess.run(model.initial_state)
+
+ params_list = [[] for _ in range(select_times)]
+ state_raw_list = [[] for _ in range(select_times)]
+ state_soft_list = [[] for _ in range(select_times)]
+ window_size_list = [[] for _ in range(select_times)]
+
+ round_cursor_list = []
+ round_length_real_list = []
+
+ input_photos_tiles = np.tile(input_photos, (select_times, 1, 1, 1))
+
+ for cursor_i, seq_len in enumerate(seq_lens):
+ if cursor_i == 0:
+ cursor_pos = np.squeeze(init_cursor, axis=0) # (select_times, 1, 2)
+ else:
+ cursor_pos = move_cursor_to_undrawn(cursor_pos, input_photos, model.hps.raster_size,
+ min_dist_p, max_dist_p) # (select_times, 1, 2)
+ round_cursor_list.append(cursor_pos)
+
+ prev_state = initial_state
+ prev_width = np.stack([model.hps.min_width for _ in range(select_times)], axis=0)
+ prev_scaling = np.ones((select_times), dtype=np.float32) # (N)
+ prev_window_size = np.ones((select_times), dtype=np.float32) * model.hps.raster_size # (N)
+
+ continuous_one_state_num = 0
+
+ for i in range(seq_len):
+ if not state_dependent and i % init_len == 0:
+ prev_state = initial_state
+
+ curr_window_size = prev_scaling * prev_window_size # (N)
+ curr_window_size = np.maximum(curr_window_size, model.hps.min_window_size)
+ curr_window_size = np.minimum(curr_window_size, image_size)
+
+ feed = {
+ model.initial_state: prev_state,
+ model.input_photo: input_photos_tiles,
+ model.curr_canvas_hard: curr_canvas.copy(),
+ model.cursor_position: cursor_pos,
+ model.image_size: image_size,
+ model.init_width: prev_width,
+ model.init_scaling: prev_scaling,
+ model.init_window_size: prev_window_size,
+ }
+
+ o_other_params_list, o_pen_list, o_pred_params_list, next_state_list = \
+ sess.run([model.other_params, model.pen_ras, model.pred_params, model.final_state], feed_dict=feed)
+ # o_other_params: (N, 6), o_pen: (N, 2), pred_params: (N, 1, 7), next_state: (N, 1024)
+ # o_other_params: [tanh*2, sigmoid*2, tanh*2, sigmoid*2]
+
+ idx_eos_list = np.argmax(o_pen_list, axis=1) # (N)
+
+ output_i = 0
+ idx_eos = idx_eos_list[output_i]
+
+ eos = [0, 0]
+ eos[idx_eos] = 1
+
+ other_params = o_other_params_list[output_i].tolist() # (6)
+ params_list[output_i].append([eos[1]] + other_params)
+ state_raw_list[output_i].append(o_pen_list[output_i][1])
+ state_soft_list[output_i].append(o_pred_params_list[output_i, 0, 0])
+ window_size_list[output_i].append(curr_window_size[output_i])
+
+ # draw the stroke and add to the canvas
+ x1y1, x2y2, width2 = o_other_params_list[output_i, 0:2], o_other_params_list[output_i, 2:4], \
+ o_other_params_list[output_i, 4]
+ x0y0 = np.zeros_like(x2y2) # (2), [-1.0, 1.0]
+ x0y0 = np.divide(np.add(x0y0, 1.0), 2.0) # (2), [0.0, 1.0]
+ x2y2 = np.divide(np.add(x2y2, 1.0), 2.0) # (2), [0.0, 1.0]
+ widths = np.stack([prev_width[output_i], width2], axis=0) # (2)
+ o_other_params_proc = np.concatenate([x0y0, x1y1, x2y2, widths], axis=-1).tolist() # (8)
+
+ if idx_eos == 0:
+ f = o_other_params_proc + [1.0, 1.0]
+ pred_stroke_img, _ = draw(f) # (raster_size, raster_size), [0.0-stroke, 1.0-BG]
+ pred_stroke_img_large = image_pasting_v3_testing(1.0 - pred_stroke_img,
+ cursor_pos[output_i, 0],
+ image_size,
+ curr_window_size[output_i],
+ pasting_func, sess) # [0.0-BG, 1.0-stroke]
+ curr_canvas[output_i] += pred_stroke_img_large # [0.0-BG, 1.0-stroke]
+
+ continuous_one_state_num = 0
+ else:
+ continuous_one_state_num += 1
+
+ curr_canvas = np.clip(curr_canvas, 0.0, 1.0)
+
+ next_width = o_other_params_list[:, 4] # (N)
+ next_scaling = o_other_params_list[:, 5]
+ next_window_size = next_scaling * curr_window_size # (N)
+ next_window_size = np.maximum(next_window_size, model.hps.min_window_size)
+ next_window_size = np.minimum(next_window_size, image_size)
+
+ prev_state = next_state_list
+ prev_width = next_width * curr_window_size / next_window_size # (N,)
+ prev_scaling = next_scaling # (N)
+ prev_window_size = curr_window_size
+
+ # update cursor_pos based on hps.cursor_type
+ new_cursor_offsets = o_other_params_list[:, 2:4] * (
+ np.expand_dims(curr_window_size, axis=-1) / 2.0) # (N, 2), patch-level
+ new_cursor_offset_next = new_cursor_offsets
+
+ # important!!!
+ new_cursor_offset_next = np.concatenate([new_cursor_offset_next[:, 1:2], new_cursor_offset_next[:, 0:1]],
+ axis=-1)
+
+ cursor_pos_large = cursor_pos * float(image_size)
+ stroke_position_next = cursor_pos_large[:, 0, :] + new_cursor_offset_next # (N, 2), large-level
+
+ if model.hps.cursor_type == 'next':
+ cursor_pos_large = stroke_position_next # (N, 2), large-level
+ else:
+ raise Exception('Unknown cursor_type')
+
+ cursor_pos_large = np.minimum(np.maximum(cursor_pos_large, 0.0),
+ float(image_size - 1)) # (N, 2), large-level
+ cursor_pos_large = np.expand_dims(cursor_pos_large, axis=1) # (N, 1, 2)
+ cursor_pos = cursor_pos_large / float(image_size)
+
+ if continuous_one_state_num >= round_stop_state_num or i == seq_len - 1:
+ round_length_real_list.append(i + 1)
+ break
+
+ return params_list, state_raw_list, state_soft_list, curr_canvas, window_size_list, \
+ round_cursor_list, round_length_real_list
+
+
+def main_testing(test_image_base_dir, test_dataset, test_image_name,
+ sampling_base_dir, model_base_dir, model_name,
+ sampling_num,
+ min_dist_p, max_dist_p,
+ longer_infer_lens, round_stop_state_num,
+ draw_seq=False, draw_order=False,
+ state_dependent=True):
+ model_params_default = hparams.get_default_hparams_rough()
+ model_params = update_hyperparams(model_params_default, model_base_dir, model_name, infer_dataset=test_dataset)
+
+ [test_set, eval_hps_model, sample_hps_model] = \
+ load_dataset_testing(test_image_base_dir, test_dataset, test_image_name, model_params)
+
+ test_image_raw_name = test_image_name[:test_image_name.find('.')]
+ model_dir = os.path.join(model_base_dir, model_name)
+
+ reset_graph()
+ sampling_model = VirtualSketchingModel(sample_hps_model)
+
+ # differentiable pasting graph
+ paste_v3_func = DiffPastingV3(sample_hps_model.raster_size)
+
+ tfconfig = tf.ConfigProto()
+ tfconfig.gpu_options.allow_growth = True
+ sess = tf.InteractiveSession(config=tfconfig)
+ sess.run(tf.global_variables_initializer())
+
+ # loads the weights from checkpoint into our model
+ snapshot_step = load_checkpoint(sess, model_dir, gen_model_pretrain=True)
+ print('snapshot_step', snapshot_step)
+ sampling_dir = os.path.join(sampling_base_dir, test_dataset + '__' + model_name)
+ os.makedirs(sampling_dir, exist_ok=True)
+
+ for sampling_i in range(sampling_num):
+ input_photos, init_cursors, test_image_size = test_set.get_test_image()
+ # input_photos: (1, image_size, image_size, 3), [0-stroke, 1-BG]
+ # init_cursors: (N, 1, 2), in size [0.0, 1.0)
+
+ print()
+ print(test_image_name, ', image_size:', test_image_size, ', sampling_i:', sampling_i)
+ print('Processing ...')
+
+ if init_cursors.ndim == 3:
+ init_cursors = np.expand_dims(init_cursors, axis=0)
+
+ input_photos = input_photos[0:1, :, :, :]
+
+ ori_img = (input_photos.copy()[0] * 255.0).astype(np.uint8)
+ ori_img_png = Image.fromarray(ori_img, 'RGB')
+ ori_img_png.save(os.path.join(sampling_dir, test_image_raw_name + '_input.png'), 'PNG')
+
+ # decoding for sampling
+ strokes_raw_out_list, states_raw_out_list, states_soft_out_list, pred_imgs_out, \
+ window_size_out_list, round_new_cursors, round_new_lengths = sample(
+ sess, sampling_model, input_photos, init_cursors, test_image_size,
+ eval_hps_model.max_seq_len, longer_infer_lens, state_dependent, paste_v3_func,
+ round_stop_state_num, min_dist_p, max_dist_p)
+ # pred_imgs_out: (N, H, W), [0.0-BG, 1.0-stroke]
+
+ print('## round_lengths:', len(round_new_lengths), ':', round_new_lengths)
+
+ output_i = 0
+ strokes_raw_out = np.stack(strokes_raw_out_list[output_i], axis=0)
+ states_raw_out = states_raw_out_list[output_i]
+ states_soft_out = states_soft_out_list[output_i]
+ window_size_out = window_size_out_list[output_i]
+
+ multi_cursors = [init_cursors[0, output_i, 0]]
+ for c_i in range(len(round_new_cursors)):
+ best_cursor = round_new_cursors[c_i][output_i, 0] # (2)
+ multi_cursors.append(best_cursor)
+ assert len(multi_cursors) == len(round_new_lengths)
+
+ print('strokes_raw_out', strokes_raw_out.shape)
+
+ clean_states_soft_out = np.array(states_soft_out) # (N)
+
+ flag_list = strokes_raw_out[:, 0].astype(np.int32) # (N)
+ drawing_len = len(flag_list) - np.sum(flag_list)
+ assert drawing_len >= 0
+
+ # print(' flag raw\t soft\t x1\t\t y1\t\t x2\t\t y2\t\t r2\t\t s2')
+ for i in range(strokes_raw_out.shape[0]):
+ flag, x1, y1, x2, y2, r2, s2 = strokes_raw_out[i]
+ win_size = window_size_out[i]
+ out_format = '#%d: %d | %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f'
+ out_values = (i, flag, states_raw_out[i], clean_states_soft_out[i], x1, y1, x2, y2, r2, s2)
+ out_log = out_format % out_values
+ # print(out_log)
+
+ print('Saving results ...')
+ # 保存结果
+ print("================", sampling_dir, test_image_raw_name + '_' + str(sampling_i))
+ save_seq_data(sampling_dir, test_image_raw_name + '_' + str(sampling_i),
+ strokes_raw_out, multi_cursors,
+ test_image_size, round_new_lengths, eval_hps_model.min_width)
+
+ draw_strokes(strokes_raw_out, sampling_dir, test_image_raw_name + '_' + str(sampling_i) + '_pred.png',
+ ori_img, test_image_size,
+ multi_cursors, round_new_lengths, eval_hps_model.min_width, eval_hps_model.cursor_type,
+ sample_hps_model.raster_size, sample_hps_model.min_window_size,
+ sess,
+ pasting_func=paste_v3_func,
+ save_seq=draw_seq, draw_order=draw_order)
+
+
+def generate_simple_order_line(model_name, test_image_name, sampling_num):
+ test_dataset = 'rough_sketches'
+ # test_image_base_dir = 'sample_inputs'
+ # test_image_base_dir = 'results/QMUPD_model/test_200/imagesstyle0-0-1'
+ test_image_base_dir = './'
+ sampling_base_dir = 'robot_data/sampling'
+ model_base_dir = 'outputs/snapshot'
+
+ state_dependent = False
+ longer_infer_lens = [128 for _ in range(10)]
+ round_stop_state_num = 12
+ min_dist_p = 0.3
+ max_dist_p = 0.9
+
+ draw_seq = False
+ draw_color_order = True
+
+ # set numpy output to something sensible
+ np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
+
+ #main_testing(test_image_base_dir, test_dataset, test_image_name,
+ # sampling_base_dir, model_base_dir, model_name, sampling_num,
+ # min_dist_p=min_dist_p, max_dist_p=max_dist_p,
+ # draw_seq=draw_seq, draw_order=draw_color_order,
+ # state_dependent=state_dependent, longer_infer_lens=longer_infer_lens,
+ # round_stop_state_num=round_stop_state_num)
+ main_testing(output_dir, test_dataset, test_image_name,
+ sampling_base_dir, model_base_dir, model_name, sampling_num,
+ min_dist_p=min_dist_p, max_dist_p=max_dist_p,
+ draw_seq=draw_seq, draw_order=draw_color_order,
+ state_dependent=state_dependent, longer_infer_lens=longer_infer_lens,
+ round_stop_state_num=round_stop_state_num)
+
+
+def decode_npz_file(npz_file):
+ data = np.load(npz_file, encoding='latin1', allow_pickle=True)
+ strokes_data = data['strokes_data']
+ init_cursors = data['init_cursors']
+ image_size = data['image_size']
+ round_length = data['round_length']
+ init_width = data['init_width']
+ return strokes_data, init_cursors, image_size, round_length, init_width
+
+def scp_transfer(host, port, username, password, local_path, remote_path):
+ # 创建一个SSH客户端对象
+ ssh = paramiko.SSHClient()
+
+ # 允许连接不在know_hosts文件中的主机
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+
+ # 连接到SSH服务器
+ ssh.connect(host, port, username, password)
+
+ # 使用SSH客户端创建一个SFTP对象
+ sftp = ssh.open_sftp()
+
+ # 使用SFTP的put方法上传文件
+ sftp.put(local_path, remote_path)
+
+ # 关闭SFTP和SSH连接
+ sftp.close()
+ ssh.close()
+
+from flask import Flask, request, send_from_directory
+
+app = Flask(__name__)
+
+@app.route('/upload', methods=['GET', 'POST'])
+def create_upload_file():
+ if request.method == 'POST':
+ print("XXXXXXXXXXXXXXXXXXXXX")
+ file = request.files['file']
+ filename = file.filename
+ file.save(os.path.join(dataroot, filename))
+ return "OK"
+import time
+
+
+def transparence2white(img):
+ sp=img.shape # 获取图片维度
+ width=sp[0] # 宽度
+ height=sp[1] # 高度
+ for yh in range(height):
+ for xw in range(width):
+ color_d=img[xw,yh] # 遍历图像每一个点,获取到每个点4通道的颜色数据
+ if(color_d.size != 4): #如果图片只有三个通道,也是可以正常处理
+ continue
+ if(color_d[3] ==0): # 最后一个通道为透明度,如果其值为0,即图像是透明
+ img[xw,yh]=[255,255,255,255] # 则将当前点的颜色设置为白色,且图像设置为不透明
+ return img
+
+@app.route('/sketch', methods=['POST'])
+def sketch():
+ # 将文件保存到dataroot文件夹下
+ # file = request.files['file']
+ # filename = file.filename
+ # print("XXXXXXXXXXXXXXXXXXXX")
+ # print(filename)
+ image_path = request.form.get("image_path")
+ print("image_path:", image_path)
+ matting_root = "/home/qian/projects/robot_sketch_draw/image-matting"
+ filename = image_path.split('/')[-1]
+ print("将文件存放到input文件夹下")
+ src_path = os.path.join(matting_root + image_path)
+ print("src_path:", src_path)
+ filepath = os.path.join(dataroot,"../input", filename)
+ shutil.copyfile(src_path, filepath)
+
+ png_image = cv2.imread(filepath, cv2.IMREAD_UNCHANGED)
+ filepath=filepath.replace(".png", ".jpg")
+ # 将png背景透明部分设置为白色
+ #png_image[np.where((png_image == [0, 0, 0, 0]).all(axis=2))] = [255, 255, 255, 255]
+ png_image = transparence2white(png_image)
+ # 转为512*512
+ png_image = cv2.resize(png_image, (512, 512))
+ cv2.imwrite(filepath, png_image)
+
+
+ outimage_path = draw_tools.generate_style_image(filepath, dataroot, output_dir)
+ outimage_path = outimage_path.split('/')[-1]
+ # return {
+ # "sketch_image_url": "./robot_data/output/"+outimage_path,
+ # "seq_data_file": None
+ # }
+ # outimage_path = "robot_data/output/1714032527749_fake.png"
+ # print(data)
+ generate_simple_order_line("pretrain_rough_sketches", outimage_path, 1)
+
+ prx = outimage_path.split('.')[0]
+ # out_png_image = os.path.join("robot_data/sampling/rough_sketches__pretrain_rough_sketches/", f"{prx}_0_pred.png")
+ out_png_image = os.path.join("robot_data/contour_images/", f"{prx}.png")
+ seq_data_file = os.path.join("robot_data/sampling/rough_sketches__pretrain_rough_sketches/seq_data/", f"{prx}_0.npz")
+ # strokes_data, init_cursors, image_size, _, _ = decode_npz_file(seq_data_file)
+ contours_list = draw_tools.getContourList_v2(seq_data_file)
+ contours_list = draw_tools.sortContoursList(contours_list)
+ # 这里设置了一些超参数,4是邻域扩展为4,位于调节线条的稀疏性;0.8 为重叠度或者近邻超过0.8的曲线去除。10为保留最短的轮廓长度
+ contours_list = draw_tools.remove_overlap_and_near_contours(contours_list, (512, 512), 3, 0.9, 5)
+ # 绘制一个轮廓线图像,并保存
+ contour_image = draw_tools.drawContours(contours_list, (512, 512,3))
+ # util.mkdirs('robot_data/contour_images')
+ # 平滑和采样
+ #contours_lists = draw_tools.sample_and_smooth_contours(contours_list, 10)
+ # prx = seq_data_file.split('/')[-1].split('.')[0]
+ cv2.imwrite(f"robot_data/contour_images/{prx}.png", contour_image)
+ # prx = prx.split('_')[0:1]
+ draw_tools.save_contour_points(contours_list, f"robot_data/contour_points/{prx}_contour_points.txt")
+ return {
+ "sketch_image_url": out_png_image,
+ "seq_data_file": seq_data_file
+ }
+
+@app.route('/drawing', methods=['GET', 'POST'])
+def drawing():
+ seq_data_file = request.form.get("seq_data_file")
+ print("seq_data_file:", seq_data_file)
+ # TODO: 临时代码,强制转换成contour_points路径
+ # seq_data_file = robot_data/sampling/rough_sketches__pretrain_rough_sketches/seq_data/{prx}_0.npz
+ # 转为contours_path 为 f"robot_data/contour_points/{prx}_contour_points.txt"
+ prx = seq_data_file.split('/')[-1].split('_')[:-1]
+ prx = "_".join(prx)
+ contours_list_path = f"./robot_data/contour_points/{prx}_contour_points.txt"
+ print(contours_list_path)
+ scp_transfer('192.168.253.95', 22, "root", "root", contours_list_path, "/home/robot/Work/system/bspline.txt")
+ return "OK"
+
+@app.route('/')
+def hello():
+ return "hello"
+
+
+@app.route('/files/')
+def serve_file(filename):
+ return send_from_directory('', filename)
+
+if __name__ == '__main__':
+ warnings.filterwarnings("ignore", category=FutureWarning)
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--sample', '-s', type=int, default=1, help="The number of outputs.")
+ parser.add_argument('--name', '-n', type=str, default="", help="The name of the image.")
+ args = parser.parse_args()
+ if args.name == "":
+ app = CameraApp() # 创建 CameraApp 对象,启动程序
+ # # # 获得图像名称
+ # # # image_name = "./robot_data/input/1714032527749.jpg"
+ # # # image = cv2.imread(image_name, cv2.IMREAD_COLOR)
+ image = app.last_photo
+ image_name = app.last_photo_name
+ else:
+ image_name = args.name
+
+ filepath = image_name
+ outimage_path = draw_tools.generate_style_image(filepath, dataroot, output_dir)
+ outimage_path = outimage_path.split('/')[-1]
+ # outimage_path = "robot_data/output/1714032527749_fake.png"
+ # print(data)
+ generate_simple_order_line("pretrain_rough_sketches", outimage_path, 1)
+ prx = outimage_path.split('.')[0]
+ out_png_image = os.path.join("robot_data/sampling/rough_sketches__pretrain_rough_sketches/", f"{prx}_0_pred.png")
+ seq_data_file = os.path.join("robot_data/sampling/rough_sketches__pretrain_rough_sketches/seq_data/", f"{prx}_0.npz")
+ # strokes_data, init_cursors, image_size, _, _ = decode_npz_file(seq_data_file)
+ contours_list = draw_tools.getContourList_v2(seq_data_file)
+ cv2.imshow("origin contours", draw_tools.drawContours(contours_list, (512, 512,3)))
+ contours_list = draw_tools.sortContoursList(contours_list)
+ cv2.imshow("sorted contours", draw_tools.drawContours(contours_list, (512, 512,3)))
+ # 这里设置了一些超参数,4是邻域扩展为4,位于调节线条的稀疏性;0.8 为重叠度或者近邻超过0.8的曲线去除。10为保留最短的轮廓长度
+ contours_list = draw_tools.remove_overlap_and_near_contours(contours_list, (512, 512), 4, 0.7, 10)
+ cv2.imshow("remove overlap contours", draw_tools.drawContours(contours_list, (512, 512,3)))
+ # 平滑和采样
+ #contours_lists = draw_tools.sample_and_smooth_contours(contours_list, 10)
+ # simple_image = cv2.imread(out_png_image)
+ # contours_list = draw_tools.getContourList(simple_image, 4, 100, 1)
+ draw_tools.save_contour_points(contours_list, f"robot_data/contour_points/{prx}_contour_points.txt")
+ cv2.waitKey(0)
+ # return "OK"/
+ # print("image_name:", image_name)
+ #image_name = "robot_data/input/1714032527749.jpg"
+ # # 生成风格图像
+ # import uvicorn
+ # default_bind_host = "0.0.0.0"
+ # uvicorn.run(app, host=default_bind_host, port=8002)
+
+
diff --git a/hi-arm/qmupd_vs/model_common_test.py b/hi-arm/qmupd_vs/model_common_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff5c40dfdf03363b9d14515fe96e5ed9bbd15ce2
--- /dev/null
+++ b/hi-arm/qmupd_vs/model_common_test.py
@@ -0,0 +1,607 @@
+import rnn
+import tensorflow as tf
+
+from subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \
+ generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \
+ generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \
+ generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \
+ generative_cnn_c3_encoder_deeper13_attn
+
+
+class DiffPastingV3(object):
+ def __init__(self, raster_size):
+ self.patch_canvas = tf.compat.v1.placeholder(dtype=tf.float32,
+ shape=(None, None, 1)) # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
+ self.cursor_pos_a = tf.compat.v1.placeholder(dtype=tf.float32, shape=(2)) # (2), float32, in large size
+ self.image_size_a = tf.compat.v1.placeholder(dtype=tf.int32, shape=()) # ()
+ self.window_size_a = tf.compat.v1.placeholder(dtype=tf.float32, shape=()) # (), float32, with grad
+ self.raster_size_a = float(raster_size)
+
+ self.pasted_image = self.image_pasting_sampling_v3()
+ # (image_size, image_size, 1), [0.0-BG, 1.0-stroke]
+
+ def image_pasting_sampling_v3(self):
+ padding_size = tf.cast(tf.ceil(self.window_size_a / 2.0), tf.int32)
+
+ x1y1_a = self.cursor_pos_a - self.window_size_a / 2.0 # (2), float32
+ x2y2_a = self.cursor_pos_a + self.window_size_a / 2.0 # (2), float32
+
+ x1y1_a_floor = tf.floor(x1y1_a) # (2)
+ x2y2_a_ceil = tf.ceil(x2y2_a) # (2)
+
+ cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0 # (2)
+ cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / self.window_size_a * self.raster_size_a # (2)
+ raster_size_b = (x2y2_a_ceil - x1y1_a_floor) # (x, y)
+ image_size_b = self.raster_size_a
+ window_size_b = self.raster_size_a * (raster_size_b / self.window_size_a) # (x, y)
+
+ cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1) # (1)
+
+ y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2.
+ x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2.
+ y2_b = y1_b + (window_size_b[1] - 1.)
+ x2_b = x1_b + (window_size_b[0] - 1.)
+ boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1) # (4)
+ boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32) # with grad to window_size_a
+
+ box_ind_b = tf.ones((1), dtype=tf.int32) # (1)
+ box_ind_b = tf.cumsum(box_ind_b) - 1
+
+ patch_canvas = tf.expand_dims(self.patch_canvas,
+ axis=0) # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
+ boxes_b = tf.expand_dims(boxes_b, axis=0) # (1, 4)
+
+ valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b,
+ crop_size=[raster_size_b[1], raster_size_b[0]])
+ valid_canvas = valid_canvas[0] # (raster_size_b, raster_size_b, 1)
+
+ pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size
+ pad_down = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32)
+ pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size
+ pad_right = self.image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32)
+
+ paddings = [[pad_up, pad_down],
+ [pad_left, pad_right],
+ [0, 0]]
+ pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
+ constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]
+
+ pasted_image = pad_img[padding_size: padding_size + self.image_size_a,
+ padding_size: padding_size + self.image_size_a, :]
+ # (image_size, image_size, 1), [0.0-BG, 1.0-stroke]
+ return pasted_image
+
+
+class VirtualSketchingModel(object):
+ def __init__(self, hps, gpu_mode=True, reuse=False):
+ """Initializer for the model.
+
+ Args:
+ hps: a HParams object containing model hyperparameters
+ gpu_mode: a boolean that when True, uses GPU mode.
+ reuse: a boolean that when true, attemps to reuse variables.
+ """
+ self.hps = hps
+ assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample']
+ # with tf.variable_scope('SCC', reuse=reuse):
+ if not gpu_mode:
+ with tf.device('/cpu:0'):
+ print('Model using cpu.')
+ self.build_model()
+ else:
+ print('-' * 100)
+ print('model_mode:', hps.model_mode)
+ print('Model using gpu.')
+ self.build_model()
+
+ def build_model(self):
+ """Define model architecture."""
+ self.config_model()
+
+ initial_state = self.get_decoder_inputs()
+ self.initial_state = initial_state
+
+ ## use pred as the prev points
+ print(self.image_size)
+ other_params, pen_ras, final_state = self.get_points_and_raster_image(self.image_size)
+
+ # other_params: (N * max_seq_len, 6)
+ # pen_ras: (N * max_seq_len, 2), after softmax
+
+ self.other_params = other_params # (N * max_seq_len, 6)
+ self.pen_ras = pen_ras # (N * max_seq_len, 2), after softmax
+ self.final_state = final_state
+
+ if not self.hps.use_softargmax:
+ pen_state_soft = pen_ras[:, 1:2] # (N * max_seq_len, 1)
+ else:
+ pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta) # (N * max_seq_len, 1)
+
+ pred_params = tf.concat([pen_state_soft, other_params], axis=1) # (N * max_seq_len, 7)
+ self.pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7]) # (N, max_seq_len, 7)
+ # pred_params: (N, max_seq_len, 7)
+
+ def config_model(self):
+ if self.hps.model_mode == 'train':
+ self.global_step = tf.Variable(0, name='global_step', trainable=False)
+
+ if self.hps.dec_model == 'lstm':
+ dec_cell_fn = rnn.LSTMCell
+ elif self.hps.dec_model == 'layer_norm':
+ dec_cell_fn = rnn.LayerNormLSTMCell
+ elif self.hps.dec_model == 'hyper':
+ dec_cell_fn = rnn.HyperLSTMCell
+ else:
+ assert False, 'please choose a respectable cell'
+
+ use_recurrent_dropout = self.hps.use_recurrent_dropout
+ use_input_dropout = self.hps.use_input_dropout
+ use_output_dropout = self.hps.use_output_dropout
+
+ dec_cell = dec_cell_fn(
+ self.hps.dec_rnn_size,
+ use_recurrent_dropout=use_recurrent_dropout,
+ dropout_keep_prob=self.hps.recurrent_dropout_prob)
+
+ # dropout:
+ # print('Input dropout mode = %s.' % use_input_dropout)
+ # print('Output dropout mode = %s.' % use_output_dropout)
+ # print('Recurrent dropout mode = %s.' % use_recurrent_dropout)
+ if use_input_dropout:
+ print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob)
+ dec_cell = tf.contrib.rnn.DropoutWrapper(
+ dec_cell, input_keep_prob=self.hps.input_dropout_prob)
+ if use_output_dropout:
+ print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob)
+ dec_cell = tf.contrib.rnn.DropoutWrapper(
+ dec_cell, output_keep_prob=self.hps.output_dropout_prob)
+ self.dec_cell = dec_cell
+
+ self.input_photo = tf.compat.v1.placeholder(dtype=tf.float32,
+ shape=[self.hps.batch_size, None, None, self.hps.input_channel]) # [0.0-stroke, 1.0-BG]
+ self.init_cursor = tf.compat.v1.placeholder(
+ dtype=tf.float32,
+ shape=[self.hps.batch_size, 1, 2]) # (N, 1, 2), in size [0.0, 1.0)
+ self.init_width = tf.compat.v1.placeholder(
+ dtype=tf.float32,
+ shape=[self.hps.batch_size]) # (1), in [0.0, 1.0]
+ self.init_scaling = tf.compat.v1.placeholder(
+ dtype=tf.float32,
+ shape=[self.hps.batch_size]) # (N), in [0.0, 1.0]
+ self.init_window_size = tf.compat.v1.placeholder(
+ dtype=tf.float32,
+ shape=[self.hps.batch_size]) # (N)
+ self.image_size = tf.compat.v1.placeholder(dtype=tf.int32, shape=()) # ()
+
+ ###########################
+
+ def normalize_image_m1to1(self, in_img_0to1):
+ norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0)
+ norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0)
+ return norm_img_m1to1
+
+ def add_coords(self, input_tensor):
+ batch_size_tensor = tf.shape(input_tensor)[0] # get N size
+
+ xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size)
+ xx_ones = tf.expand_dims(xx_ones, -1) # e.g. (N, raster_size, 1)
+ xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
+ [batch_size_tensor, 1]) # e.g. (N, raster_size)
+ xx_range = tf.expand_dims(xx_range, 1) # e.g. (N, 1, raster_size)
+
+ xx_channel = tf.matmul(xx_ones, xx_range) # e.g. (N, raster_size, raster_size)
+ xx_channel = tf.expand_dims(xx_channel, -1) # e.g. (N, raster_size, raster_size, 1)
+
+ yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size)
+ yy_ones = tf.expand_dims(yy_ones, 1) # e.g. (N, 1, raster_size)
+ yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
+ [batch_size_tensor, 1]) # (N, raster_size)
+ yy_range = tf.expand_dims(yy_range, -1) # e.g. (N, raster_size, 1)
+
+ yy_channel = tf.matmul(yy_range, yy_ones) # e.g. (N, raster_size, raster_size)
+ yy_channel = tf.expand_dims(yy_channel, -1) # e.g. (N, raster_size, raster_size, 1)
+
+ xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1)
+ yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1)
+ # xx_channel = xx_channel * 2 - 1 # [-1, 1]
+ # yy_channel = yy_channel * 2 - 1
+
+ ret = tf.concat([
+ input_tensor,
+ xx_channel,
+ yy_channel,
+ ], axis=-1) # e.g. (N, raster_size, raster_size, 4)
+
+ return ret
+
+ def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos,
+ image_size, window_size):
+ """
+ :param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG]
+ :param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
+ :param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG]
+ :param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
+ :param cursor_pos: (N, 1, 2), in size [0.0, 1.0)
+ :param window_size: (N, 1, 1), float, in large size
+ :return:
+ """
+ if self.hps.resize_method == 'BILINEAR':
+ resize_method = tf.image.ResizeMethod.BILINEAR
+ elif self.hps.resize_method == 'NEAREST_NEIGHBOR':
+ resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
+ elif self.hps.resize_method == 'BICUBIC':
+ resize_method = tf.image.ResizeMethod.BICUBIC
+ elif self.hps.resize_method == 'AREA':
+ resize_method = tf.image.ResizeMethod.AREA
+ else:
+ raise Exception('unknown resize_method', self.hps.resize_method)
+
+ patch_photo = tf.stop_gradient(patch_photo)
+ patch_canvas = tf.stop_gradient(patch_canvas)
+ cursor_pos = tf.stop_gradient(cursor_pos)
+ window_size = tf.stop_gradient(window_size)
+
+ entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo,
+ (self.hps.raster_size, self.hps.raster_size),
+ method=resize_method))
+ entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas,
+ (self.hps.raster_size, self.hps.raster_size),
+ method=resize_method))
+ entire_photo_small = self.normalize_image_m1to1(entire_photo_small) # [-1.0-stroke, 1.0-BG]
+ entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small) # [-1.0-stroke, 1.0-BG]
+
+ if self.hps.encode_cursor_type == 'value':
+ cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1) # (N, 1, 1, 2)
+ cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1])
+ cursor_info = cursor_pos_norm
+ else:
+ raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type)
+
+ batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info],
+ axis=-1) # [N, raster_size, raster_size, 6/10]
+ batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1) # [N, raster_size, raster_size, 2/4]
+ batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info],
+ axis=-1) # [N, raster_size, raster_size, 4/6]
+
+ if self.hps.model_mode == 'train':
+ is_training = True
+ dropout_keep_prob = self.hps.pix_drop_kp
+ else:
+ is_training = False
+ dropout_keep_prob = 1.0
+
+ if self.hps.add_coordconv:
+ batch_input_combined = self.add_coords(batch_input_combined) # (N, in_H, in_W, in_dim + 2)
+ batch_input_local = self.add_coords(batch_input_local) # (N, in_H, in_W, in_dim + 2)
+ batch_input_global = self.add_coords(batch_input_global) # (N, in_H, in_W, in_dim + 2)
+
+ if 'combine' in self.hps.encoder_type:
+ if self.hps.encoder_type == 'combine33':
+ image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global,
+ is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'combine43':
+ image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global,
+ is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'combine53':
+ image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global,
+ is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'combineFC':
+ image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global,
+ is_training, dropout_keep_prob) # (N, 256)
+ else:
+ raise Exception('Unknown encoder_type', self.hps.encoder_type)
+ else:
+ with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
+ if self.hps.encoder_type == 'conv10':
+ image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'conv10_deep':
+ image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512)
+ elif self.hps.encoder_type == 'conv13':
+ image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'conv10_c3':
+ image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'conv10_deep_c3':
+ image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512)
+ elif self.hps.encoder_type == 'conv13_c3':
+ image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'conv13_c3_attn':
+ image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ else:
+ raise Exception('Unknown encoder_type', self.hps.encoder_type)
+ return image_embedding
+
+ def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):
+ rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x)
+ rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size])
+
+ pen_n_out = 2
+ params_n_out = 6
+
+ with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE):
+ output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out])
+ output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0))
+ output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen) # (N, pen_n_out)
+
+ with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE):
+ output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out])
+ output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0))
+ output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params) # (N, params_n_out)
+
+ output = tf.concat([output_pen, output_params], axis=1) # (N, n_out)
+
+ return output, last_state
+
+ def get_mixture_coef(self, outputs):
+ z = outputs
+ z_pen_logits = z[:, 0:2] # (N, 2), pen states
+ z_other_params_logits = z[:, 2:] # (N, 6)
+
+ z_pen = tf.nn.softmax(z_pen_logits) # (N, 2)
+ if self.hps.position_format == 'abs':
+ x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2]) # (N, 2)
+ x2y2 = tf.tanh(z_other_params_logits[:, 2:4]) # (N, 2)
+ widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5]) # (N, 1)
+ widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width)
+ scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling # (N, 1), [0.0, max_scaling]
+ # scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling),
+ # self.hps.min_scaling)
+ z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1) # (N, 6)
+ else: # "rel"
+ raise Exception('Unknown position_format', self.hps.position_format)
+
+ r = [z_other_params, z_pen]
+ return r
+
+ ###########################
+
+ def get_decoder_inputs(self):
+ initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32)
+ return initial_state
+
+ def rnn_decoder(self, dec_cell, initial_state, actual_input_x):
+ with tf.variable_scope("RNN_DEC", reuse=tf.AUTO_REUSE):
+ output, last_state = tf.nn.dynamic_rnn(
+ dec_cell,
+ actual_input_x,
+ initial_state=initial_state,
+ time_major=False,
+ swap_memory=True,
+ dtype=tf.float32)
+ return output, last_state
+
+ ###########################
+
+ def image_padding(self, ori_image, window_size, pad_value):
+ """
+ Pad with (bg)
+ :param ori_image:
+ :return:
+ """
+ paddings = [[0, 0],
+ [window_size // 2, window_size // 2],
+ [window_size // 2, window_size // 2],
+ [0, 0]]
+ pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value) # (N, H_p, W_p, k)
+ return pad_img
+
+ def image_cropping_fn(self, fn_inputs):
+ """
+ crop the patch
+ :return:
+ """
+ index_offset = self.hps.input_channel - 1
+ input_image = fn_inputs[:, :, 0:2 + index_offset] # (image_size, image_size, -), [0.0-BG, 1.0-stroke]
+ cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset] # (2), in [0.0, 1.0)
+ image_size = fn_inputs[0, 0, 4 + index_offset] # (), float32
+ window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32) # ()
+
+ input_img_reshape = tf.expand_dims(input_image, axis=0)
+ pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0)
+
+ cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32)
+ x0, x1 = cursor_pos[0], cursor_pos[0] + window_size # ()
+ y0, y1 = cursor_pos[1], cursor_pos[1] + window_size # ()
+ patch_image = pad_img[:, y0:y1, x0:x1, :] # (1, window_size, window_size, 2/4)
+
+ # resize to raster_size
+ patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size),
+ method=tf.image.ResizeMethod.AREA)
+ patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
+ # patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke]
+
+ return patch_image_scaled
+
+ def image_cropping(self, cursor_position, input_img, image_size, window_sizes):
+ """
+ :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
+ :param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke]
+ :param window_sizes: (N, 1, 1), float32, with grad
+ """
+ input_img_ = input_img
+ window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes)) # (N, 1, 1), no grad
+
+ cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2)
+ cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 2)
+
+ image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (1, 1, 1, 1)
+ image_size_ = tf.tile(image_size_, [self.hps.batch_size, image_size, image_size, 1])
+
+ window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1)) # (N, 1, 1, 1)
+ window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 1)
+
+ fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_],
+ axis=-1) # (N, image_size, image_size, 2/4 + 4)
+ curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32) # (N, raster_size, raster_size, -)
+ return curr_patch_imgs
+
+ def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes):
+ """
+ :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
+ :param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke]
+ :param window_sizes: (N, 1, 1), float32, with grad
+ """
+ window_sizes_non_grad = tf.stop_gradient(window_sizes) # (N, 1, 1), no grad
+
+ cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32))
+ print(cursor_pos)
+ cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1) # (N, 1, 1)
+
+ y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2
+ x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2
+ y2 = y1 + (window_sizes_non_grad - 1.0)
+ x2 = x1 + (window_sizes_non_grad - 1.0)
+ boxes = tf.concat([y1, x1, y2, x2], axis=-1) # (N, 1, 4)
+ boxes = tf.squeeze(boxes, axis=1) # (N, 4)
+ boxes = boxes / tf.cast(image_size - 1, tf.float32)
+
+ box_ind = tf.ones_like(cursor_x)[:, 0, 0] # (N)
+ box_ind = tf.cast(box_ind, dtype=tf.int32)
+ box_ind = tf.cumsum(box_ind) - 1
+
+ curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind,
+ crop_size=[self.hps.raster_size, self.hps.raster_size])
+ # (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke]
+ return curr_patch_imgs
+
+ def get_points_and_raster_image(self, image_size):
+ ## generate the other_params and pen_ras and raster image for raster loss
+ prev_state = self.initial_state # (N, dec_rnn_size * 3)
+
+ prev_width = self.init_width # (N)
+ prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=-1), axis=-1) # (N, 1, 1)
+
+ prev_scaling = self.init_scaling # (N)
+ prev_scaling = tf.reshape(prev_scaling, (-1, 1, 1)) # (N, 1, 1)
+
+ prev_window_size = self.init_window_size # (N)
+ prev_window_size = tf.reshape(prev_window_size, (-1, 1, 1)) # (N, 1, 1)
+
+ cursor_position_temp = self.init_cursor
+ self.cursor_position = cursor_position_temp # (N, 1, 2), in size [0.0, 1.0)
+ cursor_position_loop = self.cursor_position
+
+ other_params_list = []
+ pen_ras_list = []
+
+ curr_canvas_soft = tf.zeros_like(self.input_photo[:, :, :, 0]) # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
+ curr_canvas_hard = tf.zeros_like(curr_canvas_soft) # [0.0-BG, 1.0-stroke]
+
+ #### sampling part - start ####
+ self.curr_canvas_hard = curr_canvas_hard
+
+ if self.hps.cropping_type == 'v3':
+ cropping_func = self.image_cropping_v3
+ # elif self.hps.cropping_type == 'v2':
+ # cropping_func = self.image_cropping
+ else:
+ raise Exception('Unknown cropping_type', self.hps.cropping_type)
+
+ for time_i in range(self.hps.max_seq_len):
+ cursor_position_non_grad = tf.stop_gradient(cursor_position_loop) # (N, 1, 2), in size [0.0, 1.0)
+
+ curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size)) # float, with grad
+ curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32))
+ curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32))
+
+ ## patch-level encoding
+ # Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation.
+ curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard)
+ curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1)
+
+ # input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
+ crop_inputs = tf.concat([1.0 - self.input_photo, curr_canvas_hard_non_grad], axis=-1) # (N, H_p, W_p, 1+1)
+
+ cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size)
+ index_offset = self.hps.input_channel - 1
+ curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset] # [0.0-BG, 1.0-stroke]
+ curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset]
+ # (N, raster_size, raster_size, 1/3), [0.0-BG, 1.0-stroke]
+
+ curr_patch_inputs = 1.0 - curr_patch_inputs # [0.0-stroke, 1.0-BG]
+ curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs)
+ # (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
+
+ # Normalizing image
+ curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad # [0.0-stroke, 1.0-BG]
+ curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad) # [-1.0-stroke, 1.0-BG]
+
+ ## image-level encoding
+ combined_z = self.build_combined_encoder(
+ curr_patch_canvas_hard_non_grad,
+ curr_patch_inputs,
+ 1.0 - curr_canvas_hard_non_grad,
+ self.input_photo,
+ cursor_position_non_grad,
+ image_size,
+ curr_window_size) # (N, z_size)
+ combined_z = tf.expand_dims(combined_z, axis=1) # (N, 1, z_size)
+
+ curr_window_size_top_side_norm_non_grad = \
+ tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32))
+ curr_window_size_bottom_side_norm_non_grad = \
+ tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32))
+ if not self.hps.concat_win_size:
+ combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2) # (N, 1, 2+z_size)
+ else:
+ combined_z = tf.concat([tf.stop_gradient(prev_width),
+ curr_window_size_top_side_norm_non_grad,
+ curr_window_size_bottom_side_norm_non_grad,
+ combined_z],
+ 2) # (N, 1, 2+z_size)
+
+ if self.hps.concat_cursor:
+ prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2) # (N, 1, 2+2+z_size)
+ else:
+ prev_input_x = combined_z # (N, 1, 2+z_size)
+
+ h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state)
+ # h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3)
+ [o_other_params, o_pen_ras] = self.get_mixture_coef(h_output)
+ # o_other_params: (N * 1, 6)
+ # o_pen_ras: (N * 1, 2), after softmax
+
+ o_other_params = tf.reshape(o_other_params, [-1, 1, 6]) # (N, 1, 6)
+ o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2]) # (N, 1, 2)
+
+ other_params_list.append(o_other_params)
+ pen_ras_list.append(o_pen_ras_raw)
+
+ #### sampling part - end ####
+
+ prev_state = next_state
+
+ other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6]) # (N * max_seq_len, 6)
+ pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2]) # (N * max_seq_len, 2)
+
+ return other_params_, pen_ras_, prev_state
+
+ def differentiable_argmax(self, input_pen, soft_beta):
+ """
+ Differentiable argmax trick.
+ :param input_pen: (N, n_class)
+ :return: pen_state: (N, 1)
+ """
+ def sign_onehot(x):
+ """
+ :param x: (N, n_class)
+ :return: (N, n_class)
+ """
+ y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x)
+ y = (y - 1) * (-1)
+ return y
+
+ def softargmax(x, beta=1e2):
+ """
+ :param x: (N, n_class)
+ :param beta: 1e10 is the best. 1e2 is acceptable.
+ :return: (N)
+ """
+ x_range = tf.cumsum(tf.ones_like(x), axis=1) # (N, 2)
+ return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1
+
+ ## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero.
+ # pen_onehot = sign_onehot(input_pen) # one-hot form, (N * max_seq_len, 2)
+ # pen_state = pen_onehot[:, 1:2] # (N * max_seq_len, 1)
+ pen_state = softargmax(input_pen, soft_beta)
+ pen_state = tf.expand_dims(pen_state, axis=1) # (N * max_seq_len, 1)
+ return pen_state
diff --git a/hi-arm/qmupd_vs/model_common_train.py b/hi-arm/qmupd_vs/model_common_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7c22b33f45a9cbd7e69c878866cb2fb6dd81f7c
--- /dev/null
+++ b/hi-arm/qmupd_vs/model_common_train.py
@@ -0,0 +1,1193 @@
+import rnn
+import tensorflow as tf
+
+from subnet_tf_utils import generative_cnn_encoder, generative_cnn_encoder_deeper, generative_cnn_encoder_deeper13, \
+ generative_cnn_c3_encoder, generative_cnn_c3_encoder_deeper, generative_cnn_c3_encoder_deeper13, \
+ generative_cnn_c3_encoder_combine33, generative_cnn_c3_encoder_combine43, \
+ generative_cnn_c3_encoder_combine53, generative_cnn_c3_encoder_combineFC, \
+ generative_cnn_c3_encoder_deeper13_attn
+from rasterization_utils.NeuralRenderer import NeuralRasterizorStep
+from vgg_utils.VGG16 import vgg_net_slim
+
+
+class VirtualSketchingModel(object):
+ def __init__(self, hps, gpu_mode=True, reuse=False):
+ """Initializer for the model.
+
+ Args:
+ hps: a HParams object containing model hyperparameters
+ gpu_mode: a boolean that when True, uses GPU mode.
+ reuse: a boolean that when true, attemps to reuse variables.
+ """
+ self.hps = hps
+ assert hps.model_mode in ['train', 'eval', 'eval_sample', 'sample']
+ # with tf.variable_scope('SCC', reuse=reuse):
+ if not gpu_mode:
+ with tf.device('/cpu:0'):
+ print('Model using cpu.')
+ self.build_model()
+ else:
+ print('-' * 100)
+ print('model_mode:', hps.model_mode)
+ print('Model using gpu.')
+ self.build_model()
+
+ def build_model(self):
+ """Define model architecture."""
+ self.config_model()
+
+ initial_state = self.get_decoder_inputs()
+ self.initial_state = initial_state
+ self.initial_state_list = tf.split(self.initial_state, self.total_loop, axis=0)
+
+ total_loss_list = []
+ ras_loss_list = []
+ perc_relu_raw_list = []
+ perc_relu_norm_list = []
+ sn_loss_list = []
+ cursor_outside_loss_list = []
+ win_size_outside_loss_list = []
+ early_state_loss_list = []
+
+ tower_grads = []
+
+ pred_raster_imgs_list = []
+ pred_raster_imgs_rgb_list = []
+
+ for t_i in range(self.total_loop):
+ gpu_idx = t_i // self.hps.loop_per_gpu
+ gpu_i = self.hps.gpus[gpu_idx]
+ print(self.hps.model_mode, 'model, gpu:', gpu_i, ', loop:', t_i % self.hps.loop_per_gpu)
+ with tf.device('/gpu:%d' % gpu_i):
+ with tf.name_scope('GPU_%d' % gpu_i) as scope:
+ if t_i > 0:
+ tf.get_variable_scope().reuse_variables()
+ else:
+ total_loss_list.clear()
+ ras_loss_list.clear()
+ perc_relu_raw_list.clear()
+ perc_relu_norm_list.clear()
+ sn_loss_list.clear()
+ cursor_outside_loss_list.clear()
+ win_size_outside_loss_list.clear()
+ early_state_loss_list.clear()
+ tower_grads.clear()
+ pred_raster_imgs_list.clear()
+ pred_raster_imgs_rgb_list.clear()
+
+ split_input_photo = self.input_photo_list[t_i]
+ split_image_size = self.image_size[t_i]
+ split_init_cursor = self.init_cursor_list[t_i]
+ split_initial_state = self.initial_state_list[t_i]
+ if self.hps.input_channel == 1:
+ split_target_sketch = split_input_photo
+ else:
+ split_target_sketch = self.target_sketch_list[t_i]
+
+ ## use pred as the prev points
+ other_params, pen_ras, final_state, pred_raster_images, pred_raster_images_rgb, \
+ pos_before_max_min, win_size_before_max_min \
+ = self.get_points_and_raster_image(split_initial_state, split_init_cursor, split_input_photo,
+ split_image_size)
+ # other_params: (N * max_seq_len, 6)
+ # pen_ras: (N * max_seq_len, 2), after softmax
+ # pos_before_max_min: (N, max_seq_len, 2), in image_size
+ # win_size_before_max_min: (N, max_seq_len, 1), in image_size
+
+ pred_raster_imgs = 1.0 - pred_raster_images # (N, image_size, image_size), [0.0-stroke, 1.0-BG]
+ pred_raster_imgs_rgb = 1.0 - pred_raster_images_rgb # (N, image_size, image_size, 3)
+ pred_raster_imgs_list.append(pred_raster_imgs)
+ pred_raster_imgs_rgb_list.append(pred_raster_imgs_rgb)
+
+ if not self.hps.use_softargmax:
+ pen_state_soft = pen_ras[:, 1:2] # (N * max_seq_len, 1)
+ else:
+ pen_state_soft = self.differentiable_argmax(pen_ras, self.hps.soft_beta) # (N * max_seq_len, 1)
+
+ pred_params = tf.concat([pen_state_soft, other_params], axis=1) # (N * max_seq_len, 7)
+ pred_params = tf.reshape(pred_params, shape=[-1, self.hps.max_seq_len, 7]) # (N, max_seq_len, 7)
+ # pred_params: (N, max_seq_len, 7)
+
+ if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval':
+ raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost, \
+ early_pen_states_cost, \
+ perc_relu_loss_raw, perc_relu_loss_norm = \
+ self.build_losses(split_target_sketch, pred_raster_imgs, pred_params,
+ pos_before_max_min, win_size_before_max_min,
+ split_image_size)
+ # perc_relu_loss_raw, perc_relu_loss_norm: (n_layers)
+
+ ras_loss_list.append(raster_cost)
+ perc_relu_raw_list.append(perc_relu_loss_raw)
+ perc_relu_norm_list.append(perc_relu_loss_norm)
+ sn_loss_list.append(sn_cost)
+ cursor_outside_loss_list.append(cursor_outside_cost)
+ win_size_outside_loss_list.append(winsize_outside_cost)
+ early_state_loss_list.append(early_pen_states_cost)
+
+ if self.hps.model_mode == 'train':
+ total_cost_split, grads_and_vars_split = self.build_training_op_split(
+ raster_cost, sn_cost, cursor_outside_cost, winsize_outside_cost,
+ early_pen_states_cost)
+ total_loss_list.append(total_cost_split)
+ tower_grads.append(grads_and_vars_split)
+
+ self.raster_cost = tf.reduce_mean(tf.stack(ras_loss_list, axis=0))
+ self.perc_relu_losses_raw = tf.reduce_mean(tf.stack(perc_relu_raw_list, axis=0), axis=0) # (n_layers)
+ self.perc_relu_losses_norm = tf.reduce_mean(tf.stack(perc_relu_norm_list, axis=0), axis=0) # (n_layers)
+ self.stroke_num_cost = tf.reduce_mean(tf.stack(sn_loss_list, axis=0))
+ self.pos_outside_cost = tf.reduce_mean(tf.stack(cursor_outside_loss_list, axis=0))
+ self.win_size_outside_cost = tf.reduce_mean(tf.stack(win_size_outside_loss_list, axis=0))
+ self.early_pen_states_cost = tf.reduce_mean(tf.stack(early_state_loss_list, axis=0))
+ self.cost = tf.reduce_mean(tf.stack(total_loss_list, axis=0))
+
+ self.pred_raster_imgs = tf.concat(pred_raster_imgs_list, axis=0) # (N, image_size, image_size), [0.0-stroke, 1.0-BG]
+ self.pred_raster_imgs_rgb = tf.concat(pred_raster_imgs_rgb_list, axis=0) # (N, image_size, image_size, 3)
+
+ if self.hps.model_mode == 'train':
+ self.build_training_op(tower_grads)
+
+ def config_model(self):
+ if self.hps.model_mode == 'train':
+ self.global_step = tf.Variable(0, name='global_step', trainable=False)
+
+ if self.hps.dec_model == 'lstm':
+ dec_cell_fn = rnn.LSTMCell
+ elif self.hps.dec_model == 'layer_norm':
+ dec_cell_fn = rnn.LayerNormLSTMCell
+ elif self.hps.dec_model == 'hyper':
+ dec_cell_fn = rnn.HyperLSTMCell
+ else:
+ assert False, 'please choose a respectable cell'
+
+ use_recurrent_dropout = self.hps.use_recurrent_dropout
+ use_input_dropout = self.hps.use_input_dropout
+ use_output_dropout = self.hps.use_output_dropout
+
+ dec_cell = dec_cell_fn(
+ self.hps.dec_rnn_size,
+ use_recurrent_dropout=use_recurrent_dropout,
+ dropout_keep_prob=self.hps.recurrent_dropout_prob)
+
+ # dropout:
+ # print('Input dropout mode = %s.' % use_input_dropout)
+ # print('Output dropout mode = %s.' % use_output_dropout)
+ # print('Recurrent dropout mode = %s.' % use_recurrent_dropout)
+ if use_input_dropout:
+ print('Dropout to input w/ keep_prob = %4.4f.' % self.hps.input_dropout_prob)
+ dec_cell = tf.contrib.rnn.DropoutWrapper(
+ dec_cell, input_keep_prob=self.hps.input_dropout_prob)
+ if use_output_dropout:
+ print('Dropout to output w/ keep_prob = %4.4f.' % self.hps.output_dropout_prob)
+ dec_cell = tf.contrib.rnn.DropoutWrapper(
+ dec_cell, output_keep_prob=self.hps.output_dropout_prob)
+ self.dec_cell = dec_cell
+
+ self.total_loop = len(self.hps.gpus) * self.hps.loop_per_gpu
+
+ self.init_cursor = tf.placeholder(
+ dtype=tf.float32,
+ shape=[self.hps.batch_size, 1, 2]) # (N, 1, 2), in size [0.0, 1.0)
+ self.init_width = tf.placeholder(
+ dtype=tf.float32,
+ shape=[1]) # (1), in [0.0, 1.0]
+ self.image_size = tf.placeholder(dtype=tf.int32, shape=(self.total_loop)) # ()
+
+ self.init_cursor_list = tf.split(self.init_cursor, self.total_loop, axis=0)
+ self.input_photo_list = []
+ for loop_i in range(self.total_loop):
+ input_photo_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, self.hps.input_channel]) # [0.0-stroke, 1.0-BG]
+ self.input_photo_list.append(input_photo_i)
+
+ if self.hps.input_channel == 3:
+ self.target_sketch_list = []
+ for loop_i in range(self.total_loop):
+ target_sketch_i = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1]) # [0.0-stroke, 1.0-BG]
+ self.target_sketch_list.append(target_sketch_i)
+
+ if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval':
+ self.stroke_num_loss_weight = tf.Variable(0.0, trainable=False)
+ self.early_pen_loss_start_idx = tf.Variable(0, dtype=tf.int32, trainable=False)
+ self.early_pen_loss_end_idx = tf.Variable(0, dtype=tf.int32, trainable=False)
+
+ if self.hps.model_mode == 'train':
+ self.perc_loss_mean_list = []
+ for loop_i in range(len(self.hps.perc_loss_layers)):
+ relu_loss_mean = tf.Variable(0.0, trainable=False)
+ self.perc_loss_mean_list.append(relu_loss_mean)
+ self.last_step_num = tf.Variable(0.0, trainable=False)
+
+ with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE):
+ self.lr = tf.Variable(self.hps.learning_rate, trainable=False)
+ self.optimizer = tf.train.AdamOptimizer(self.lr)
+
+ ###########################
+
+ def normalize_image_m1to1(self, in_img_0to1):
+ norm_img_m1to1 = tf.multiply(in_img_0to1, 2.0)
+ norm_img_m1to1 = tf.subtract(norm_img_m1to1, 1.0)
+ return norm_img_m1to1
+
+ def add_coords(self, input_tensor):
+ batch_size_tensor = tf.shape(input_tensor)[0] # get N size
+
+ xx_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size)
+ xx_ones = tf.expand_dims(xx_ones, -1) # e.g. (N, raster_size, 1)
+ xx_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
+ [batch_size_tensor, 1]) # e.g. (N, raster_size)
+ xx_range = tf.expand_dims(xx_range, 1) # e.g. (N, 1, raster_size)
+
+ xx_channel = tf.matmul(xx_ones, xx_range) # e.g. (N, raster_size, raster_size)
+ xx_channel = tf.expand_dims(xx_channel, -1) # e.g. (N, raster_size, raster_size, 1)
+
+ yy_ones = tf.ones([batch_size_tensor, self.hps.raster_size], dtype=tf.int32) # e.g. (N, raster_size)
+ yy_ones = tf.expand_dims(yy_ones, 1) # e.g. (N, 1, raster_size)
+ yy_range = tf.tile(tf.expand_dims(tf.range(self.hps.raster_size), 0),
+ [batch_size_tensor, 1]) # (N, raster_size)
+ yy_range = tf.expand_dims(yy_range, -1) # e.g. (N, raster_size, 1)
+
+ yy_channel = tf.matmul(yy_range, yy_ones) # e.g. (N, raster_size, raster_size)
+ yy_channel = tf.expand_dims(yy_channel, -1) # e.g. (N, raster_size, raster_size, 1)
+
+ xx_channel = tf.cast(xx_channel, 'float32') / (self.hps.raster_size - 1)
+ yy_channel = tf.cast(yy_channel, 'float32') / (self.hps.raster_size - 1)
+ # xx_channel = xx_channel * 2 - 1 # [-1, 1]
+ # yy_channel = yy_channel * 2 - 1
+
+ ret = tf.concat([
+ input_tensor,
+ xx_channel,
+ yy_channel,
+ ], axis=-1) # e.g. (N, raster_size, raster_size, 4)
+
+ return ret
+
+ def build_combined_encoder(self, patch_canvas, patch_photo, entire_canvas, entire_photo, cursor_pos,
+ image_size, window_size):
+ """
+ :param patch_canvas: (N, raster_size, raster_size, 1), [-1.0-stroke, 1.0-BG]
+ :param patch_photo: (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
+ :param entire_canvas: (N, image_size, image_size, 1), [0.0-stroke, 1.0-BG]
+ :param entire_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
+ :param cursor_pos: (N, 1, 2), in size [0.0, 1.0)
+ :param window_size: (N, 1, 1), float, in large size
+ :return:
+ """
+ if self.hps.resize_method == 'BILINEAR':
+ resize_method = tf.image.ResizeMethod.BILINEAR
+ elif self.hps.resize_method == 'NEAREST_NEIGHBOR':
+ resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
+ elif self.hps.resize_method == 'BICUBIC':
+ resize_method = tf.image.ResizeMethod.BICUBIC
+ elif self.hps.resize_method == 'AREA':
+ resize_method = tf.image.ResizeMethod.AREA
+ else:
+ raise Exception('unknown resize_method', self.hps.resize_method)
+
+ patch_photo = tf.stop_gradient(patch_photo)
+ patch_canvas = tf.stop_gradient(patch_canvas)
+ cursor_pos = tf.stop_gradient(cursor_pos)
+ window_size = tf.stop_gradient(window_size)
+
+ entire_photo_small = tf.stop_gradient(tf.image.resize_images(entire_photo,
+ (self.hps.raster_size, self.hps.raster_size),
+ method=resize_method))
+ entire_canvas_small = tf.stop_gradient(tf.image.resize_images(entire_canvas,
+ (self.hps.raster_size, self.hps.raster_size),
+ method=resize_method))
+ entire_photo_small = self.normalize_image_m1to1(entire_photo_small) # [-1.0-stroke, 1.0-BG]
+ entire_canvas_small = self.normalize_image_m1to1(entire_canvas_small) # [-1.0-stroke, 1.0-BG]
+
+ if self.hps.encode_cursor_type == 'value':
+ cursor_pos_norm = tf.expand_dims(cursor_pos, axis=1) # (N, 1, 1, 2)
+ cursor_pos_norm = tf.tile(cursor_pos_norm, [1, self.hps.raster_size, self.hps.raster_size, 1])
+ cursor_info = cursor_pos_norm
+ else:
+ raise Exception('Unknown encode_cursor_type', self.hps.encode_cursor_type)
+
+ batch_input_combined = tf.concat([patch_photo, patch_canvas, entire_photo_small, entire_canvas_small, cursor_info],
+ axis=-1) # [N, raster_size, raster_size, 6/10]
+ batch_input_local = tf.concat([patch_photo, patch_canvas], axis=-1) # [N, raster_size, raster_size, 2/4]
+ batch_input_global = tf.concat([entire_photo_small, entire_canvas_small, cursor_info],
+ axis=-1) # [N, raster_size, raster_size, 4/6]
+
+ if self.hps.model_mode == 'train':
+ is_training = True
+ dropout_keep_prob = self.hps.pix_drop_kp
+ else:
+ is_training = False
+ dropout_keep_prob = 1.0
+
+ if self.hps.add_coordconv:
+ batch_input_combined = self.add_coords(batch_input_combined) # (N, in_H, in_W, in_dim + 2)
+ batch_input_local = self.add_coords(batch_input_local) # (N, in_H, in_W, in_dim + 2)
+ batch_input_global = self.add_coords(batch_input_global) # (N, in_H, in_W, in_dim + 2)
+
+ if 'combine' in self.hps.encoder_type:
+ if self.hps.encoder_type == 'combine33':
+ image_embedding, _ = generative_cnn_c3_encoder_combine33(batch_input_local, batch_input_global,
+ is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'combine43':
+ image_embedding, _ = generative_cnn_c3_encoder_combine43(batch_input_local, batch_input_global,
+ is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'combine53':
+ image_embedding, _ = generative_cnn_c3_encoder_combine53(batch_input_local, batch_input_global,
+ is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'combineFC':
+ image_embedding, _ = generative_cnn_c3_encoder_combineFC(batch_input_local, batch_input_global,
+ is_training, dropout_keep_prob) # (N, 256)
+ else:
+ raise Exception('Unknown encoder_type', self.hps.encoder_type)
+ else:
+ with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
+ if self.hps.encoder_type == 'conv10':
+ image_embedding, _ = generative_cnn_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'conv10_deep':
+ image_embedding, _ = generative_cnn_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512)
+ elif self.hps.encoder_type == 'conv13':
+ image_embedding, _ = generative_cnn_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'conv10_c3':
+ image_embedding, _ = generative_cnn_c3_encoder(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'conv10_deep_c3':
+ image_embedding, _ = generative_cnn_c3_encoder_deeper(batch_input_combined, is_training, dropout_keep_prob) # (N, 512)
+ elif self.hps.encoder_type == 'conv13_c3':
+ image_embedding, _ = generative_cnn_c3_encoder_deeper13(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ elif self.hps.encoder_type == 'conv13_c3_attn':
+ image_embedding, _ = generative_cnn_c3_encoder_deeper13_attn(batch_input_combined, is_training, dropout_keep_prob) # (N, 128)
+ else:
+ raise Exception('Unknown encoder_type', self.hps.encoder_type)
+ return image_embedding
+
+ def build_seq_decoder(self, dec_cell, actual_input_x, initial_state):
+ rnn_output, last_state = self.rnn_decoder(dec_cell, initial_state, actual_input_x)
+ rnn_output_flat = tf.reshape(rnn_output, [-1, self.hps.dec_rnn_size])
+
+ pen_n_out = 2
+ params_n_out = 6
+
+ with tf.variable_scope('DEC_RNN_out_pen', reuse=tf.AUTO_REUSE):
+ output_w_pen = tf.get_variable('output_w', [self.hps.dec_rnn_size, pen_n_out])
+ output_b_pen = tf.get_variable('output_b', [pen_n_out], initializer=tf.constant_initializer(0.0))
+ output_pen = tf.nn.xw_plus_b(rnn_output_flat, output_w_pen, output_b_pen) # (N, pen_n_out)
+
+ with tf.variable_scope('DEC_RNN_out_params', reuse=tf.AUTO_REUSE):
+ output_w_params = tf.get_variable('output_w', [self.hps.dec_rnn_size, params_n_out])
+ output_b_params = tf.get_variable('output_b', [params_n_out], initializer=tf.constant_initializer(0.0))
+ output_params = tf.nn.xw_plus_b(rnn_output_flat, output_w_params, output_b_params) # (N, params_n_out)
+
+ output = tf.concat([output_pen, output_params], axis=1) # (N, n_out)
+
+ return output, last_state
+
+ def get_mixture_coef(self, outputs):
+ z = outputs
+ z_pen_logits = z[:, 0:2] # (N, 2), pen states
+ z_other_params_logits = z[:, 2:] # (N, 6)
+
+ z_pen = tf.nn.softmax(z_pen_logits) # (N, 2)
+ if self.hps.position_format == 'abs':
+ x1y1 = tf.nn.sigmoid(z_other_params_logits[:, 0:2]) # (N, 2)
+ x2y2 = tf.tanh(z_other_params_logits[:, 2:4]) # (N, 2)
+ widths = tf.nn.sigmoid(z_other_params_logits[:, 4:5]) # (N, 1)
+ widths = tf.add(tf.multiply(widths, 1.0 - self.hps.min_width), self.hps.min_width)
+ scaling = tf.nn.sigmoid(z_other_params_logits[:, 5:6]) * self.hps.max_scaling # (N, 1), [0.0, max_scaling]
+ # scaling = tf.add(tf.multiply(scaling, (self.hps.max_scaling - self.hps.min_scaling) / self.hps.max_scaling),
+ # self.hps.min_scaling)
+ z_other_params = tf.concat([x1y1, x2y2, widths, scaling], axis=-1) # (N, 6)
+ else: # "rel"
+ raise Exception('Unknown position_format', self.hps.position_format)
+
+ r = [z_other_params, z_pen]
+ return r
+
+ ###########################
+
+ def get_decoder_inputs(self):
+ initial_state = self.dec_cell.zero_state(batch_size=self.hps.batch_size, dtype=tf.float32)
+ return initial_state
+
+ def rnn_decoder(self, dec_cell, initial_state, actual_input_x):
+ with tf.variable_scope("RNN_DEC", reuse=tf.AUTO_REUSE):
+ output, last_state = tf.nn.dynamic_rnn(
+ dec_cell,
+ actual_input_x,
+ initial_state=initial_state,
+ time_major=False,
+ swap_memory=True,
+ dtype=tf.float32)
+ return output, last_state
+
+ ###########################
+
+ def image_padding(self, ori_image, window_size, pad_value):
+ """
+ Pad with (bg)
+ :param ori_image:
+ :return:
+ """
+ paddings = [[0, 0],
+ [window_size // 2, window_size // 2],
+ [window_size // 2, window_size // 2],
+ [0, 0]]
+ pad_img = tf.pad(ori_image, paddings=paddings, mode='CONSTANT', constant_values=pad_value) # (N, H_p, W_p, k)
+ return pad_img
+
+ def image_cropping_fn(self, fn_inputs):
+ """
+ crop the patch
+ :return:
+ """
+ index_offset = self.hps.input_channel - 1
+ input_image = fn_inputs[:, :, 0:2 + index_offset] # (image_size, image_size, 2), [0.0-BG, 1.0-stroke]
+ cursor_pos = fn_inputs[0, 0, 2 + index_offset:4 + index_offset] # (2), in [0.0, 1.0)
+ image_size = fn_inputs[0, 0, 4 + index_offset] # (), float32
+ window_size = tf.cast(fn_inputs[0, 0, 5 + index_offset], tf.int32) # ()
+
+ input_img_reshape = tf.expand_dims(input_image, axis=0)
+ pad_img = self.image_padding(input_img_reshape, window_size, pad_value=0.0)
+
+ cursor_pos = tf.cast(tf.round(tf.multiply(cursor_pos, image_size)), dtype=tf.int32)
+ x0, x1 = cursor_pos[0], cursor_pos[0] + window_size # ()
+ y0, y1 = cursor_pos[1], cursor_pos[1] + window_size # ()
+ patch_image = pad_img[:, y0:y1, x0:x1, :] # (1, window_size, window_size, 2/4)
+
+ # resize to raster_size
+ patch_image_scaled = tf.image.resize_images(patch_image, (self.hps.raster_size, self.hps.raster_size),
+ method=tf.image.ResizeMethod.AREA)
+ patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
+ # patch_canvas_scaled: (raster_size, raster_size, 2/4), [0.0-BG, 1.0-stroke]
+
+ return patch_image_scaled
+
+ def image_cropping(self, cursor_position, input_img, image_size, window_sizes):
+ """
+ :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
+ :param input_img: (N, image_size, image_size, 2/4), [0.0-BG, 1.0-stroke]
+ :param window_sizes: (N, 1, 1), float32, with grad
+ """
+ input_img_ = input_img
+ window_sizes_non_grad = tf.stop_gradient(tf.round(window_sizes)) # (N, 1, 1), no grad
+
+ cursor_position_ = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2)
+ cursor_position_ = tf.tile(cursor_position_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 2)
+
+ image_size_ = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (1, 1, 1, 1)
+ image_size_ = tf.tile(image_size_, [self.hps.batch_size // self.total_loop, image_size, image_size, 1])
+
+ window_sizes_ = tf.reshape(window_sizes_non_grad, (-1, 1, 1, 1)) # (N, 1, 1, 1)
+ window_sizes_ = tf.tile(window_sizes_, [1, image_size, image_size, 1]) # (N, image_size, image_size, 1)
+
+ fn_inputs = tf.concat([input_img_, cursor_position_, image_size_, window_sizes_],
+ axis=-1) # (N, image_size, image_size, 2/4 + 4)
+ curr_patch_imgs = tf.map_fn(self.image_cropping_fn, fn_inputs, parallel_iterations=32) # (N, raster_size, raster_size, -)
+ return curr_patch_imgs
+
+ def image_cropping_v3(self, cursor_position, input_img, image_size, window_sizes):
+ """
+ :param cursor_position: (N, 1, 2), float type, in size [0.0, 1.0)
+ :param input_img: (N, image_size, image_size, k), [0.0-BG, 1.0-stroke]
+ :param window_sizes: (N, 1, 1), float32, with grad
+ """
+ window_sizes_non_grad = tf.stop_gradient(window_sizes) # (N, 1, 1), no grad
+
+ cursor_pos = tf.multiply(cursor_position, tf.cast(image_size, tf.float32))
+ cursor_x, cursor_y = tf.split(cursor_pos, 2, axis=-1) # (N, 1, 1)
+
+ y1 = cursor_y - (window_sizes_non_grad - 1.0) / 2
+ x1 = cursor_x - (window_sizes_non_grad - 1.0) / 2
+ y2 = y1 + (window_sizes_non_grad - 1.0)
+ x2 = x1 + (window_sizes_non_grad - 1.0)
+ boxes = tf.concat([y1, x1, y2, x2], axis=-1) # (N, 1, 4)
+ boxes = tf.squeeze(boxes, axis=1) # (N, 4)
+ boxes = boxes / tf.cast(image_size - 1, tf.float32)
+
+ box_ind = tf.ones_like(cursor_x)[:, 0, 0] # (N)
+ box_ind = tf.cast(box_ind, dtype=tf.int32)
+ box_ind = tf.cumsum(box_ind) - 1
+
+ curr_patch_imgs = tf.image.crop_and_resize(input_img, boxes, box_ind,
+ crop_size=[self.hps.raster_size, self.hps.raster_size])
+ # (N, raster_size, raster_size, k), [0.0-BG, 1.0-stroke]
+ return curr_patch_imgs
+
+ def get_pixel_value(self, img, x, y):
+ """
+ Utility function to get pixel value for coordinate vectors x and y from a 4D tensor image.
+
+ Input
+ -----
+ - img: tensor of shape (B, H, W, C)
+ - x: flattened tensor of shape (B, H', W')
+ - y: flattened tensor of shape (B, H', W')
+
+ Returns
+ -------
+ - output: tensor of shape (B, H', W', C)
+ """
+ shape = tf.shape(x)
+ batch_size = shape[0]
+ height = shape[1]
+ width = shape[2]
+
+ batch_idx = tf.range(0, batch_size)
+ batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
+ b = tf.tile(batch_idx, (1, height, width))
+
+ indices = tf.stack([b, y, x], 3)
+
+ return tf.gather_nd(img, indices)
+
+ def image_pasting_nondiff_single(self, fn_inputs):
+ patch_image = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
+ cursor_pos = fn_inputs[0, 0, 1:3] # (2), in large size
+ image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32) # ()
+ window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32) # ()
+
+ patch_image_scaled = tf.expand_dims(patch_image, axis=0) # (1, raster_size, raster_size, 1)
+ patch_image_scaled = tf.image.resize_images(patch_image_scaled, (window_size, window_size),
+ method=tf.image.ResizeMethod.BILINEAR)
+ patch_image_scaled = tf.squeeze(patch_image_scaled, axis=0)
+ # patch_canvas_scaled: (window_size, window_size, 1)
+
+ cursor_pos = tf.cast(tf.round(cursor_pos), dtype=tf.int32) # (2)
+ cursor_x, cursor_y = cursor_pos[0], cursor_pos[1]
+
+ pad_up = cursor_y
+ pad_down = image_size - cursor_y
+ pad_left = cursor_x
+ pad_right = image_size - cursor_x
+
+ paddings = [[pad_up, pad_down],
+ [pad_left, pad_right],
+ [0, 0]]
+ pad_img = tf.pad(patch_image_scaled, paddings=paddings, mode='CONSTANT',
+ constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]
+
+ crop_start = window_size // 2
+ pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :]
+ return pasted_image
+
+ def image_pasting_diff_single(self, fn_inputs):
+ patch_canvas = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
+ cursor_pos = fn_inputs[0, 0, 1:3] # (2), in large size
+ image_size = tf.cast(fn_inputs[0, 0, 3], tf.int32) # ()
+ window_size = tf.cast(fn_inputs[0, 0, 4], tf.int32) # ()
+ cursor_x, cursor_y = cursor_pos[0], cursor_pos[1]
+
+ patch_canvas_scaled = tf.expand_dims(patch_canvas, axis=0) # (1, raster_size, raster_size, 1)
+ patch_canvas_scaled = tf.image.resize_images(patch_canvas_scaled, (window_size, window_size),
+ method=tf.image.ResizeMethod.BILINEAR)
+ # patch_canvas_scaled: (1, window_size, window_size, 1)
+
+ valid_canvas = self.image_pasting_diff_batch(patch_canvas_scaled,
+ tf.expand_dims(tf.expand_dims(cursor_pos, axis=0), axis=0),
+ window_size)
+ valid_canvas = tf.squeeze(valid_canvas, axis=0)
+ # (window_size + 1, window_size + 1, 1)
+
+ pad_up = tf.cast(tf.floor(cursor_y), tf.int32)
+ pad_down = image_size - 1 - tf.cast(tf.floor(cursor_y), tf.int32)
+ pad_left = tf.cast(tf.floor(cursor_x), tf.int32)
+ pad_right = image_size - 1 - tf.cast(tf.floor(cursor_x), tf.int32)
+
+ paddings = [[pad_up, pad_down],
+ [pad_left, pad_right],
+ [0, 0]]
+ pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
+ constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]
+
+ crop_start = window_size // 2
+ pasted_image = pad_img[crop_start: crop_start + image_size, crop_start: crop_start + image_size, :]
+ return pasted_image
+
+ def image_pasting_diff_single_v3(self, fn_inputs):
+ patch_canvas = fn_inputs[:, :, 0:1] # (raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
+ cursor_pos_a = fn_inputs[0, 0, 1:3] # (2), float32, in large size
+ image_size_a = tf.cast(fn_inputs[0, 0, 3], tf.int32) # ()
+ window_size_a = fn_inputs[0, 0, 4] # (), float32, with grad
+ raster_size_a = float(self.hps.raster_size)
+
+ padding_size = tf.cast(tf.ceil(window_size_a / 2.0), tf.int32)
+
+ x1y1_a = cursor_pos_a - window_size_a / 2.0 # (2), float32
+ x2y2_a = cursor_pos_a + window_size_a / 2.0 # (2), float32
+
+ x1y1_a_floor = tf.floor(x1y1_a) # (2)
+ x2y2_a_ceil = tf.ceil(x2y2_a) # (2)
+
+ cursor_pos_b_oricoord = (x1y1_a_floor + x2y2_a_ceil) / 2.0 # (2)
+ cursor_pos_b = (cursor_pos_b_oricoord - x1y1_a) / window_size_a * raster_size_a # (2)
+ raster_size_b = (x2y2_a_ceil - x1y1_a_floor) # (x, y)
+ image_size_b = raster_size_a
+ window_size_b = raster_size_a * (raster_size_b / window_size_a) # (x, y)
+
+ cursor_b_x, cursor_b_y = tf.split(cursor_pos_b, 2, axis=-1) # (1)
+
+ y1_b = cursor_b_y - (window_size_b[1] - 1.) / 2.
+ x1_b = cursor_b_x - (window_size_b[0] - 1.) / 2.
+ y2_b = y1_b + (window_size_b[1] - 1.)
+ x2_b = x1_b + (window_size_b[0] - 1.)
+ boxes_b = tf.concat([y1_b, x1_b, y2_b, x2_b], axis=-1) # (4)
+ boxes_b = boxes_b / tf.cast(image_size_b - 1, tf.float32) # with grad to window_size_a
+
+ box_ind_b = tf.ones((1), dtype=tf.int32) # (1)
+ box_ind_b = tf.cumsum(box_ind_b) - 1
+
+ patch_canvas = tf.expand_dims(patch_canvas, axis=0) # (1, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
+ boxes_b = tf.expand_dims(boxes_b, axis=0) # (1, 4)
+
+ valid_canvas = tf.image.crop_and_resize(patch_canvas, boxes_b, box_ind_b,
+ crop_size=[raster_size_b[1], raster_size_b[0]])
+ valid_canvas = valid_canvas[0] # (raster_size_b, raster_size_b, 1)
+
+ pad_up = tf.cast(x1y1_a_floor[1], tf.int32) + padding_size
+ pad_down = image_size_a + padding_size - tf.cast(x2y2_a_ceil[1], tf.int32)
+ pad_left = tf.cast(x1y1_a_floor[0], tf.int32) + padding_size
+ pad_right = image_size_a + padding_size - tf.cast(x2y2_a_ceil[0], tf.int32)
+
+ paddings = [[pad_up, pad_down],
+ [pad_left, pad_right],
+ [0, 0]]
+ pad_img = tf.pad(valid_canvas, paddings=paddings, mode='CONSTANT',
+ constant_values=0.0) # (H_p, W_p, 1), [0.0-BG, 1.0-stroke]
+
+ pasted_image = pad_img[padding_size: padding_size + image_size_a, padding_size: padding_size + image_size_a, :]
+ return pasted_image
+
+ def image_pasting_diff_batch(self, patch_image, cursor_position, window_size):
+ """
+ :param patch_img: (N, window_size, window_size, 1), [0.0-BG, 1.0-stroke]
+ :param cursor_position: (N, 1, 2), in large size
+ :return:
+ """
+ paddings1 = [[0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0]]
+ patch_image_pad1 = tf.pad(patch_image, paddings=paddings1, mode='CONSTANT',
+ constant_values=0.0) # (N, window_size+2, window_size+2, 1), [0.0-BG, 1.0-stroke]
+
+ cursor_x, cursor_y = cursor_position[:, :, 0:1], cursor_position[:, :, 1:2] # (N, 1, 1)
+ cursor_x_f, cursor_y_f = tf.floor(cursor_x), tf.floor(cursor_y)
+ patch_x, patch_y = 1.0 - (cursor_x - cursor_x_f), 1.0 - (cursor_y - cursor_y_f) # (N, 1, 1)
+
+ x_ones = tf.ones_like(patch_x, dtype=tf.float32) # (N, 1, 1)
+ x_ones = tf.tile(x_ones, [1, 1, window_size]) # (N, 1, window_size)
+ patch_x = tf.concat([patch_x, x_ones], axis=-1) # (N, 1, window_size + 1)
+ patch_x = tf.tile(patch_x, [1, window_size + 1, 1]) # (N, window_size + 1, window_size + 1)
+ patch_x = tf.cumsum(patch_x, axis=-1) # (N, window_size + 1, window_size + 1)
+ patch_x0 = tf.cast(tf.floor(patch_x), tf.int32) # (N, window_size + 1, window_size + 1)
+ patch_x1 = patch_x0 + 1 # (N, window_size + 1, window_size + 1)
+
+ y_ones = tf.ones_like(patch_y, dtype=tf.float32) # (N, 1, 1)
+ y_ones = tf.tile(y_ones, [1, window_size, 1]) # (N, window_size, 1)
+ patch_y = tf.concat([patch_y, y_ones], axis=1) # (N, window_size + 1, 1)
+ patch_y = tf.tile(patch_y, [1, 1, window_size + 1]) # (N, window_size + 1, window_size + 1)
+ patch_y = tf.cumsum(patch_y, axis=1) # (N, window_size + 1, window_size + 1)
+ patch_y0 = tf.cast(tf.floor(patch_y), tf.int32) # (N, window_size + 1, window_size + 1)
+ patch_y1 = patch_y0 + 1 # (N, window_size + 1, window_size + 1)
+
+ # get pixel value at corner coords
+ valid_canvas_patch_a = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y0)
+ valid_canvas_patch_b = self.get_pixel_value(patch_image_pad1, patch_x0, patch_y1)
+ valid_canvas_patch_c = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y0)
+ valid_canvas_patch_d = self.get_pixel_value(patch_image_pad1, patch_x1, patch_y1)
+ # (N, window_size + 1, window_size + 1, 1)
+
+ patch_x0 = tf.cast(patch_x0, tf.float32)
+ patch_x1 = tf.cast(patch_x1, tf.float32)
+ patch_y0 = tf.cast(patch_y0, tf.float32)
+ patch_y1 = tf.cast(patch_y1, tf.float32)
+
+ # calculate deltas
+ wa = (patch_x1 - patch_x) * (patch_y1 - patch_y)
+ wb = (patch_x1 - patch_x) * (patch_y - patch_y0)
+ wc = (patch_x - patch_x0) * (patch_y1 - patch_y)
+ wd = (patch_x - patch_x0) * (patch_y - patch_y0)
+ # (N, window_size + 1, window_size + 1)
+
+ # add dimension for addition
+ wa = tf.expand_dims(wa, axis=3)
+ wb = tf.expand_dims(wb, axis=3)
+ wc = tf.expand_dims(wc, axis=3)
+ wd = tf.expand_dims(wd, axis=3)
+ # (N, window_size + 1, window_size + 1, 1)
+
+ # compute output
+ valid_canvas_patch_ = tf.add_n([wa * valid_canvas_patch_a,
+ wb * valid_canvas_patch_b,
+ wc * valid_canvas_patch_c,
+ wd * valid_canvas_patch_d]) # (N, window_size + 1, window_size + 1, 1)
+ return valid_canvas_patch_
+
+ def image_pasting(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False):
+ """
+ paste the patch_img to padded size based on cursor_position
+ :param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0)
+ :param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
+ :param window_sizes: (N, 1, 1), float32, with grad
+ :return:
+ """
+ cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32)) # in large size
+ window_sizes_r = tf.round(window_sizes) # (N, 1, 1), no grad
+
+ patch_img_ = tf.expand_dims(patch_img, axis=-1) # (N, raster_size, raster_size, 1)
+ cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2)
+ cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size,
+ 1]) # (N, raster_size, raster_size, 2)
+ image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (N, 1, 1, 1)
+ image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size,
+ self.hps.raster_size, 1])
+ window_sizes_tile = tf.reshape(window_sizes_r, (-1, 1, 1, 1)) # (N, 1, 1, 1)
+ window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1])
+
+ pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile],
+ axis=-1) # (N, raster_size, raster_size, 5)
+
+ if is_differentiable:
+ curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single, pasting_inputs,
+ parallel_iterations=32) # (N, image_size, image_size, 1)
+ else:
+ curr_paste_imgs = tf.map_fn(self.image_pasting_nondiff_single, pasting_inputs,
+ parallel_iterations=32) # (N, image_size, image_size, 1)
+ curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1) # (N, image_size, image_size)
+ return curr_paste_imgs
+
+ def image_pasting_v3(self, cursor_position_norm, patch_img, image_size, window_sizes, is_differentiable=False):
+ """
+ paste the patch_img to padded size based on cursor_position
+ :param cursor_position_norm: (N, 1, 2), float type, in size [0.0, 1.0)
+ :param patch_img: (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
+ :param window_sizes: (N, 1, 1), float32, with grad
+ :return:
+ """
+ cursor_position = tf.multiply(cursor_position_norm, tf.cast(image_size, tf.float32)) # in large size
+
+ if is_differentiable:
+ patch_img_ = tf.expand_dims(patch_img, axis=-1) # (N, raster_size, raster_size, 1)
+ cursor_position_step = tf.reshape(cursor_position, (-1, 1, 1, 2)) # (N, 1, 1, 2)
+ cursor_position_step = tf.tile(cursor_position_step, [1, self.hps.raster_size, self.hps.raster_size,
+ 1]) # (N, raster_size, raster_size, 2)
+ image_size_tile = tf.reshape(tf.cast(image_size, tf.float32), (1, 1, 1, 1)) # (N, 1, 1, 1)
+ image_size_tile = tf.tile(image_size_tile, [self.hps.batch_size // self.total_loop, self.hps.raster_size,
+ self.hps.raster_size, 1])
+ window_sizes_tile = tf.reshape(window_sizes, (-1, 1, 1, 1)) # (N, 1, 1, 1)
+ window_sizes_tile = tf.tile(window_sizes_tile, [1, self.hps.raster_size, self.hps.raster_size, 1])
+
+ pasting_inputs = tf.concat([patch_img_, cursor_position_step, image_size_tile, window_sizes_tile],
+ axis=-1) # (N, raster_size, raster_size, 5)
+ curr_paste_imgs = tf.map_fn(self.image_pasting_diff_single_v3, pasting_inputs,
+ parallel_iterations=32) # (N, image_size, image_size, 1)
+ else:
+ raise Exception('Unfinished...')
+ curr_paste_imgs = tf.squeeze(curr_paste_imgs, axis=-1) # (N, image_size, image_size)
+ return curr_paste_imgs
+
+ def get_points_and_raster_image(self, initial_state, init_cursor, input_photo, image_size):
+ ## generate the other_params and pen_ras and raster image for raster loss
+ prev_state = initial_state # (N, dec_rnn_size * 3)
+
+ prev_width = self.init_width # (1)
+ prev_width = tf.expand_dims(tf.expand_dims(prev_width, axis=0), axis=0) # (1, 1, 1)
+ prev_width = tf.tile(prev_width, [self.hps.batch_size // self.total_loop, 1, 1]) # (N, 1, 1)
+
+ prev_scaling = tf.ones((self.hps.batch_size // self.total_loop, 1, 1)) # (N, 1, 1)
+ prev_window_size = tf.ones((self.hps.batch_size // self.total_loop, 1, 1),
+ dtype=tf.float32) * float(self.hps.raster_size) # (N, 1, 1)
+
+ cursor_position_temp = init_cursor
+ self.cursor_position = cursor_position_temp # (N, 1, 2), in size [0.0, 1.0)
+ cursor_position_loop = self.cursor_position
+
+ other_params_list = []
+ pen_ras_list = []
+
+ pos_before_max_min_list = []
+ win_size_before_max_min_list = []
+
+ curr_canvas_soft = tf.zeros_like(input_photo[:, :, :, 0]) # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
+ curr_canvas_soft_rgb = tf.tile(tf.zeros_like(input_photo[:, :, :, 0:1]), [1, 1, 1, 3]) # (N, image_size, image_size, 3), [0.0-BG, 1.0-stroke]
+ curr_canvas_hard = tf.zeros_like(curr_canvas_soft) # [0.0-BG, 1.0-stroke]
+
+ #### sampling part - start ####
+ self.curr_canvas_hard = curr_canvas_hard
+
+ rasterizor_st = NeuralRasterizorStep(
+ raster_size=self.hps.raster_size,
+ position_format=self.hps.position_format)
+
+ if self.hps.cropping_type == 'v3':
+ cropping_func = self.image_cropping_v3
+ # elif self.hps.cropping_type == 'v2':
+ # cropping_func = self.image_cropping
+ else:
+ raise Exception('Unknown cropping_type', self.hps.cropping_type)
+
+ if self.hps.pasting_type == 'v3':
+ pasting_func = self.image_pasting_v3
+ # elif self.hps.pasting_type == 'v2':
+ # pasting_func = self.image_pasting
+ else:
+ raise Exception('Unknown pasting_type', self.hps.pasting_type)
+
+ for time_i in range(self.hps.max_seq_len):
+ cursor_position_non_grad = tf.stop_gradient(cursor_position_loop) # (N, 1, 2), in size [0.0, 1.0)
+
+ curr_window_size = tf.multiply(prev_scaling, tf.stop_gradient(prev_window_size)) # float, with grad
+ curr_window_size = tf.maximum(curr_window_size, tf.cast(self.hps.min_window_size, tf.float32))
+ curr_window_size = tf.minimum(curr_window_size, tf.cast(image_size, tf.float32))
+
+ ## patch-level encoding
+ # Here, we make the gradients from canvas_z to curr_canvas_hard be None to avoid recurrent gradient propagation.
+ curr_canvas_hard_non_grad = tf.stop_gradient(self.curr_canvas_hard)
+ curr_canvas_hard_non_grad = tf.expand_dims(curr_canvas_hard_non_grad, axis=-1)
+
+ # input_photo: (N, image_size, image_size, 1/3), [0.0-stroke, 1.0-BG]
+ crop_inputs = tf.concat([1.0 - input_photo, curr_canvas_hard_non_grad], axis=-1) # (N, H_p, W_p, 1/3+1)
+
+ cropped_outputs = cropping_func(cursor_position_non_grad, crop_inputs, image_size, curr_window_size)
+ index_offset = self.hps.input_channel - 1
+ curr_patch_inputs = cropped_outputs[:, :, :, 0:1 + index_offset] # [0.0-BG, 1.0-stroke]
+ curr_patch_canvas_hard_non_grad = cropped_outputs[:, :, :, 1 + index_offset:2 + index_offset]
+ # (N, raster_size, raster_size, 1), [0.0-BG, 1.0-stroke]
+
+ curr_patch_inputs = 1.0 - curr_patch_inputs # [0.0-stroke, 1.0-BG]
+ curr_patch_inputs = self.normalize_image_m1to1(curr_patch_inputs)
+ # (N, raster_size, raster_size, 1/3), [-1.0-stroke, 1.0-BG]
+
+ # Normalizing image
+ curr_patch_canvas_hard_non_grad = 1.0 - curr_patch_canvas_hard_non_grad # [0.0-stroke, 1.0-BG]
+ curr_patch_canvas_hard_non_grad = self.normalize_image_m1to1(curr_patch_canvas_hard_non_grad) # [-1.0-stroke, 1.0-BG]
+
+ ## image-level encoding
+ combined_z = self.build_combined_encoder(
+ curr_patch_canvas_hard_non_grad,
+ curr_patch_inputs,
+ 1.0 - curr_canvas_hard_non_grad,
+ input_photo,
+ cursor_position_non_grad,
+ image_size,
+ curr_window_size) # (N, z_size)
+ combined_z = tf.expand_dims(combined_z, axis=1) # (N, 1, z_size)
+
+ curr_window_size_top_side_norm_non_grad = \
+ tf.stop_gradient(curr_window_size / tf.cast(image_size, tf.float32))
+ curr_window_size_bottom_side_norm_non_grad = \
+ tf.stop_gradient(curr_window_size / tf.cast(self.hps.min_window_size, tf.float32))
+ if not self.hps.concat_win_size:
+ combined_z = tf.concat([tf.stop_gradient(prev_width), combined_z], 2) # (N, 1, 2+z_size)
+ else:
+ combined_z = tf.concat([tf.stop_gradient(prev_width),
+ curr_window_size_top_side_norm_non_grad,
+ curr_window_size_bottom_side_norm_non_grad,
+ combined_z],
+ 2) # (N, 1, 2+z_size)
+
+ if self.hps.concat_cursor:
+ prev_input_x = tf.concat([cursor_position_non_grad, combined_z], 2) # (N, 1, 2+2+z_size)
+ else:
+ prev_input_x = combined_z # (N, 1, 2+z_size)
+
+ h_output, next_state = self.build_seq_decoder(self.dec_cell, prev_input_x, prev_state)
+ # h_output: (N * 1, n_out), next_state: (N, dec_rnn_size * 3)
+ [o_other_params, o_pen_ras] = self.get_mixture_coef(h_output)
+ # o_other_params: (N * 1, 6)
+ # o_pen_ras: (N * 1, 2), after softmax
+
+ o_other_params = tf.reshape(o_other_params, [-1, 1, 6]) # (N, 1, 6)
+ o_pen_ras_raw = tf.reshape(o_pen_ras, [-1, 1, 2]) # (N, 1, 2)
+
+ other_params_list.append(o_other_params)
+ pen_ras_list.append(o_pen_ras_raw)
+
+ #### sampling part - end ####
+
+ if self.hps.model_mode == 'train' or self.hps.model_mode == 'eval' or self.hps.model_mode == 'eval_sample':
+ # use renderer here to convert the strokes to image
+ curr_other_params = tf.squeeze(o_other_params, axis=1) # (N, 6), (x1, y1)=[0.0, 1.0], (x2, y2)=[-1.0, 1.0]
+ x1y1, x2y2, width2, scaling = curr_other_params[:, 0:2], curr_other_params[:, 2:4],\
+ curr_other_params[:, 4:5], curr_other_params[:, 5:6]
+ x0y0 = tf.zeros_like(x2y2) # (N, 2), [-1.0, 1.0]
+ x0y0 = tf.div(tf.add(x0y0, 1.0), 2.0) # (N, 2), [0.0, 1.0]
+ x2y2 = tf.div(tf.add(x2y2, 1.0), 2.0) # (N, 2), [0.0, 1.0]
+ widths = tf.concat([tf.squeeze(prev_width, axis=1), width2], axis=1) # (N, 2)
+ curr_other_params = tf.concat([x0y0, x1y1, x2y2, widths], axis=-1) # (N, 8), (x0, y0)&(x2, y2)=[0.0, 1.0]
+ curr_stroke_image = rasterizor_st.raster_func_stroke_abs(curr_other_params)
+ # (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
+
+ curr_stroke_image_large = pasting_func(cursor_position_loop, curr_stroke_image,
+ image_size, curr_window_size,
+ is_differentiable=self.hps.pasting_diff)
+ # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
+
+ ## soft
+ if not self.hps.use_softargmax:
+ curr_state_soft = o_pen_ras[:, 1:2] # (N, 1)
+ else:
+ curr_state_soft = self.differentiable_argmax(o_pen_ras, self.hps.soft_beta) # (N, 1)
+
+ curr_state_soft = tf.expand_dims(curr_state_soft, axis=1) # (N, 1, 1)
+
+ filter_curr_stroke_image_soft = tf.multiply(tf.subtract(1.0, curr_state_soft), curr_stroke_image_large)
+ # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
+ curr_canvas_soft = tf.add(curr_canvas_soft, filter_curr_stroke_image_soft) # [0.0-BG, 1.0-stroke]
+
+ ## hard
+ curr_state_hard = tf.expand_dims(tf.cast(tf.argmax(o_pen_ras_raw, axis=-1), dtype=tf.float32),
+ axis=-1) # (N, 1, 1)
+ filter_curr_stroke_image_hard = tf.multiply(tf.subtract(1.0, curr_state_hard), curr_stroke_image_large)
+ # (N, image_size, image_size), [0.0-BG, 1.0-stroke]
+ self.curr_canvas_hard = tf.add(self.curr_canvas_hard, filter_curr_stroke_image_hard) # [0.0-BG, 1.0-stroke]
+ self.curr_canvas_hard = tf.clip_by_value(self.curr_canvas_hard, 0.0, 1.0) # [0.0-BG, 1.0-stroke]
+
+ next_width = o_other_params[:, :, 4:5]
+ next_scaling = o_other_params[:, :, 5:6]
+ next_window_size = tf.multiply(next_scaling, tf.stop_gradient(curr_window_size)) # float, with grad
+ window_size_before_max_min = next_window_size # (N, 1, 1), large-level
+ win_size_before_max_min_list.append(window_size_before_max_min)
+ next_window_size = tf.maximum(next_window_size, tf.cast(self.hps.min_window_size, tf.float32))
+ next_window_size = tf.minimum(next_window_size, tf.cast(image_size, tf.float32))
+
+ prev_state = next_state
+ prev_width = next_width * curr_window_size / next_window_size # (N, 1, 1)
+ prev_scaling = next_scaling # (N, 1, 1))
+ prev_window_size = curr_window_size
+
+ # update the cursor position
+ new_cursor_offsets = tf.multiply(o_other_params[:, :, 2:4],
+ tf.divide(curr_window_size, 2.0)) # (N, 1, 2), window-level
+ new_cursor_offset_next = new_cursor_offsets
+ new_cursor_offset_next = tf.concat([new_cursor_offset_next[:, :, 1:2], new_cursor_offset_next[:, :, 0:1]], axis=-1)
+
+ cursor_position_loop_large = tf.multiply(cursor_position_loop, tf.cast(image_size, tf.float32))
+
+ if self.hps.stop_accu_grad:
+ stroke_position_next = tf.stop_gradient(cursor_position_loop_large) + new_cursor_offset_next # (N, 1, 2), large-level
+ else:
+ stroke_position_next = cursor_position_loop_large + new_cursor_offset_next # (N, 1, 2), large-level
+
+ stroke_position_before_max_min = stroke_position_next # (N, 1, 2), large-level
+ pos_before_max_min_list.append(stroke_position_before_max_min)
+
+ if self.hps.cursor_type == 'next':
+ cursor_position_loop_large = stroke_position_next # (N, 1, 2), large-level
+ else:
+ raise Exception('Unknown cursor_type')
+
+ cursor_position_loop_large = tf.maximum(cursor_position_loop_large, 0.0)
+ cursor_position_loop_large = tf.minimum(cursor_position_loop_large, tf.cast(image_size - 1, tf.float32))
+ cursor_position_loop = tf.div(cursor_position_loop_large, tf.cast(image_size, tf.float32))
+
+ curr_canvas_soft = tf.clip_by_value(curr_canvas_soft, 0.0, 1.0) # (N, raster_size, raster_size), [0.0-BG, 1.0-stroke]
+
+ other_params_ = tf.reshape(tf.concat(other_params_list, axis=1), [-1, 6]) # (N * max_seq_len, 6)
+ pen_ras_ = tf.reshape(tf.concat(pen_ras_list, axis=1), [-1, 2]) # (N * max_seq_len, 2)
+ pos_before_max_min_ = tf.concat(pos_before_max_min_list, axis=1) # (N, max_seq_len, 2)
+ win_size_before_max_min_ = tf.concat(win_size_before_max_min_list, axis=1) # (N, max_seq_len, 1)
+
+ return other_params_, pen_ras_, prev_state, curr_canvas_soft, curr_canvas_soft_rgb, \
+ pos_before_max_min_, win_size_before_max_min_
+
+ def differentiable_argmax(self, input_pen, soft_beta):
+ """
+ Differentiable argmax trick.
+ :param input_pen: (N, n_class)
+ :return: pen_state: (N, 1)
+ """
+ def sign_onehot(x):
+ """
+ :param x: (N, n_class)
+ :return: (N, n_class)
+ """
+ y = tf.sign(tf.reduce_max(x, axis=-1, keepdims=True) - x)
+ y = (y - 1) * (-1)
+ return y
+
+ def softargmax(x, beta=1e2):
+ """
+ :param x: (N, n_class)
+ :param beta: 1e10 is the best. 1e2 is acceptable.
+ :return: (N)
+ """
+ x_range = tf.cumsum(tf.ones_like(x), axis=1) # (N, 2)
+ return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=1) - 1
+
+ ## Better to use softargmax(beta=1e2). The sign_onehot's gradient is close to zero.
+ # pen_onehot = sign_onehot(input_pen) # one-hot form, (N * max_seq_len, 2)
+ # pen_state = pen_onehot[:, 1:2] # (N * max_seq_len, 1)
+ pen_state = softargmax(input_pen, soft_beta)
+ pen_state = tf.expand_dims(pen_state, axis=1) # (N * max_seq_len, 1)
+ return pen_state
+
+ def build_losses(self, target_sketch, pred_raster_imgs, pred_params,
+ pos_before_max_min, win_size_before_max_min, image_size):
+ def get_raster_loss(pred_imgs, gt_imgs, loss_type):
+ perc_layer_losses_raw = []
+ perc_layer_losses_weighted = []
+ perc_layer_losses_norm = []
+
+ if loss_type == 'l1':
+ ras_cost = tf.reduce_mean(tf.abs(tf.subtract(gt_imgs, pred_imgs))) # ()
+ elif loss_type == 'l1_small':
+ gt_imgs_small = tf.image.resize_images(tf.expand_dims(gt_imgs, axis=3), (32, 32))
+ pred_imgs_small = tf.image.resize_images(tf.expand_dims(pred_imgs, axis=3), (32, 32))
+ ras_cost = tf.reduce_mean(tf.abs(tf.subtract(gt_imgs_small, pred_imgs_small))) # ()
+ elif loss_type == 'mse':
+ ras_cost = tf.reduce_mean(tf.pow(tf.subtract(gt_imgs, pred_imgs), 2)) # ()
+ elif loss_type == 'perceptual':
+ return_map_pred = vgg_net_slim(pred_imgs, image_size)
+ return_map_gt = vgg_net_slim(gt_imgs, image_size)
+ perc_loss_type = 'l1' # [l1, mse]
+ weighted_map = {'ReLU1_1': 100.0, 'ReLU1_2': 100.0,
+ 'ReLU2_1': 100.0, 'ReLU2_2': 100.0,
+ 'ReLU3_1': 10.0, 'ReLU3_2': 10.0, 'ReLU3_3': 10.0,
+ 'ReLU4_1': 1.0, 'ReLU4_2': 1.0, 'ReLU4_3': 1.0,
+ 'ReLU5_1': 1.0, 'ReLU5_2': 1.0, 'ReLU5_3': 1.0}
+
+ for perc_layer in self.hps.perc_loss_layers:
+ if perc_loss_type == 'l1':
+ perc_layer_loss = tf.reduce_mean(tf.abs(tf.subtract(return_map_pred[perc_layer],
+ return_map_gt[perc_layer]))) # ()
+ elif perc_loss_type == 'mse':
+ perc_layer_loss = tf.reduce_mean(tf.pow(tf.subtract(return_map_pred[perc_layer],
+ return_map_gt[perc_layer]), 2)) # ()
+ else:
+ raise NameError('Unknown perceptual loss type:', perc_loss_type)
+ perc_layer_losses_raw.append(perc_layer_loss)
+
+ assert perc_layer in weighted_map
+ perc_layer_losses_weighted.append(perc_layer_loss * weighted_map[perc_layer])
+
+ for loop_i in range(len(self.hps.perc_loss_layers)):
+ perc_relu_loss_raw = perc_layer_losses_raw[loop_i] # ()
+
+ if self.hps.model_mode == 'train':
+ curr_relu_mean = (self.perc_loss_mean_list[loop_i] * self.last_step_num + perc_relu_loss_raw) / (self.last_step_num + 1.0)
+ relu_cost_norm = perc_relu_loss_raw / curr_relu_mean
+ else:
+ relu_cost_norm = perc_relu_loss_raw
+ perc_layer_losses_norm.append(relu_cost_norm)
+
+ perc_layer_losses_raw = tf.stack(perc_layer_losses_raw, axis=0)
+ perc_layer_losses_norm = tf.stack(perc_layer_losses_norm, axis=0)
+
+ if self.hps.perc_loss_fuse_type == 'max':
+ ras_cost = tf.reduce_max(perc_layer_losses_norm)
+ elif self.hps.perc_loss_fuse_type == 'add':
+ ras_cost = tf.reduce_mean(perc_layer_losses_norm)
+ elif self.hps.perc_loss_fuse_type == 'raw_add':
+ ras_cost = tf.reduce_mean(perc_layer_losses_raw)
+ elif self.hps.perc_loss_fuse_type == 'weighted_sum':
+ ras_cost = tf.reduce_mean(perc_layer_losses_weighted)
+ else:
+ raise NameError('Unknown perc_loss_fuse_type:', self.hps.perc_loss_fuse_type)
+
+ elif loss_type == 'triplet':
+ raise Exception('Solution for triplet loss is coming soon.')
+ else:
+ raise NameError('Unknown loss type:', loss_type)
+
+ if loss_type != 'perceptual':
+ for perc_layer_i in self.hps.perc_loss_layers:
+ perc_layer_losses_raw.append(tf.constant(0.0))
+ perc_layer_losses_norm.append(tf.constant(0.0))
+
+ perc_layer_losses_raw = tf.stack(perc_layer_losses_raw, axis=0)
+ perc_layer_losses_norm = tf.stack(perc_layer_losses_norm, axis=0)
+
+ return ras_cost, perc_layer_losses_raw, perc_layer_losses_norm
+
+ gt_raster_images = tf.squeeze(target_sketch, axis=3) # (N, raster_h, raster_w), [0.0-stroke, 1.0-BG]
+ raster_cost, perc_relu_losses_raw, perc_relu_losses_norm = \
+ get_raster_loss(pred_raster_imgs, gt_raster_images, loss_type=self.hps.raster_loss_base_type)
+
+ def get_stroke_num_loss(input_strokes):
+ ending_state = input_strokes[:, :, 0] # (N, seq_len)
+ stroke_num_loss_pre = tf.reduce_mean(ending_state) # larger is better, [0.0, 1.0]
+ stroke_num_loss = 1.0 - stroke_num_loss_pre # lower is better, [0.0, 1.0]
+ return stroke_num_loss
+
+ stroke_num_cost = get_stroke_num_loss(pred_params) # lower is better
+
+ def get_pos_outside_loss(pos_before_max_min_):
+ pos_after_max_min = tf.maximum(pos_before_max_min_, 0.0)
+ pos_after_max_min = tf.minimum(pos_after_max_min, tf.cast(image_size - 1, tf.float32)) # (N, max_seq_len, 2)
+ pos_outside_loss = tf.reduce_mean(tf.abs(pos_before_max_min_ - pos_after_max_min))
+ return pos_outside_loss
+
+ pos_outside_cost = get_pos_outside_loss(pos_before_max_min) # lower is better
+
+ def get_win_size_outside_loss(win_size_before_max_min_, min_window_size):
+ win_size_outside_top_loss = tf.divide(
+ tf.maximum(win_size_before_max_min_ - tf.cast(image_size, tf.float32), 0.0),
+ tf.cast(image_size, tf.float32)) # (N, max_seq_len, 1)
+ win_size_outside_bottom_loss = tf.divide(
+ tf.maximum(tf.cast(min_window_size, tf.float32) - win_size_before_max_min_, 0.0),
+ tf.cast(min_window_size, tf.float32)) # (N, max_seq_len, 1)
+ win_size_outside_loss = tf.reduce_mean(win_size_outside_top_loss + win_size_outside_bottom_loss)
+ return win_size_outside_loss
+
+ win_size_outside_cost = get_win_size_outside_loss(win_size_before_max_min, self.hps.min_window_size) # lower is better
+
+ def get_early_pen_states_loss(input_strokes, curr_start, curr_end):
+ # input_strokes: (N, max_seq_len, 7)
+ pred_early_pen_states = input_strokes[:, curr_start:curr_end, 0] # (N, curr_early_len)
+ pred_early_pen_states_min = tf.reduce_min(pred_early_pen_states, axis=1) # (N), should not be 1
+ early_pen_states_loss = tf.reduce_mean(pred_early_pen_states_min) # lower is better
+ return early_pen_states_loss
+
+ early_pen_states_cost = get_early_pen_states_loss(pred_params,
+ self.early_pen_loss_start_idx, self.early_pen_loss_end_idx)
+
+ return raster_cost, stroke_num_cost, pos_outside_cost, win_size_outside_cost, \
+ early_pen_states_cost, \
+ perc_relu_losses_raw, perc_relu_losses_norm
+
+ def build_training_op_split(self, raster_cost, sn_cost, cursor_outside_cost, win_size_outside_cost,
+ early_pen_states_cost):
+ total_cost = self.hps.raster_loss_weight * raster_cost + \
+ self.hps.early_pen_loss_weight * early_pen_states_cost + \
+ self.stroke_num_loss_weight * sn_cost + \
+ self.hps.outside_loss_weight * cursor_outside_cost + \
+ self.hps.win_size_outside_loss_weight * win_size_outside_cost
+
+ tvars = [var for var in tf.trainable_variables()
+ if 'raster_unit' not in var.op.name and 'VGG16' not in var.op.name]
+ gvs = self.optimizer.compute_gradients(total_cost, var_list=tvars)
+ return total_cost, gvs
+
+ def build_training_op(self, grad_list):
+ with tf.variable_scope('train_op', reuse=tf.AUTO_REUSE):
+ gvs = self.average_gradients(grad_list)
+ g = self.hps.grad_clip
+
+ for grad, var in gvs:
+ print('>>', var.op.name)
+ if grad is None:
+ print(' >> None value')
+
+ capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs]
+
+ self.train_op = self.optimizer.apply_gradients(
+ capped_gvs, global_step=self.global_step, name='train_step')
+
+ def average_gradients(self, grads_list):
+ """
+ Compute the average gradients.
+ :param grads_list: list(of length N_GPU) of list(grad, var)
+ :return:
+ """
+ avg_grads = []
+ for grad_and_vars in zip(*grads_list):
+ grads = []
+ for g, _ in grad_and_vars:
+ expanded_g = tf.expand_dims(g, 0)
+ grads.append(expanded_g)
+ grad = tf.concat(grads, axis=0)
+ grad = tf.reduce_mean(grad, axis=0)
+
+ v = grad_and_vars[0][1]
+ grad_and_var = (grad, v)
+ avg_grads.append(grad_and_var)
+
+ return avg_grads
\ No newline at end of file
diff --git a/hi-arm/qmupd_vs/models/__init__.py b/hi-arm/qmupd_vs/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc01113da66ff042bd1807b5bfdb70c4bce8d14c
--- /dev/null
+++ b/hi-arm/qmupd_vs/models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from models.base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/hi-arm/qmupd_vs/models/base_model.py b/hi-arm/qmupd_vs/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d06337d4ee138db99a94032b40fe6ad9c8627f4b
--- /dev/null
+++ b/hi-arm/qmupd_vs/models/base_model.py
@@ -0,0 +1,248 @@
+import os
+import torch
+from collections import OrderedDict
+from abc import ABCMeta, abstractmethod
+from . import networks
+import pdb
+
+
+class BaseModel():
+ __metaclass__ = ABCMeta
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this fucntion, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): specify the images that you want to display and save.
+ -- self.visual_names (str list): define networks used in our training.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
+ torch.backends.cudnn.benchmark = True
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+ if not self.isTrain or opt.continue_train:
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
+ self.load_networks(load_suffix)
+ self.print_networks(opt.verbose)
+
+ def eval(self):
+ """Make models eval mode during test time"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self):
+ """ Return image paths that are used to load current data"""
+ return self.image_paths
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == 'plateau':
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate = %.7f' % lr)
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ for name in self.model_names:
+ if isinstance(name, str):
+ save_filename = '%s_net_%s.pth' % (epoch, name)
+ save_path = os.path.join(self.save_dir, save_filename)
+ net = getattr(self, 'net' + name)
+
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
+ torch.save(net.module.cpu().state_dict(), save_path)
+ net.cuda(self.gpu_ids[0])
+ else:
+ torch.save(net.cpu().state_dict(), save_path)
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ for name in self.model_names:
+ if isinstance(name, str):
+ load_filename = '%s_net_%s.pth' % (epoch, name)
+ load_path = os.path.join(self.save_dir, load_filename)
+ net = getattr(self, 'net' + name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ print('loading the model from %s' % load_path)
+ # if you are using PyTorch newer than 0.4 (e.g., built from
+ # GitHub source), you can remove str() on self.device
+ state_dict = torch.load(load_path, map_location=str(self.device))
+ if hasattr(state_dict, '_metadata'):
+ del state_dict._metadata
+
+ # patch InstanceNorm checkpoints prior to 0.4
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
+ net.load_state_dict(state_dict)
+ #param1 = {}
+ #for name, parameters in net.named_parameters():
+ # print(name,',',parameters.size())
+ # param1[name] = parameters.detach().cpu().numpy()
+ #pdb.set_trace()
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ print('-----------------------------------------------')
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+ # ===========================================================================================================
+ def masked(self, A,mask):
+ if self.opt.mask_type == 0:
+ return (A/2+0.5)*mask*2-1
+ elif self.opt.mask_type == 1:
+ return ((A/2+0.5)*mask+1-mask)*2-1
+ elif self.opt.mask_type == 2:
+ return torch.cat((A, mask), 1)
+ elif self.opt.mask_type == 3:
+ masked = ((A/2+0.5)*mask+1-mask)*2-1
+ return torch.cat((masked, mask), 1)
\ No newline at end of file
diff --git a/hi-arm/qmupd_vs/models/cycle_gan_cls_model.py b/hi-arm/qmupd_vs/models/cycle_gan_cls_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8883fcec78f150470728571ae2c1c6f9fbbd0346
--- /dev/null
+++ b/hi-arm/qmupd_vs/models/cycle_gan_cls_model.py
@@ -0,0 +1,565 @@
+import torch
+import itertools
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+import models.dist_model as dm # numpy==1.14.3
+import torchvision.transforms as transforms
+import os
+from util.util import tensor2im, tensor2im2, save_image
+
+def truncate(fake_B,a=127.5):#[-1,1]
+ #return torch.round((fake_B+1)*a)/a-1
+ return ((fake_B+1)*a).int().float()/a-1
+
+class CycleGANClsModel(BaseModel):
+ """
+ This class implements the CycleGAN model, for learning image-to-image translation without paired data.
+
+ The model training requires '--dataset_mode unaligned' dataset.
+ By default, it uses a '--netG resnet_9blocks' ResNet generator,
+ a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
+ and a least-square GANs objective ('--gan_mode lsgan').
+
+ CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
+ """
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+
+ For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
+ A (source domain), B (target domain).
+ Generators: G_A: A -> B; G_B: B -> A.
+ Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
+ Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
+ Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
+ Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
+ Dropout is not used in the original CycleGAN paper.
+ """
+ parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout
+ parser.set_defaults(dataset_mode='unaligned_mask_stylecls')
+ parser.add_argument('--netda', type=str, default='basic_cls') # discriminator has two branches
+ parser.add_argument('--truncate', type=float, default=0.0, help='whether truncate in forward')
+ if is_train:
+ parser.add_argument('--lambda_A', type=float, default=5.0, help='weight for cycle loss (A -> B -> A)')
+ parser.add_argument('--lambda_B', type=float, default=5.0, help='weight for cycle loss (B -> A -> B)')
+ parser.add_argument('--lambda_identity', type=float, default=0, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
+ parser.add_argument('--perceptual_cycle', type=int, default=6, help='whether use perceptual similarity for cycle loss')
+ parser.add_argument('--use_hed', type=int, default=1, help='whether use hed processing for cycle loss')
+ parser.add_argument('--ntrunc_trunc', type=int, default=1, help='whether use both non-trunc version and trunc version')
+ parser.add_argument('--trunc_a', type=float, default=31.875, help='multiply which value to round when trunc')
+ parser.add_argument('--lambda_A_trunc', type=float, default=5.0, help='weight for cycle loss for trunc')
+ parser.add_argument('--hed_pretrained_mode', type=str, default='./checkpoints/network-bsds500.pytorch', help='path to the pretrained hed model')
+ parser.add_argument('--vgg_pretrained_mode', type=str, default='./checkpoints/vgg19.pth', help='path to the pretrained vgg model')
+ parser.add_argument('--lambda_G_A_l', type=float, default=0.5, help='weight for local GAN loss in G')
+ parser.add_argument('--style_loss_with_weight', type=int, default=0, help='whether multiply prob in style loss')
+ parser.add_argument('--metric', action='store_true', help='whether use metric loss for fakeB')
+ parser.add_argument('--metric_model_path', type=str, default='3/30_net_Regressor.pth', help='metric model path')
+ parser.add_argument('--lambda_metric', type=float, default=0.5, help='weight for metric loss')
+ parser.add_argument('--metricvec', action='store_true', help='whether use metric model with vec input')
+ parser.add_argument('--metric_resnext', action='store_true', help='whether use resnext as metric model')
+ parser.add_argument('--metric_resnet', action='store_true', help='whether use resnet as metric model')
+ parser.add_argument('--metric_inception', action='store_true', help='whether use inception as metric model')# the inception of transform_input=False
+ parser.add_argument('--metric_inmask', action='store_true', help='whether use inmask in metric model')
+ else:
+ parser.add_argument('--check_D', action='store_true', help='whether use check Ds outputs')
+ # for masks
+ parser.add_argument('--use_mask', type=int, default=1, help='whether use mask for special face region')
+ parser.add_argument('--use_eye_mask', type=int, default=1, help='whether use mask for special face region')
+ parser.add_argument('--use_lip_mask', type=int, default=1, help='whether use mask for special face region')
+ parser.add_argument('--mask_type', type=int, default=3, help='use mask type, 0 outside black, 1 outside white')
+ # for style control
+ parser.add_argument('--style_control', type=int, default=1, help='use style_control')
+ parser.add_argument('--sfeature_mode', type=str, default='1vgg19_softmax', help='vgg19 softmax as feature')
+ parser.add_argument('--netga', type=str, default='resnet_style_9blocks', help='net arch for netG_A')
+ parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0 (before insert style)')
+ parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')
+ parser.add_argument('--one_hot', type=int, default=0, help='use one-hot for style code')
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize the CycleGAN class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseModel.__init__(self, opt)
+ # specify the training losses you want to print out. The training/test scripts will call
+ self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
+ # specify the images you want to save/display. The training/test scripts will call
+ visual_names_A = ['real_A', 'fake_B', 'rec_A']
+ visual_names_B = ['real_B', 'fake_A', 'rec_B']
+ if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
+ visual_names_A.append('idt_B')
+ visual_names_B.append('idt_A')
+ if self.isTrain and self.opt.use_hed:
+ visual_names_A.append('real_A_hed')
+ visual_names_A.append('rec_A_hed')
+ if self.isTrain and self.opt.ntrunc_trunc:
+ visual_names_A.append('rec_At')
+ if self.opt.use_hed:
+ visual_names_A.append('rec_At_hed')
+ self.loss_names = ['D_A', 'G_A', 'cycle_A', 'cycle_A2', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'G']
+ if self.isTrain and self.opt.use_mask:
+ visual_names_A.append('fake_B_l')
+ visual_names_A.append('real_B_l')
+ self.loss_names += ['D_A_l', 'G_A_l']
+ if self.isTrain and self.opt.use_eye_mask:
+ visual_names_A.append('fake_B_le')
+ visual_names_A.append('real_B_le')
+ self.loss_names += ['D_A_le', 'G_A_le']
+ if self.isTrain and self.opt.use_lip_mask:
+ visual_names_A.append('fake_B_ll')
+ visual_names_A.append('real_B_ll')
+ self.loss_names += ['D_A_ll', 'G_A_ll']
+ if self.isTrain and self.opt.metric:
+ self.loss_names += ['metric']
+ #visual_names_B += ['fake_B2']
+ if not self.isTrain and self.opt.use_mask:
+ visual_names_A.append('fake_B_l')
+ visual_names_A.append('real_B_l')
+ if not self.isTrain and self.opt.use_eye_mask:
+ visual_names_A.append('fake_B_le')
+ visual_names_A.append('real_B_le')
+ if not self.isTrain and self.opt.use_lip_mask:
+ visual_names_A.append('fake_B_ll')
+ visual_names_A.append('real_B_ll')
+ self.loss_names += ['D_A_cls','G_A_cls']
+
+ self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
+ print(self.visual_names)
+ # specify the models you want to save to the disk. The training/test scripts will call and .
+ if self.isTrain:
+ self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
+ if self.opt.use_mask:
+ self.model_names += ['D_A_l']
+ if self.opt.use_eye_mask:
+ self.model_names += ['D_A_le']
+ if self.opt.use_lip_mask:
+ self.model_names += ['D_A_ll']
+ else: # during test time, only load Gs
+ self.model_names = ['G_A', 'G_B']
+ if self.opt.check_D:
+ self.model_names += ['D_A', 'D_B']
+
+ # define networks (both Generators and discriminators)
+ # The naming is different from those used in the paper.
+ # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
+ if not self.opt.style_control:
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
+ else:
+ print(opt.netga)
+ print('model0_res', opt.model0_res)
+ print('model1_res', opt.model1_res)
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
+ self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
+
+ #if self.isTrain: # define discriminators
+ if self.isTrain or self.opt.check_D: # define discriminators
+ self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netda,
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_class=3)
+ self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
+ if self.opt.use_mask:
+ if self.opt.mask_type in [2, 3]:
+ output_nc = opt.output_nc + 1
+ else:
+ output_nc = opt.output_nc
+ self.netD_A_l = networks.define_D(output_nc, opt.ndf, opt.netD,
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
+ if self.opt.use_eye_mask:
+ if self.opt.mask_type in [2, 3]:
+ output_nc = opt.output_nc + 1
+ else:
+ output_nc = opt.output_nc
+ self.netD_A_le = networks.define_D(output_nc, opt.ndf, opt.netD,
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
+ if self.opt.use_lip_mask:
+ if self.opt.mask_type in [2, 3]:
+ output_nc = opt.output_nc + 1
+ else:
+ output_nc = opt.output_nc
+ self.netD_A_ll = networks.define_D(output_nc, opt.ndf, opt.netD,
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
+
+ if self.isTrain and self.opt.metric:
+ if not opt.metric_resnext and not opt.metric_resnet and not opt.metric_inception:
+ self.metric = networks.define_inception_v3a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec)
+ elif opt.metric_resnext:
+ self.metric = networks.define_resnext101a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec)
+ elif opt.metric_resnet:
+ self.metric = networks.define_resnet101a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec)
+ elif opt.metric_inception:
+ self.metric = networks.define_inception3a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec)
+ self.metric.eval()
+ self.set_requires_grad(self.metric, False)
+
+ if not self.isTrain and self.opt.check_D:
+ self.criterionGAN = networks.GANLoss('lsgan').to(self.device)
+
+ if self.isTrain:
+ if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
+ assert(opt.input_nc == opt.output_nc)
+ self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
+ self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
+ self.criterionCycle = torch.nn.L1Loss()
+ self.criterionIdt = torch.nn.L1Loss()
+ self.criterionCls = torch.nn.CrossEntropyLoss()
+ self.criterionCls2 = torch.nn.CrossEntropyLoss(reduction='none')
+ # initialize optimizers; schedulers will be automatically created by function .
+ self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
+ if not self.opt.use_mask:
+ self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
+ elif not self.opt.use_eye_mask:
+ D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters())
+ self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
+ elif not self.opt.use_lip_mask:
+ D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters())
+ self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
+ else:
+ D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) + list(self.netD_A_ll.parameters())
+ self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D)
+
+ if self.opt.perceptual_cycle:
+ if self.opt.perceptual_cycle in [1,2,3,6]:
+ self.lpips = dm.DistModel(opt,model='net-lin',net='alex',use_gpu=True)
+ elif self.opt.perceptual_cycle in [4,5,8]:
+ self.vgg = networks.define_VGG(init_weights_=opt.vgg_pretrained_mode, feature_mode_=True, gpu_ids_=self.gpu_ids) # using conv4_4 layer
+
+ if self.opt.use_hed:
+ #self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.gpu_ids)
+ self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.opt.gpu_ids_p)
+ self.set_requires_grad(self.hed, False)
+
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): include the data itself and its metadata information.
+
+ The option 'direction' can be used to swap domain A and domain B.
+ """
+ AtoB = self.opt.direction == 'AtoB'
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+ if self.opt.use_mask:
+ self.A_mask = input['A_mask'].to(self.device)
+ self.B_mask = input['B_mask'].to(self.device)
+ if self.opt.use_eye_mask:
+ self.A_maske = input['A_maske'].to(self.device)
+ self.B_maske = input['B_maske'].to(self.device)
+ if self.opt.use_lip_mask:
+ self.A_maskl = input['A_maskl'].to(self.device)
+ self.B_maskl = input['B_maskl'].to(self.device)
+ if self.opt.style_control:
+ self.real_B_style = input['B_style'].to(self.device)
+ self.real_B_label = input['B_label'].to(self.device)
+ if self.opt.isTrain and self.opt.style_loss_with_weight:
+ self.real_B_style0 = input['B_style0'].to(self.device)
+ self.zero = torch.zeros(self.real_B_label.size(),dtype=torch.int64).to(self.device)
+ self.one = torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)
+ self.two = 2*torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)
+ if self.opt.isTrain and self.opt.metricvec:
+ self.vec = input['vec'].to(self.device)
+ if self.opt.isTrain and self.opt.metric_inmask:
+ self.A_maskfg = input['A_maskfg'].to(self.device)
+
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ if not self.opt.style_control:
+ self.fake_B = self.netG_A(self.real_A) # G_A(A)
+ else:
+ #print(torch.mean(self.real_B_style,(2,3)),'style_control')
+ #print(self.real_B_style,'style_control')
+ self.fake_B = self.netG_A(self.real_A, self.real_B_style)
+ self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
+ self.fake_A = self.netG_B(self.real_B) # G_B(B)
+ if not self.opt.style_control:
+ self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
+ else:
+ #print(torch.mean(self.real_B_style,(2,3)),'style_control')
+ self.rec_B = self.netG_A(self.fake_A, self.real_B_style) # -- cycle_B loss
+
+ if self.opt.use_mask:
+ self.fake_B_l = self.masked(self.fake_B,self.A_mask)
+ self.real_B_l = self.masked(self.real_B,self.B_mask)
+ if self.opt.use_eye_mask:
+ self.fake_B_le = self.masked(self.fake_B,self.A_maske)
+ self.real_B_le = self.masked(self.real_B,self.B_maske)
+ if self.opt.use_lip_mask:
+ self.fake_B_ll = self.masked(self.fake_B,self.A_maskl)
+ self.real_B_ll = self.masked(self.real_B,self.B_maskl)
+
+ def backward_D_basic(self, netD, real, fake):
+ """Calculate GAN loss for the discriminator
+
+ Parameters:
+ netD (network) -- the discriminator D
+ real (tensor array) -- real images
+ fake (tensor array) -- images generated by a generator
+
+ Return the discriminator loss.
+ We also call loss_D.backward() to calculate the gradients.
+ """
+ # Real
+ pred_real = netD(real)
+ loss_D_real = self.criterionGAN(pred_real, True)
+ # Fake
+ pred_fake = netD(fake.detach())
+ loss_D_fake = self.criterionGAN(pred_fake, False)
+ # Combined loss and calculate gradients
+ loss_D = (loss_D_real + loss_D_fake) * 0.5
+ loss_D.backward()
+ return loss_D
+
+ def backward_D_basic_cls(self, netD, real, fake):
+ # Real
+ pred_real, pred_real_cls = netD(real)
+ loss_D_real = self.criterionGAN(pred_real, True)
+ if not self.opt.style_loss_with_weight:
+ loss_D_real_cls = self.criterionCls(pred_real_cls, self.real_B_label)
+ else:
+ loss_D_real_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_real_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_real_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_real_cls, self.two))
+ # Fake
+ pred_fake, pred_fake_cls = netD(fake.detach())
+ loss_D_fake = self.criterionGAN(pred_fake, False)
+ if not self.opt.style_loss_with_weight:
+ loss_D_fake_cls = self.criterionCls(pred_fake_cls, self.real_B_label)
+ else:
+ loss_D_fake_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))
+ # Combined loss and calculate gradients
+ loss_D = (loss_D_real + loss_D_fake) * 0.5
+ loss_D_cls = (loss_D_real_cls + loss_D_fake_cls) * 0.5
+ loss_D_total = loss_D + loss_D_cls
+ loss_D_total.backward()
+ return loss_D, loss_D_cls
+
+ def backward_D_A(self):
+ """Calculate GAN loss for discriminator D_A"""
+ fake_B = self.fake_B_pool.query(self.fake_B)
+ self.loss_D_A, self.loss_D_A_cls = self.backward_D_basic_cls(self.netD_A, self.real_B, fake_B)
+
+ def backward_D_A_l(self):
+ """Calculate GAN loss for discriminator D_A_l"""
+ fake_B = self.fake_B_pool.query(self.fake_B)
+ self.loss_D_A_l = self.backward_D_basic(self.netD_A_l, self.masked(self.real_B,self.B_mask), self.masked(fake_B,self.A_mask))
+
+ def backward_D_A_le(self):
+ """Calculate GAN loss for discriminator D_A_le"""
+ fake_B = self.fake_B_pool.query(self.fake_B)
+ self.loss_D_A_le = self.backward_D_basic(self.netD_A_le, self.masked(self.real_B,self.B_maske), self.masked(fake_B,self.A_maske))
+
+ def backward_D_A_ll(self):
+ """Calculate GAN loss for discriminator D_A_ll"""
+ fake_B = self.fake_B_pool.query(self.fake_B)
+ self.loss_D_A_ll = self.backward_D_basic(self.netD_A_ll, self.masked(self.real_B,self.B_maskl), self.masked(fake_B,self.A_maskl))
+
+ def backward_D_B(self):
+ """Calculate GAN loss for discriminator D_B"""
+ fake_A = self.fake_A_pool.query(self.fake_A)
+ self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
+
+ def update_process(self, epoch):
+ self.process = (epoch - 1) / float(self.opt.niter_decay + self.opt.niter)
+
+ def backward_G(self):
+ """Calculate the loss for generators G_A and G_B"""
+ lambda_idt = self.opt.lambda_identity
+ lambda_G_A_l = self.opt.lambda_G_A_l
+ lambda_A = self.opt.lambda_A
+ lambda_B = self.opt.lambda_B
+ lambda_A_trunc = self.opt.lambda_A_trunc
+ if self.opt.ntrunc_trunc:
+ lambda_A = lambda_A * (1 - self.process * 0.9)
+ lambda_A_trunc = lambda_A_trunc * self.process * 0.9
+ self.lambda_As = [lambda_A, lambda_A_trunc]
+ # Identity loss
+ if lambda_idt > 0:
+ # G_A should be identity if real_B is fed: ||G_A(B) - B||
+ self.idt_A = self.netG_A(self.real_B)
+ self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
+ # G_B should be identity if real_A is fed: ||G_B(A) - A||
+ self.idt_B = self.netG_B(self.real_A)
+ self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
+ else:
+ self.loss_idt_A = 0
+ self.loss_idt_B = 0
+
+ # GAN loss D_A(G_A(A))
+ pred_fake, pred_fake_cls = self.netD_A(self.fake_B)
+ self.loss_G_A = self.criterionGAN(pred_fake, True)
+ if not self.opt.style_loss_with_weight:
+ self.loss_G_A_cls = self.criterionCls(pred_fake_cls, self.real_B_label)
+ else:
+ self.loss_G_A_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))
+ if self.opt.use_mask:
+ self.loss_G_A_l = self.criterionGAN(self.netD_A_l(self.fake_B_l), True) * lambda_G_A_l
+ if self.opt.use_eye_mask:
+ self.loss_G_A_le = self.criterionGAN(self.netD_A_le(self.fake_B_le), True) * lambda_G_A_l
+ if self.opt.use_lip_mask:
+ self.loss_G_A_ll = self.criterionGAN(self.netD_A_ll(self.fake_B_ll), True) * lambda_G_A_l
+ # GAN loss D_B(G_B(B))
+ self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
+ # Forward cycle loss || G_B(G_A(A)) - A||
+ if self.opt.perceptual_cycle == 0:
+ self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
+ if self.opt.ntrunc_trunc:
+ self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
+ self.loss_cycle_A2 = self.criterionCycle(self.rec_At, self.real_A) * lambda_A_trunc
+ else:
+ if self.opt.perceptual_cycle == 1:
+ self.loss_cycle_A = self.lpips.forward_pair(self.rec_A, self.real_A).mean() * lambda_A
+ if self.opt.ntrunc_trunc:
+ self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
+ self.loss_cycle_A2 = self.lpips.forward_pair(self.rec_At, self.real_A).mean() * lambda_A_trunc
+ elif self.opt.perceptual_cycle == 2:
+ ts = self.real_A.shape
+ rec_A = (self.rec_A[:,0,:,:]*0.299+self.rec_A[:,1,:,:]*0.587+self.rec_A[:,2,:,:]*0.114).unsqueeze(0)
+ real_A = (self.real_A[:,0,:,:]*0.299+self.real_A[:,1,:,:]*0.587+self.real_A[:,2,:,:]*0.114).unsqueeze(0)
+ self.loss_cycle_A = self.lpips.forward_pair(rec_A.expand(ts), real_A.expand(ts)).mean() * lambda_A
+ elif self.opt.perceptual_cycle == 3 and self.opt.use_hed:
+ ts = self.real_A.shape
+ #[-1,1]->[0,1]->[-1,1]
+ rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2
+ real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2
+ self.loss_cycle_A = self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean() * lambda_A
+ self.rec_A_hed = rec_A_hed
+ self.real_A_hed = real_A_hed
+ print(lambda_A)
+ elif self.opt.perceptual_cycle == 4:
+ x_a_feature = self.vgg(self.real_A)
+ g_a_feature = self.vgg(self.rec_A)
+ self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A
+ elif self.opt.perceptual_cycle == 5 and self.opt.use_hed:
+ ts = self.real_A.shape
+ rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2
+ real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2
+ x_a_feature = self.vgg(real_A_hed.expand(ts))
+ g_a_feature = self.vgg(rec_A_hed.expand(ts))
+ self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A
+ self.rec_A_hed = rec_A_hed
+ self.real_A_hed = real_A_hed
+ elif self.opt.perceptual_cycle == 6 and self.opt.use_hed and self.opt.ntrunc_trunc:
+ ts = self.real_A.shape
+ gpu_p = self.opt.gpu_ids_p[0]
+ gpu = self.opt.gpu_ids[0]
+ rec_A_hed = (self.hed(self.rec_A.cuda(gpu_p)/2+0.5)-0.5)*2
+ real_A_hed = (self.hed(self.real_A.cuda(gpu_p)/2+0.5)-0.5)*2
+ self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
+ rec_At_hed = (self.hed(self.rec_At.cuda(gpu_p)/2+0.5)-0.5)*2
+ self.loss_cycle_A = (self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A
+ self.loss_cycle_A2 = (self.lpips.forward_pair(rec_At_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A_trunc
+ self.rec_A_hed = rec_A_hed
+ self.real_A_hed = real_A_hed
+ self.rec_At_hed = rec_At_hed
+ elif self.opt.perceptual_cycle == 8 and self.opt.use_hed and self.opt.ntrunc_trunc:
+ ts = self.real_A.shape
+ rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2
+ real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2
+ self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
+ rec_At_hed = (self.hed(self.rec_At/2+0.5)-0.5)*2
+ x_a_feature = self.vgg(real_A_hed.expand(ts))
+ g_a_feature = self.vgg(rec_A_hed.expand(ts))
+ gt_a_feature = self.vgg(rec_At_hed.expand(ts))
+ self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A
+ self.loss_cycle_A2 = self.criterionCycle(gt_a_feature, x_a_feature.detach()) * lambda_A_trunc
+ self.rec_A_hed = rec_A_hed
+ self.real_A_hed = real_A_hed
+ self.rec_At_hed = rec_At_hed
+
+ # Backward cycle loss || G_A(G_B(B)) - B||
+ self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
+
+ # Metric loss, metric higher better
+ if self.opt.metric:
+ self.fake_B2 = self.fake_B.clone()
+ if self.opt.metric_inmask:
+ # background black
+ #self.fake_B2 = (self.fake_B2/2+0.5)*self.A_maskfg*2-1
+ # background white
+ self.fake_B2 = ((self.fake_B2/2+0.5)*self.A_maskfg+1-self.A_maskfg)*2-1
+ if not self.opt.metric_resnext and not self.opt.metric_resnet: # for two version of inception (during training input is [-1,1])
+ self.fake_B2 = torch.nn.functional.interpolate(input=self.fake_B2, size=(299, 299), mode='bilinear', align_corners=False)
+ self.fake_B2 = self.fake_B2.repeat(1,3,1,1)
+ else: # for resnet and resnext
+ self.fake_B2 = torch.nn.functional.interpolate(input=self.fake_B2, size=(224, 224), mode='bilinear', align_corners=False)
+ x = self.fake_B2.repeat(1,3,1,1)
+ # [-1,1] -> [0,1] -> mean [0.485,0.456,0.406], std [0.229,0.224,0.225]
+ x_ch0 = (torch.unsqueeze(x[:, 0],1)*0.5+0.5-0.485)/0.229
+ x_ch1 = (torch.unsqueeze(x[:, 1],1)*0.5+0.5-0.456)/0.224
+ x_ch2 = (torch.unsqueeze(x[:, 2],1)*0.5+0.5-0.406)/0.225
+ self.fake_B2 = torch.cat((x_ch0, x_ch1, x_ch2, x[:, 3:]), 1)
+
+
+ if not self.opt.metricvec:
+ pred = self.metric(self.fake_B2)
+ else:
+ pred = self.metric(torch.cat((self.fake_B2, self.vec),1))
+ self.loss_metric = torch.mean((1-pred)) * self.opt.lambda_metric
+
+ # combined loss and calculate gradients
+ self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
+ if getattr(self,'loss_cycle_A2',-1) != -1:
+ self.loss_G = self.loss_G + self.loss_cycle_A2
+ if getattr(self,'loss_G_A_l',-1) != -1:
+ self.loss_G = self.loss_G + self.loss_G_A_l
+ if getattr(self,'loss_G_A_le',-1) != -1:
+ self.loss_G = self.loss_G + self.loss_G_A_le
+ if getattr(self,'loss_G_A_ll',-1) != -1:
+ self.loss_G = self.loss_G + self.loss_G_A_ll
+ if getattr(self,'loss_G_A_cls',-1) != -1:
+ self.loss_G = self.loss_G + self.loss_G_A_cls
+ if getattr(self,'loss_metric',-1) != -1:
+ self.loss_G = self.loss_G + self.loss_metric
+ self.loss_G.backward()
+
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ # forward
+ self.forward() # compute fake images and reconstruction images.
+ # G_A and G_B
+ self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
+ if self.opt.use_mask:
+ self.set_requires_grad([self.netD_A_l], False)
+ if self.opt.use_eye_mask:
+ self.set_requires_grad([self.netD_A_le], False)
+ if self.opt.use_lip_mask:
+ self.set_requires_grad([self.netD_A_ll], False)
+ self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
+ self.backward_G() # calculate gradients for G_A and G_B
+ self.optimizer_G.step() # update G_A and G_B's weights
+ # D_A and D_B
+ self.set_requires_grad([self.netD_A, self.netD_B], True)
+ if self.opt.use_mask:
+ self.set_requires_grad([self.netD_A_l], True)
+ if self.opt.use_eye_mask:
+ self.set_requires_grad([self.netD_A_le], True)
+ if self.opt.use_lip_mask:
+ self.set_requires_grad([self.netD_A_ll], True)
+ self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
+ self.backward_D_A() # calculate gradients for D_A
+ if self.opt.use_mask:
+ self.backward_D_A_l()# calculate gradients for D_A_l
+ if self.opt.use_eye_mask:
+ self.backward_D_A_le()# calculate gradients for D_A_le
+ if self.opt.use_lip_mask:
+ self.backward_D_A_ll()# calculate gradients for D_A_ll
+ self.backward_D_B() # calculate graidents for D_B
+ self.optimizer_D.step() # update D_A and D_B's weights
diff --git a/hi-arm/qmupd_vs/models/dist_model.py b/hi-arm/qmupd_vs/models/dist_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61d5de0214978ef071cb520dcbed77882c59836
--- /dev/null
+++ b/hi-arm/qmupd_vs/models/dist_model.py
@@ -0,0 +1,323 @@
+
+from __future__ import absolute_import
+
+import sys
+sys.path.append('..')
+sys.path.append('.')
+import numpy as np
+import torch
+from torch import nn
+from collections import OrderedDict
+from torch.autograd import Variable
+from .base_model import BaseModel
+from scipy.ndimage import zoom
+import skimage.transform
+
+from . import networks_basic as networks
+# from PerceptualSimilarity.util import util
+from util import util
+
+class DistModel(BaseModel):
+ def name(self):
+ return self.model_name
+
+ def __init__(self, opt, model='net-lin', net='alex', pnet_rand=False, pnet_tune=False, model_path=None, colorspace='Lab', use_gpu=True, printNet=False, spatial=False, spatial_shape=None, spatial_order=1, spatial_factor=None, is_train=False, lr=.0001, beta1=0.5, version='0.1'):
+ '''
+ INPUTS
+ model - ['net-lin'] for linearly calibrated network
+ ['net'] for off-the-shelf network
+ ['L2'] for L2 distance in Lab colorspace
+ ['SSIM'] for ssim in RGB colorspace
+ net - ['squeeze','alex','vgg']
+ model_path - if None, will look in weights/[NET_NAME].pth
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
+ use_gpu - bool - whether or not to use a GPU
+ printNet - bool - whether or not to print network architecture out
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
+ is_train - bool - [True] for training mode
+ lr - float - initial learning rate
+ beta1 - float - initial momentum term for adam
+ version - 0.1 for latest, 0.0 was original
+ '''
+ BaseModel.__init__(self, opt)
+
+ self.model = model
+ self.net = net
+ self.use_gpu = use_gpu
+ self.is_train = is_train
+ self.spatial = spatial
+ self.spatial_shape = spatial_shape
+ self.spatial_order = spatial_order
+ self.spatial_factor = spatial_factor
+
+ self.model_name = '%s [%s]'%(model,net)
+ if(self.model == 'net-lin'): # pretrained net + linear layer
+ #self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
+ self.device = torch.device('cuda:{}'.format(opt.gpu_ids_p[0])) if opt.gpu_ids_p else torch.device('cpu')
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,use_dropout=True,spatial=spatial,version=version,lpips=True).to(self.device)
+ kw = {}
+
+ if not use_gpu:
+ kw['map_location'] = 'cpu'
+ if(model_path is None):
+ import inspect
+ #model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', '..', 'weights/v%s/%s.pth'%(version,net)))
+ model_path = './checkpoints/weights/v%s/%s.pth'%(version,net)
+
+ if(not is_train):
+ print('Loading model from: %s'%model_path)
+ #self.net.load_state_dict(torch.load(model_path, **kw))
+ state_dict = torch.load(model_path, map_location=str(self.device))
+ self.net.load_state_dict(state_dict, strict=False)
+
+ elif(self.model=='net'): # pretrained network
+ assert not self.spatial, 'spatial argument not supported yet for uncalibrated networks'
+ self.net = networks.PNet(use_gpu=use_gpu,pnet_type=net,device=self.device)
+ self.is_fake_net = True
+ elif(self.model in ['L2','l2']):
+ self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace,device=self.device) # not really a network, only for testing
+ self.model_name = 'L2'
+ elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
+ self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace,device=self.device)
+ self.model_name = 'SSIM'
+ else:
+ raise ValueError("Model [%s] not recognized." % self.model)
+
+ self.parameters = list(self.net.parameters())
+
+ if self.is_train: # training mode
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
+ self.rankLoss = networks.BCERankingLoss(use_gpu=use_gpu,device=self.device)
+ self.parameters+=self.rankLoss.parameters
+ self.lr = lr
+ self.old_lr = lr
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
+ else: # test mode
+ self.net.eval()
+
+ if(printNet):
+ print('---------- Networks initialized -------------')
+ networks.print_network(self.net)
+ print('-----------------------------------------------')
+
+ def forward_pair(self,in1,in2,retPerLayer=False):
+ if(retPerLayer):
+ return self.net.forward(in1,in2, retPerLayer=True)
+ else:
+ return self.net.forward(in1,in2)
+
+ def forward(self, in0, in1, retNumpy=False):
+ ''' Function computes the distance between image patches in0 and in1
+ INPUTS
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
+ retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array
+ OUTPUT
+ computed distances between in0 and in1
+ '''
+
+ self.input_ref = in0
+ self.input_p0 = in1
+
+ self.var_ref = Variable(self.input_ref,requires_grad=True)
+ self.var_p0 = Variable(self.input_p0,requires_grad=True)
+
+ self.d0 = self.forward_pair(self.var_ref, self.var_p0)
+ self.loss_total = self.d0
+
+ def convert_output(d0):
+ if(retNumpy):
+ ans = d0.cpu().data.numpy()
+ if not self.spatial:
+ ans = ans.flatten()
+ else:
+ assert(ans.shape[0] == 1 and len(ans.shape) == 4)
+ return ans[0,...].transpose([1, 2, 0]) # Reshape to usual numpy image format: (height, width, channels)
+ return ans
+ else:
+ return d0
+
+ if self.spatial:
+ L = [convert_output(x) for x in self.d0]
+ spatial_shape = self.spatial_shape
+ if spatial_shape is None:
+ if(self.spatial_factor is None):
+ spatial_shape = (in0.size()[2],in0.size()[3])
+ else:
+ spatial_shape = (max([x.shape[0] for x in L])*self.spatial_factor, max([x.shape[1] for x in L])*self.spatial_factor)
+
+ L = [skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L]
+
+ L = np.mean(np.concatenate(L, 2) * len(L), 2)
+ return L
+ else:
+ return convert_output(self.d0)
+
+ # ***** TRAINING FUNCTIONS *****
+ def optimize_parameters(self):
+ self.forward_train()
+ self.optimizer_net.zero_grad()
+ self.backward_train()
+ self.optimizer_net.step()
+ self.clamp_weights()
+
+ def clamp_weights(self):
+ for module in self.net.modules():
+ if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
+ module.weight.data = torch.clamp(module.weight.data,min=0)
+
+ def set_input(self, data):
+ self.input_ref = data['ref']
+ self.input_p0 = data['p0']
+ self.input_p1 = data['p1']
+ self.input_judge = data['judge']
+
+ if(self.use_gpu):
+ self.input_ref = self.input_ref.cuda(self.device)
+ self.input_p0 = self.input_p0.cuda(self.device)
+ self.input_p1 = self.input_p1.cuda(self.device)
+ self.input_judge = self.input_judge.cuda(self.device)
+
+ self.var_ref = Variable(self.input_ref,requires_grad=True)
+ self.var_p0 = Variable(self.input_p0,requires_grad=True)
+ self.var_p1 = Variable(self.input_p1,requires_grad=True)
+
+ def forward_train(self): # run forward pass
+ self.d0 = self.forward_pair(self.var_ref, self.var_p0)
+ self.d1 = self.forward_pair(self.var_ref, self.var_p1)
+ self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
+
+ # var_judge
+ self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
+
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
+ return self.loss_total
+
+ def backward_train(self):
+ torch.mean(self.loss_total).backward()
+
+ def compute_accuracy(self,d0,d1,judge):
+ ''' d0, d1 are Variables, judge is a Tensor '''
+ d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr))
+ self.old_lr = lr
+
+
+
+def score_2afc_dataset(data_loader,func):
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
+ distance function 'func' in dataset 'data_loader'
+ INPUTS
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
+ func - callable distance function - calling d=func(in0,in1) should take 2
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
+ OUTPUTS
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
+ [1] - dictionary with following elements
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
+ gts - N array in [0,1], preferred patch selected by human evaluators
+ (closer to "0" for left patch p0, "1" for right patch p1,
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
+ CONSTS
+ N - number of test triplets in data_loader
+ '''
+
+ d0s = []
+ d1s = []
+ gts = []
+
+ # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())
+ for (i,data) in enumerate(data_loader.load_data()):
+ d0s+=func(data['ref'],data['p0']).tolist()
+ d1s+=func(data['ref'],data['p1']).tolist()
+ gts+=data['judge'].cpu().numpy().flatten().tolist()
+ # bar.update(i)
+
+ d0s = np.array(d0s)
+ d1s = np.array(d1s)
+ gts = np.array(gts)
+ scores = (d0s epochs
+ and linearly decay the rate to zero over the next epochs.
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == 'linear':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == 'cosine':
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def init_weights(net, init_type='normal', init_gain=0.02):
+ """Initialize network weights.
+
+ Parameters:
+ net (network) -- network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+ work better for some applications. Feel free to try yourself.
+ """
+ def init_func(m): # define the initialization function
+ classname = m.__class__.__name__
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, init_gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=init_gain)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=init_gain)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+ init.normal_(m.weight.data, 1.0, init_gain)
+ init.constant_(m.bias.data, 0.0)
+
+ print('initialize network with %s' % init_type)
+ net.apply(init_func) # apply the initialization function
+
+
+def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
+ Parameters:
+ net (network) -- the network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Return an initialized network.
+ """
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ net.to(gpu_ids[0])
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
+ init_weights(net, init_type, init_gain=init_gain)
+ return net
+
+
+def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], model0_res=0, model1_res=0, extra_channel=3):
+ """Create a generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
+ use_dropout (bool) -- if use dropout layers.
+ init_type (str) -- the name of our initialization method.
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Returns a generator
+
+ Our current implementation provides two types of generators:
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
+
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
+
+
+ The generator has been initialized by . It uses RELU for non-linearity.
+ """
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if netG == 'resnet_9blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
+ elif netG == 'resnet_8blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=8)
+ elif netG == 'resnet_style_9blocks':
+ net = ResnetStyleGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, extra_channel=extra_channel)
+ elif netG == 'resnet_style2_9blocks':
+ net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel)
+ elif netG == 'resnet_style2_8blocks':
+ net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=8, model0_res=model0_res, extra_channel=extra_channel)
+ elif netG == 'resnet_style2_10blocks':
+ net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=10, model0_res=model0_res, extra_channel=extra_channel)
+ elif netG == 'resnet_style3decoder_9blocks':
+ net = ResnetStyle3DecoderGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res)
+ elif netG == 'resnet_style2mc_9blocks':
+ net = ResnetStyle2MCGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel)
+ elif netG == 'resnet_style2mc2_9blocks':
+ net = ResnetStyle2MC2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, model1_res=model1_res, extra_channel=extra_channel)
+ elif netG == 'resnet_6blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
+ elif netG == 'unet_128':
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_256':
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ else:
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
+ return init_net(net, init_type, init_gain, gpu_ids)
+
+
+def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], n_class=3):
+ """Create a discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the first conv layer
+ netD (str) -- the architecture's name: basic | n_layers | pixel
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
+ norm (str) -- the type of normalization layers used in the network.
+ init_type (str) -- the name of the initialization method.
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Returns a discriminator
+
+ Our current implementation provides three types of discriminators:
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
+ It can classify whether 70×70 overlapping patches are real or fake.
+ Such a patch-level discriminator architecture has fewer parameters
+ than a full-image discriminator and can work on arbitrarily-sized images
+ in a fully convolutional fashion.
+
+ [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
+ with the parameter (default=3 as used in [basic] (PatchGAN).)
+
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
+ It encourages greater color diversity but has no effect on spatial statistics.
+
+ The discriminator has been initialized by . It uses Leakly RELU for non-linearity.
+ """
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if netD == 'basic': # default PatchGAN classifier
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
+ elif netD == 'basic_cls':
+ net = NLayerDiscriminatorCls(input_nc, ndf, n_layers=3, n_class=3, norm_layer=norm_layer)
+ elif netD == 'n_layers': # more options
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
+ elif netD == 'pixel': # classify if each pixel is real or fake
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
+ else:
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
+ return init_net(net, init_type, init_gain, gpu_ids)
+
+
+def define_HED(init_weights_, gpu_ids_=[]):
+ net = HED()
+
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.to(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
+
+ if not init_weights_ == None:
+ device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
+ print('Loading model from: %s'%init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(device))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+
+ return net
+
+def define_VGG(init_weights_, feature_mode_, batch_norm_=False, num_classes_=1000, gpu_ids_=[]):
+ net = VGG19(init_weights=init_weights_, feature_mode=feature_mode_, batch_norm=batch_norm_, num_classes=num_classes_)
+ # set the GPU
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
+
+ if not init_weights_ == None:
+ device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
+ print('Loading model from: %s'%init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(device))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+ return net
+
+###################################################################################################################
+from torchvision.models import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn
+def define_vgg11_bn(gpu_ids_=[],vec=0):
+ net = vgg11_bn(pretrained=True)
+ net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+def define_vgg19_bn(gpu_ids_=[],vec=0):
+ net = vgg19_bn(pretrained=True)
+ net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+def define_vgg19(gpu_ids_=[],vec=0):
+ net = vgg19(pretrained=True)
+ net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+###################################################################################################################
+from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
+def define_resnet101(gpu_ids_=[],vec=0):
+ net = resnet101(pretrained=True)
+ num_ftrs = net.fc.in_features
+ net.fc = nn.Linear(num_ftrs, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+def define_resnet101a(init_weights_,gpu_ids_=[],vec=0):
+ net = resnet101(pretrained=True)
+ num_ftrs = net.fc.in_features
+ net.fc = nn.Linear(num_ftrs, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
+ if not init_weights_ == None:
+ print('Loading model from: %s'%init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+###################################################################################################################
+import pretrainedmodels.models.resnext as resnext
+def define_resnext101(gpu_ids_=[],vec=0):
+ net = resnext.resnext101_64x4d(num_classes=1000,pretrained='imagenet')
+ net.last_linear = nn.Linear(2048, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+def define_resnext101a(init_weights_,gpu_ids_=[],vec=0):
+ net = resnext.resnext101_64x4d(num_classes=1000,pretrained='imagenet')
+ net.last_linear = nn.Linear(2048, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
+ if not init_weights_ == None:
+ print('Loading model from: %s'%init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+###################################################################################################################
+from torchvision.models import Inception3, inception_v3
+def define_inception3(gpu_ids_=[],vec=0):
+ net = inception_v3(pretrained=True)
+ net.transform_input = False # assume [-1,1] input
+ net.fc = nn.Linear(2048, 1)
+ net.aux_logits = False
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+def define_inception3a(init_weights_,gpu_ids_=[],vec=0):
+ net = inception_v3(pretrained=True)
+ net.transform_input = False # assume [-1,1] input
+ net.fc = nn.Linear(2048, 1)
+ net.aux_logits = False
+ if not init_weights_ == None:
+ print('Loading model from: ', init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+ return net
+###################################################################################################################
+from torchvision.models.inception import BasicConv2d
+def define_inception_v3(init_weights_,gpu_ids_=[],vec=0):
+
+ ## pretrained = True
+ kwargs = {}
+ if 'transform_input' not in kwargs:
+ kwargs['transform_input'] = True
+ if 'aux_logits' in kwargs:
+ original_aux_logits = kwargs['aux_logits']
+ kwargs['aux_logits'] = True
+ else:
+ original_aux_logits = True
+ print(kwargs)
+ net = Inception3(**kwargs)
+
+ if not init_weights_ == None:
+ print('Loading model from: %s'%init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+
+ if not original_aux_logits:
+ net.aux_logits = False
+ del net.AuxLogits
+
+ net.fc = nn.Linear(2048, 1)
+ if vec == 1:
+ net.Conv2d_1a_3x3 = BasicConv2d(6, 32, kernel_size=3, stride=2)
+ net.aux_logits = False
+
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+
+ return net
+
+def define_inception_v3a(init_weights_,gpu_ids_=[],vec=0):
+
+ kwargs = {}
+ if 'transform_input' not in kwargs:
+ kwargs['transform_input'] = True
+ if 'aux_logits' in kwargs:
+ original_aux_logits = kwargs['aux_logits']
+ kwargs['aux_logits'] = True
+ else:
+ original_aux_logits = True
+ print(kwargs)
+ net = Inception3(**kwargs)
+
+ if not original_aux_logits:
+ net.aux_logits = False
+ del net.AuxLogits
+
+ net.fc = nn.Linear(2048, 1)
+ if vec == 1:
+ net.Conv2d_1a_3x3 = BasicConv2d(6, 32, kernel_size=3, stride=2)
+ net.aux_logits = False
+
+ if not init_weights_ == None:
+ print('Loading model from: %s'%init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+ net = torch.nn.DataParallel(net, gpu_ids_)
+
+ return net
+
+def define_inception_ori(init_weights_,transform_input=False,gpu_ids_=[]):
+
+ ## pretrained = True
+ kwargs = {}
+ kwargs['transform_input'] = transform_input
+
+ if 'aux_logits' in kwargs:
+ original_aux_logits = kwargs['aux_logits']
+ kwargs['aux_logits'] = True
+ else:
+ original_aux_logits = True
+ print(kwargs)
+ net = Inception3(**kwargs)
+
+
+ if not init_weights_ == None:
+ print('Loading model from: %s'%init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+ #for e in list(net.modules()):
+ # print(e)
+
+ if not original_aux_logits:
+ net.aux_logits = False
+ del net.AuxLogits
+
+
+ if len(gpu_ids_) > 0:
+ assert(torch.cuda.is_available())
+ net.cuda(gpu_ids_[0])
+
+ return net
+###################################################################################################################
+
+def define_DT(init_weights_, input_nc_, output_nc_, ngf_, netG_, norm_, use_dropout_, init_type_, init_gain_, gpu_ids_):
+ net = define_G(input_nc_, output_nc_, ngf_, netG_, norm_, use_dropout_, init_type_, init_gain_, gpu_ids_)
+
+ if not init_weights_ == None:
+ device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
+ print('Loading model from: %s'%init_weights_)
+ state_dict = torch.load(init_weights_, map_location=str(device))
+ if isinstance(net, torch.nn.DataParallel):
+ net.module.load_state_dict(state_dict)
+ else:
+ net.load_state_dict(state_dict)
+ print('load the weights successfully')
+ return net
+
+def define_C(input_nc, classes, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], h=512, w=512, nnG=3, dim=4096):
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+ if netG == 'classifier':
+ net = Classifier(input_nc, classes, ngf, num_downs=nnG, norm_layer=norm_layer, use_dropout=use_dropout, h=h, w=w, dim=dim)
+ elif netG == 'vgg':
+ net = VGG19(init_weights=None, feature_mode=False, batch_norm=True, num_classes=classes)
+ return init_net(net, init_type, init_gain, gpu_ids)
+
+##############################################################################
+# Classes
+##############################################################################
+class GANLoss(nn.Module):
+ """Define different GAN objectives.
+
+ The GANLoss class abstracts away the need to create the target label tensor
+ that has the same size as the input.
+ """
+
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
+ """ Initialize the GANLoss class.
+
+ Parameters:
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
+ target_real_label (bool) - - label for a real image
+ target_fake_label (bool) - - label of a fake image
+
+ Note: Do not use sigmoid as the last layer of Discriminator.
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
+ """
+ super(GANLoss, self).__init__()
+ self.register_buffer('real_label', torch.tensor(target_real_label))
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
+ self.gan_mode = gan_mode
+ if gan_mode == 'lsgan':#cyclegan
+ self.loss = nn.MSELoss()
+ elif gan_mode == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif gan_mode in ['wgangp']:
+ self.loss = None
+ else:
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
+
+ def get_target_tensor(self, prediction, target_is_real):
+ """Create label tensors with the same size as the input.
+
+ Parameters:
+ prediction (tensor) - - tpyically the prediction from a discriminator
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
+
+ Returns:
+ A label tensor filled with ground truth label, and with the size of the input
+ """
+
+ if target_is_real:
+ target_tensor = self.real_label
+ else:
+ target_tensor = self.fake_label
+ return target_tensor.expand_as(prediction)
+
+ def __call__(self, prediction, target_is_real):
+ """Calculate loss given Discriminator's output and grount truth labels.
+
+ Parameters:
+ prediction (tensor) - - tpyically the prediction output from a discriminator
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
+
+ Returns:
+ the calculated loss.
+ """
+ if self.gan_mode in ['lsgan', 'vanilla']:
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
+ loss = self.loss(prediction, target_tensor)
+ elif self.gan_mode == 'wgangp':
+ if target_is_real:
+ loss = -prediction.mean()
+ else:
+ loss = prediction.mean()
+ return loss
+
+
+def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
+
+ Arguments:
+ netD (network) -- discriminator network
+ real_data (tensor array) -- real images
+ fake_data (tensor array) -- generated images from the generator
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
+ lambda_gp (float) -- weight for this loss
+
+ Returns the gradient penalty loss
+ """
+ if lambda_gp > 0.0:
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
+ interpolatesv = real_data
+ elif type == 'fake':
+ interpolatesv = fake_data
+ elif type == 'mixed':
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
+ else:
+ raise NotImplementedError('{} not implemented'.format(type))
+ interpolatesv.requires_grad_(True)
+ disc_interpolates = netD(interpolatesv)
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True)
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
+ return gradient_penalty, gradients
+ else:
+ return 0.0, None
+
+
+class ResnetGenerator(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetGenerator, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input, feature_mode = False):
+ """Standard forward"""
+ if not feature_mode:
+ return self.model(input)
+ else:
+ module_list = list(self.model.modules())
+ x = input.clone()
+ indexes = list(range(1,11))+[11,20,29,38,47,56,65,74,83]+list(range(92,101))
+ for i in indexes:
+ x = module_list[i](x)
+ if i == 3:
+ x1 = x.clone()
+ elif i == 6:
+ x2 = x.clone()
+ elif i == 9:
+ x3 = x.clone()
+ elif i == 47:
+ y7 = x.clone()
+ elif i == 83:
+ y4 = x.clone()
+ elif i == 93:
+ y3 = x.clone()
+ elif i == 96:
+ y2 = x.clone()
+ #y = self.model(input)
+ #pdb.set_trace()
+ return x,x1,x2,x3,y4,y3,y2,y7
+
+class ResnetStyleGenerator(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetStyleGenerator, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model0 = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ model1 = [nn.Conv2d(3, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult),
+ nn.ReLU(True)]
+
+ model = []
+ model += [nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult),
+ nn.ReLU(True)]
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model0 = nn.Sequential(*model0)
+ self.model1 = nn.Sequential(*model1)
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input1, input2):
+ """Standard forward"""
+ f1 = self.model0(input1)
+ f2 = self.model1(input2)
+ #pdb.set_trace()
+ f1 = torch.cat((f1,f2), 1)
+ return self.model(f1)
+
+
+class ResnetStyle2Generator(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetStyle2Generator, self).__init__()
+ self.n_blocks = n_blocks
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model0 = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ for i in range(model0_res): # add ResNet blocks
+ model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ model = []
+ model += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult),
+ nn.ReLU(True)]
+
+ for i in range(n_blocks-model0_res): # add ResNet blocks
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model0 = nn.Sequential(*model0)
+ self.model = nn.Sequential(*model)
+ #print(list(self.modules()))
+
+ def forward(self, input1, input2, feature_mode=False, ablate_res=-1):
+ """Standard forward"""
+ if not feature_mode:
+ if ablate_res == -1:
+ f1 = self.model0(input1)
+ y1 = torch.cat([f1, input2], 1)
+ return self.model(y1)
+ else:
+ f1 = self.model0(input1)
+ y = torch.cat([f1, input2], 1)
+ module_list = list(self.model.modules())
+ for i in range(1, 4):#merge module
+ y = module_list[i](y)
+ for k in range(self.n_blocks):#resblocks
+ if k+1 == ablate_res:
+ print('skip resblock'+str(k+1))
+ continue
+ y1 = y.clone()
+ for i in range(6+9*k,13+9*k):
+ y = module_list[i](y)
+ y = y1 + y
+ for i in range(4+9*self.n_blocks,13+9*self.n_blocks):#up convs
+ y = module_list[i](y)
+ return y
+ else:
+ module_list0 = list(self.model0.modules())
+ x = input1.clone()
+ for i in range(1,11):
+ x = module_list0[i](x)
+ if i == 3:
+ x1 = x.clone()#[1,64,512,512]
+ elif i == 6:
+ x2 = x.clone()#[1,128,256,256]
+ elif i == 9:
+ x3 = x.clone()#[1,256,128,128]
+ #f1 = self.model0(input1)#[1,256,128,128]
+ #pdb.set_trace()
+ y1 = torch.cat([x, input2], 1)#[1,259,128,128]
+ module_list = list(self.model.modules())
+ indexes = list(range(1,4))+[4,13,22,31,40,49,58,67,76]+list(range(85,94))
+ y = y1.clone()
+ for i in indexes:
+ y = module_list[i](y)
+ if i == 76:
+ y4 = y.clone()#[1,256,128,128]
+ elif i == 86:
+ y3 = y.clone()#[1,128,256,256]
+ elif i == 89:
+ y2 = y.clone()#[1,64,512,512]
+ elif i == 40:
+ y7 = y.clone()
+ #out = self.model(y1)
+ #pdb.set_trace()
+ return y,x1,x2,x3,y4,y3,y2,y7
+
+class ResnetStyle3DecoderGenerator(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', model0_res=0):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetStyle3DecoderGenerator, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model0 = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ for i in range(model0_res): # add ResNet blocks
+ model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ model1 = []
+ model2 = []
+ model3 = []
+ for i in range(n_blocks-model0_res): # add ResNet blocks
+ model1 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+ model2 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+ model3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model1 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model2 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model3 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model1 += [nn.ReflectionPad2d(3)]
+ model1 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model1 += [nn.Tanh()]
+ model2 += [nn.ReflectionPad2d(3)]
+ model2 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model2 += [nn.Tanh()]
+ model3 += [nn.ReflectionPad2d(3)]
+ model3 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model3 += [nn.Tanh()]
+
+ self.model0 = nn.Sequential(*model0)
+ self.model1 = nn.Sequential(*model1)
+ self.model2 = nn.Sequential(*model2)
+ self.model3 = nn.Sequential(*model3)
+ print(list(self.modules()))
+
+ def forward(self, input, domain):
+ """Standard forward"""
+ f1 = self.model0(input)
+ if domain == 0:
+ y = self.model1(f1)
+ elif domain == 1:
+ y = self.model2(f1)
+ elif domain == 2:
+ y = self.model3(f1)
+ return y
+
+class ResnetStyle2MCGenerator(nn.Module):
+ # multi-column
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetStyle2MCGenerator, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model0 = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ model1_3 = []
+ model1_5 = []
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model1_3 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+ model1_5 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=5, stride=2, padding=2, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ for i in range(model0_res): # add ResNet blocks
+ model1_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+ model1_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)]
+
+ model = []
+ model += [nn.Conv2d(ngf * mult * 2 + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult),
+ nn.ReLU(True)]
+
+ for i in range(n_blocks-model0_res): # add ResNet blocks
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model0 = nn.Sequential(*model0)
+ self.model1_3 = nn.Sequential(*model1_3)
+ self.model1_5 = nn.Sequential(*model1_5)
+ self.model = nn.Sequential(*model)
+ print(list(self.modules()))
+
+ def forward(self, input1, input2):
+ """Standard forward"""
+ f0 = self.model0(input1)
+ f1 = self.model1_3(f0)
+ f2 = self.model1_5(f0)
+ y1 = torch.cat([f1, f2, input2], 1)
+ return self.model(y1)
+
+class ResnetStyle2MC2Generator(nn.Module):
+ # multi-column, need to insert style early
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0, model1_res=0):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetStyle2MC2Generator, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model0 = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ model1_3 = []
+ model1_5 = []
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model1_3 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+ model1_5 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=5, stride=2, padding=2, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ for i in range(model0_res): # add ResNet blocks
+ model1_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+ model1_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)]
+
+ model2_3 = []
+ model2_5 = []
+ model2_3 += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult),
+ nn.ReLU(True)]
+ model2_5 += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=5, stride=1, padding=2, bias=use_bias),
+ norm_layer(ngf * mult),
+ nn.ReLU(True)]
+
+ for i in range(model1_res): # add ResNet blocks
+ model2_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+ model2_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)]
+
+ model = []
+ model += [nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult),
+ nn.ReLU(True)]
+ for i in range(n_blocks-model0_res-model1_res): # add ResNet blocks
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model0 = nn.Sequential(*model0)
+ self.model1_3 = nn.Sequential(*model1_3)
+ self.model1_5 = nn.Sequential(*model1_5)
+ self.model2_3 = nn.Sequential(*model2_3)
+ self.model2_5 = nn.Sequential(*model2_5)
+ self.model = nn.Sequential(*model)
+ print(list(self.modules()))
+
+ def forward(self, input1, input2):
+ """Standard forward"""
+ f0 = self.model0(input1)
+ f1 = self.model1_3(f0)
+ f2 = self.model1_5(f0)
+ f3 = self.model2_3(torch.cat([f1,input2],1))
+ f4 = self.model2_5(torch.cat([f2,input2],1))
+ #pdb.set_trace()
+ return self.model(torch.cat([f3,f4],1))
+
+class ResnetBlock(nn.Module):
+ """Define a Resnet block"""
+
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):
+ """Initialize the Resnet block
+
+ A resnet block is a conv block with skip connections
+ We construct a conv block with build_conv_block function,
+ and implement skip connections in function.
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
+ """
+ super(ResnetBlock, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, kernel)
+
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):
+ """Construct a convolutional block.
+
+ Parameters:
+ dim (int) -- the number of channels in the conv layer.
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ use_bias (bool) -- if the conv layer uses bias or not
+
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
+ """
+ conv_block = []
+ p = 0
+ pad = int((kernel-1)/2)
+ if padding_type == 'reflect':#by default
+ conv_block += [nn.ReflectionPad2d(pad)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(pad)]
+ elif padding_type == 'zero':
+ p = pad
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(pad)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(pad)]
+ elif padding_type == 'zero':
+ p = pad
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ """Forward function (with skip connections)"""
+ out = x + self.conv_block(x) # add skip connections
+ return out
+
+
+class UnetGenerator(nn.Module):
+ """Create a Unet-based generator"""
+
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet generator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
+ image of size 128x128 will become of size 1x1 # at the bottleneck
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+
+ We construct the U-Net from the innermost layer to the outermost layer.
+ It is a recursive process.
+ """
+ super(UnetGenerator, self).__init__()
+ # construct unet structure
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ # gradually reduce the number of filters from ngf * 8 to ngf
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class UnetSkipConnectionBlock(nn.Module):
+ """Defines the Unet submodule with skip connection.
+ X -------------------identity----------------------
+ |-- downsampling -- |submodule| -- upsampling --|
+ """
+
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet submodule with skip connections.
+
+ Parameters:
+ outer_nc (int) -- the number of filters in the outer conv layer
+ inner_nc (int) -- the number of filters in the inner conv layer
+ input_nc (int) -- the number of channels in input images/features
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
+ outermost (bool) -- if this module is the outermost module
+ innermost (bool) -- if this module is the innermost module
+ norm_layer -- normalization layer
+ user_dropout (bool) -- if use dropout layers.
+ """
+ super(UnetSkipConnectionBlock, self).__init__()
+ self.outermost = outermost
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1, bias=use_bias)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = norm_layer(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = norm_layer(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1)
+ down = [downconv]
+ up = [uprelu, upconv, nn.Tanh()]
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ if self.outermost:
+ return self.model(x)
+ else: # add skip connections
+ return torch.cat([x, self.model(x)], 1)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator"""
+
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
+ """Construct a PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.model(input)
+
+
+class NLayerDiscriminatorCls(nn.Module):
+ """Defines a PatchGAN discriminator"""
+
+ def __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer=nn.BatchNorm2d):
+ """Construct a PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminatorCls, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence1 = [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+ sequence1 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+
+ sequence2 = [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+ sequence2 += [
+ nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+ sequence2 += [
+ nn.Conv2d(ndf * nf_mult, n_class, kernel_size=16, stride=1, padding=0, bias=use_bias)]
+
+
+ self.model0 = nn.Sequential(*sequence)
+ self.model1 = nn.Sequential(*sequence1)
+ self.model2 = nn.Sequential(*sequence2)
+ print(list(self.modules()))
+
+ def forward(self, input):
+ """Standard forward."""
+ feat = self.model0(input)
+ # patchGAN output (1 * 62 * 62)
+ patch = self.model1(feat)
+ # class output (3 * 1 * 1)
+ classl = self.model2(feat)
+ return patch, classl.view(classl.size(0), -1)
+
+
+class PixelDiscriminator(nn.Module):
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
+
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
+ """Construct a 1x1 PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ """
+ super(PixelDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer != nn.InstanceNorm2d
+
+ self.net = [
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
+ norm_layer(ndf * 2),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
+
+ self.net = nn.Sequential(*self.net)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.net(input)
+
+
+class HED(nn.Module):
+ def __init__(self):
+ super(HED, self).__init__()
+
+ self.moduleVggOne = nn.Sequential(
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False)
+ )
+
+ self.moduleVggTwo = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False)
+ )
+
+ self.moduleVggThr = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False)
+ )
+
+ self.moduleVggFou = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False)
+ )
+
+ self.moduleVggFiv = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False)
+ )
+
+ self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+
+ self.moduleCombine = nn.Sequential(
+ nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
+ nn.Sigmoid()
+ )
+
+ def forward(self, tensorInput):
+ tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793
+ tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762
+ tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434
+
+ tensorInput = torch.cat([ tensorBlue, tensorGreen, tensorRed ], 1)
+
+ tensorVggOne = self.moduleVggOne(tensorInput)
+ tensorVggTwo = self.moduleVggTwo(tensorVggOne)
+ tensorVggThr = self.moduleVggThr(tensorVggTwo)
+ tensorVggFou = self.moduleVggFou(tensorVggThr)
+ tensorVggFiv = self.moduleVggFiv(tensorVggFou)
+
+ tensorScoreOne = self.moduleScoreOne(tensorVggOne)
+ tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo)
+ tensorScoreThr = self.moduleScoreThr(tensorVggThr)
+ tensorScoreFou = self.moduleScoreFou(tensorVggFou)
+ tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv)
+
+ tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
+ tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
+ tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
+ tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
+ tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
+
+ return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1))
+
+# class for VGG19 modle
+# borrows largely from torchvision vgg
+class VGG19(nn.Module):
+ def __init__(self, init_weights=None, feature_mode=False, batch_norm=False, num_classes=1000):
+ super(VGG19, self).__init__()
+ self.cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
+ self.init_weights = init_weights
+ self.feature_mode = feature_mode
+ self.batch_norm = batch_norm
+ self.num_clases = num_classes
+ self.features = self.make_layers(self.cfg, batch_norm)
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+ # print('----------load the pretrained vgg net---------')
+ # if not init_weights == None:
+ # print('load the weights')
+ # self.load_state_dict(torch.load(init_weights))
+
+
+ def make_layers(self, cfg, batch_norm=False):
+ layers = []
+ in_channels = 3
+ for v in cfg:
+ if v == 'M':
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+ else:
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+ if batch_norm:
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+ else:
+ layers += [conv2d, nn.ReLU(inplace=True)]
+ in_channels = v
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.feature_mode:
+ module_list = list(self.features.modules())
+ for l in module_list[1:27]: # conv4_4
+ x = l(x)
+ if not self.feature_mode:
+ x = self.features(x)
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+
+ return x
+
+class Classifier(nn.Module):
+ def __init__(self, input_nc, classes, ngf=64, num_downs=3, norm_layer=nn.BatchNorm2d, use_dropout=False, h=512, w=512, dim=4096):
+ super(Classifier, self).__init__()
+ self.input_nc = input_nc
+ self.ngf = ngf
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, num_downs):
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ model += [
+ nn.Conv2d(int(ngf * nf_mult_prev), int(ngf * nf_mult), kernel_size=4, stride=2, padding=1, bias=use_bias),
+ norm_layer(int(ngf * nf_mult)),
+ nn.LeakyReLU(0.2, True)
+ ]
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** num_downs, 8)
+ model += [
+ nn.Conv2d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+ self.encoder = nn.Sequential(*model)
+
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, dim),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(dim, dim),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(dim, classes),
+ )
+
+ def forward(self, x):
+ ax = self.encoder(x)
+ #print('ax',ax.shape) # (8, 512, 7, 7)
+ ax = ax.view(ax.size(0), -1) # view -- reshape
+ return self.classifier(ax)
diff --git a/hi-arm/qmupd_vs/models/networks_basic.py b/hi-arm/qmupd_vs/models/networks_basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..d71d6b383b9763bce2c1c19ae703966d87ba8cdf
--- /dev/null
+++ b/hi-arm/qmupd_vs/models/networks_basic.py
@@ -0,0 +1,187 @@
+
+from __future__ import absolute_import
+
+import sys
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from torch.autograd import Variable
+import numpy as np
+from pdb import set_trace as st
+from skimage import color
+from IPython import embed
+from . import pretrained_networks as pn
+
+from util import util
+
+def spatial_average(in_tens, keepdim=True):
+ return in_tens.mean([2,3],keepdim=keepdim)
+
+def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
+ in_H = in_tens.shape[2]
+ scale_factor = 1.*out_H/in_H
+
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
+
+# Learned perceptual metric
+class PNetLin(nn.Module):
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
+ super(PNetLin, self).__init__()
+
+ self.pnet_type = pnet_type
+ self.pnet_tune = pnet_tune
+ self.pnet_rand = pnet_rand
+ self.spatial = spatial
+ self.lpips = lpips
+ self.version = version
+ self.scaling_layer = ScalingLayer()
+
+ if(self.pnet_type in ['vgg','vgg16']):
+ net_type = pn.vgg16
+ self.chns = [64,128,256,512,512]
+ elif(self.pnet_type=='alex'):
+ net_type = pn.alexnet
+ self.chns = [64,192,384,256,256]
+ elif(self.pnet_type=='squeeze'):
+ net_type = pn.squeezenet
+ self.chns = [64,128,256,384,384,512,512]
+ self.L = len(self.chns)
+
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
+
+ if(lpips):
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
+ self.lins+=[self.lin5,self.lin6]
+
+ def forward(self, in0, in1, retPerLayer=False):
+ # v0.0 - original release had a bug, where input was not scaled
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+
+ for kk in range(self.L):
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
+
+ if(self.lpips):
+ if(self.spatial):
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
+ else:
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
+ else:
+ if(self.spatial):
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
+ else:
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
+
+ val = res[0]
+ for l in range(1,self.L):
+ val += res[l]
+
+ if(retPerLayer):
+ return (val, res)
+ else:
+ return val
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
+
+ def forward(self, inp):
+ return (inp - self.shift.to(inp.device)) / self.scale.to(inp.device)
+
+
+class NetLinLayer(nn.Module):
+ ''' A single linear layer which does a 1x1 conv '''
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+
+ layers = [nn.Dropout(),] if(use_dropout) else []
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
+ self.model = nn.Sequential(*layers)
+
+
+class Dist2LogitLayer(nn.Module):
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
+ def __init__(self, chn_mid=32, use_sigmoid=True):
+ super(Dist2LogitLayer, self).__init__()
+
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
+ layers += [nn.LeakyReLU(0.2,True),]
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
+ layers += [nn.LeakyReLU(0.2,True),]
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
+ if(use_sigmoid):
+ layers += [nn.Sigmoid(),]
+ self.model = nn.Sequential(*layers)
+
+ def forward(self,d0,d1,eps=0.1):
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
+
+class BCERankingLoss(nn.Module):
+ def __init__(self, chn_mid=32):
+ super(BCERankingLoss, self).__init__()
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
+ # self.parameters = list(self.net.parameters())
+ self.loss = torch.nn.BCELoss()
+
+ def forward(self, d0, d1, judge):
+ per = (judge+1.)/2.
+ self.logit = self.net.forward(d0,d1)
+ return self.loss(self.logit, per)
+
+# L2, DSSIM metrics
+class FakeNet(nn.Module):
+ def __init__(self, use_gpu=True, colorspace='Lab'):
+ super(FakeNet, self).__init__()
+ self.use_gpu = use_gpu
+ self.colorspace=colorspace
+
+class L2(FakeNet):
+
+ def forward(self, in0, in1, retPerLayer=None):
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
+
+ if(self.colorspace=='RGB'):
+ (N,C,X,Y) = in0.size()
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
+ return value
+ elif(self.colorspace=='Lab'):
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
+ ret_var = Variable( torch.Tensor((value,) ) )
+ if(self.use_gpu):
+ ret_var = ret_var.cuda()
+ return ret_var
+
+class DSSIM(FakeNet):
+
+ def forward(self, in0, in1, retPerLayer=None):
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
+
+ if(self.colorspace=='RGB'):
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
+ elif(self.colorspace=='Lab'):
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
+ ret_var = Variable( torch.Tensor((value,) ) )
+ if(self.use_gpu):
+ ret_var = ret_var.cuda()
+ return ret_var
+
+def print_network(net):
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print('Network',net)
+ print('Total number of parameters: %d' % num_params)
diff --git a/hi-arm/qmupd_vs/models/pretrained_networks.py b/hi-arm/qmupd_vs/models/pretrained_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1329d64b798229bb16578f5bcaa1dff7d660a8e
--- /dev/null
+++ b/hi-arm/qmupd_vs/models/pretrained_networks.py
@@ -0,0 +1,181 @@
+from collections import namedtuple
+import torch
+from torchvision import models
+from IPython import embed
+
+class squeezenet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(squeezenet, self).__init__()
+ pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.slice6 = torch.nn.Sequential()
+ self.slice7 = torch.nn.Sequential()
+ self.N_slices = 7
+ for x in range(2):
+ self.slice1.add_module(str(x), pretrained_features[x])
+ for x in range(2,5):
+ self.slice2.add_module(str(x), pretrained_features[x])
+ for x in range(5, 8):
+ self.slice3.add_module(str(x), pretrained_features[x])
+ for x in range(8, 10):
+ self.slice4.add_module(str(x), pretrained_features[x])
+ for x in range(10, 11):
+ self.slice5.add_module(str(x), pretrained_features[x])
+ for x in range(11, 12):
+ self.slice6.add_module(str(x), pretrained_features[x])
+ for x in range(12, 13):
+ self.slice7.add_module(str(x), pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1 = h
+ h = self.slice2(h)
+ h_relu2 = h
+ h = self.slice3(h)
+ h_relu3 = h
+ h = self.slice4(h)
+ h_relu4 = h
+ h = self.slice5(h)
+ h_relu5 = h
+ h = self.slice6(h)
+ h_relu6 = h
+ h = self.slice7(h)
+ h_relu7 = h
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
+
+ return out
+
+
+class alexnet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(alexnet, self).__init__()
+ alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(2):
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(2, 5):
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(5, 8):
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(8, 10):
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(10, 12):
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1 = h
+ h = self.slice2(h)
+ h_relu2 = h
+ h = self.slice3(h)
+ h_relu3 = h
+ h = self.slice4(h)
+ h_relu4 = h
+ h = self.slice5(h)
+ h_relu5 = h
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
+
+ return out
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+
+ return out
+
+
+
+class resnet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
+ super(resnet, self).__init__()
+ if(num==18):
+ self.net = models.resnet18(pretrained=pretrained)
+ elif(num==34):
+ self.net = models.resnet34(pretrained=pretrained)
+ elif(num==50):
+ self.net = models.resnet50(pretrained=pretrained)
+ elif(num==101):
+ self.net = models.resnet101(pretrained=pretrained)
+ elif(num==152):
+ self.net = models.resnet152(pretrained=pretrained)
+ self.N_slices = 5
+
+ self.conv1 = self.net.conv1
+ self.bn1 = self.net.bn1
+ self.relu = self.net.relu
+ self.maxpool = self.net.maxpool
+ self.layer1 = self.net.layer1
+ self.layer2 = self.net.layer2
+ self.layer3 = self.net.layer3
+ self.layer4 = self.net.layer4
+
+ def forward(self, X):
+ h = self.conv1(X)
+ h = self.bn1(h)
+ h = self.relu(h)
+ h_relu1 = h
+ h = self.maxpool(h)
+ h = self.layer1(h)
+ h_conv2 = h
+ h = self.layer2(h)
+ h_conv3 = h
+ h = self.layer3(h)
+ h_conv4 = h
+ h = self.layer4(h)
+ h_conv5 = h
+
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
+
+ return out
diff --git a/hi-arm/qmupd_vs/models/test_model.py b/hi-arm/qmupd_vs/models/test_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b86872218cbf60e61e76989649799adb993de3bf
--- /dev/null
+++ b/hi-arm/qmupd_vs/models/test_model.py
@@ -0,0 +1,96 @@
+from .base_model import BaseModel
+from . import networks
+import torch
+import pdb
+
+class TestModel(BaseModel):
+ """ This TesteModel can be used to generate CycleGAN results for only one direction.
+ This model will automatically set '--dataset_mode single', which only loads the images from one collection.
+
+ See the test instruction for more details.
+ """
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+
+ The model can only be used during test time. It requires '--dataset_mode single'.
+ You need to specify the network using the option '--model_suffix'.
+ """
+ assert not is_train, 'TestModel cannot be used during training time'
+ parser.set_defaults(dataset_mode='single')
+ parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
+ parser.add_argument('--style_control', type=int, default=0, help='use style_control')
+ parser.add_argument('--sfeature_mode', type=str, default='vgg19_softmax', help='vgg19 softmax as feature')
+ parser.add_argument('--sinput', type=str, default='sind', help='use which one for style input')
+ parser.add_argument('--sind', type=int, default=0, help='one hot for sfeature')
+ parser.add_argument('--svec', type=str, default='1,0,0', help='3-dim vec')
+ parser.add_argument('--simg', type=str, default='Yann_Legendre-053', help='drawing example for style')
+ parser.add_argument('--netga', type=str, default='resnet_style_9blocks', help='net arch for netG_A')
+ parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0')
+ parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize the pix2pix class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ assert(not opt.isTrain)
+ BaseModel.__init__(self, opt)
+ # specify the training losses you want to print out. The training/test scripts will call
+ self.loss_names = []
+ # specify the images you want to save/display. The training/test scripts will call
+ #self.visual_names = ['real', 'fake', 'rec', 'fake_B']
+ self.visual_names = ['real', 'fake']
+ # specify the models you want to save to the disk. The training/test scripts will call and
+ self.model_names = ['G' + opt.model_suffix, 'G_B'] # only generator is needed.
+ if not self.opt.style_control:
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
+ opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
+ else:
+ print(opt.netga)
+ print('model0_res', opt.model0_res)
+ print('model1_res', opt.model1_res)
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
+
+ self.netGB = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG,
+ opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
+ # assigns the model to self.netG_[suffix] so that it can be loaded
+ # please see
+ setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
+ setattr(self, 'netG_B', self.netGB) # store netGB in self.
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+
+ We need to use 'single_dataset' dataset mode. It only load images from one domain.
+ """
+ self.real = input['A'].to(self.device)
+ self.image_paths = input['A_paths']
+ if self.opt.style_control:
+ self.style = input['B_style']
+
+ def forward(self):
+ """Run forward pass."""
+ if not self.opt.style_control:
+ self.fake = self.netG(self.real) # G(real)
+ else:
+ #print(torch.mean(self.style,(2,3)),'style_control')
+ self.fake = self.netG(self.real, self.style)
+
+ def optimize_parameters(self):
+ """No optimization for test model."""
+ pass
diff --git a/hi-arm/qmupd_vs/operator_main.ipynb b/hi-arm/qmupd_vs/operator_main.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..b1fe6961d9f3e756531a67f649983428b10bff46
--- /dev/null
+++ b/hi-arm/qmupd_vs/operator_main.ipynb
@@ -0,0 +1,606 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "138ed57c06ca45c786e09bcf744f4d54",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "CameraStream(constraints={'facing_mode': 'user', 'audio': False, 'video': {'width': 512, 'height': 512, 'facin…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d0e4ef53014b4bbab34e6ba90336ad52",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "ImageRecorder(image=Image(value=b''), stream=CameraStream(constraints={'facing_mode': 'user', 'audio': False, …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from ipywebrtc import CameraStream, ImageRecorder\n",
+ "from IPython.display import display\n",
+ "import PIL.Image\n",
+ "import io\n",
+ "import numpy\n",
+ "import cv2\n",
+ "from ipywebrtc import CameraStream\n",
+ "camera = CameraStream.facing_user(audio=False, constraints={\n",
+ " 'facing_mode': 'user',\n",
+ " 'audio': False,\n",
+ " 'video': { 'width': 512, 'height': 512 }\n",
+ "})\n",
+ "display(camera)\n",
+ "recorder = ImageRecorder(stream=camera)\n",
+ "display(recorder)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "
+