SFTGAN test

SFTGAN是截至到目前在本研究部门内部研究发现的最好的和最优的超分辨放大算法,当然算法都是针对某一个领域,某一个方向的。SFTGAN发现是在艺术画的放大生成过程中是最好的。而ESRGAN在现实生活图片,真实照片上的放大效果比较突出,尤其是颜色比较集中的情况下。以下内容是SFTGAN的测试用例和为了方便使用改写的测试代码。

SFTGAN的github地址:SFTGAN

程序的组织架构如下:
SFTGAN.png

首先说明执行命令


运行test_subdir.py(即可运行文件,代码如下)将images下的所有图片放大到4096并放置于images_4096下,将大于1024并小于2048的图片先resize到1024再通过SFTGAN来super resolution,而大于2048的图片直接resize到4096。其中的1表示首先拷贝和处理images的图片为三通道然后再放置于images_4096, 而5代表迭代五次,因为本程序只能通过SFTGAN放大四倍,若要从256的图片放大到4096要两次,更小的要更多次数,保险起见设置为5.

1
python test_subdir.py images/ images_4096/ 1 5 2048 4096 1024

以下代码放置于pytorch_test文件夹下,用于将目录下的文件放大到指定大小

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
'''
Segmentation codes for generating segmentation probability maps for SFTGAN
'''

import os
import glob
import numpy as np
import cv2
import sys
import torch
import torchvision.utils
import time
import architectures as arch
import util
from PIL import Image
# 通道转换
def change_image_channels(input_image_path, output_image_path):
image = Image.open(input_image_path)
if image.mode == 'RGBA':
r, g, b, a = image.split()
image = Image.merge("RGB", (r, g, b))
try:
os.remove(output_image_path)
except:
pass
image.save(output_image_path)
elif image.mode != 'RGB':
image = image.convert("RGB")
try:
os.remove(output_image_path)
except:
pass
image.save(output_image_path)
else:
try:
os.remove(output_image_path)
except:
pass
image.save(output_image_path)
return image


# options
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

times = 3
channel_mark = 1
imgSize = 4096
finalSize = 4096
minImgSize = 1024


imagespath = sys.argv[1] #must end with "/"
outputdir = sys.argv[2] #must end with "/"
channel_mark = int(sys.argv[3]) #default 1, means change all images to 3 channel
times = int(sys.argv[4]) #default 3
imgSize = int(sys.argv[5])
finalSize = int(sys.argv[6])
minImgSize = int(sys.argv[7])
if not os.path.exists(outputdir):
os.makedirs(outputdir)
device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> 'cpu'
# device = torch.device('cpu')



model_path = '/home/t-huch/SFTGAN/pretrained_models/SFTGAN_torch.pth' # torch version



if 'torch' in model_path: # torch version
model = arch.SFT_Net_torch()
else: # pytorch version
model = arch.SFT_Net()
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)


# load model
seg_model = arch.OutdoorSceneSeg()
model_path = '/home/t-huch/SFTGAN/pretrained_models/segmentation_OST_bic.pth'
seg_model.load_state_dict(torch.load(model_path), strict=True)
seg_model.eval()
seg_model = seg_model.to(device)

print('Testing SFTGAN ...')

print(channel_mark)
if channel_mark == 1:

for root, dirs, files in os.walk(imagespath):
for file in files:
start_time = time.time()
path = os.path.join(root,file)
imgname = os.path.basename(path)
subDir = os.path.join(outputdir,root.replace(imagespath, ""))
if not os.path.exists(subDir):
os.makedirs(subDir)
print(path)
change_image_channels(path, os.path.join(subDir,imgname))


while times > 0:
times -= 1
for root, dirs, files in os.walk(outputdir):
for file in files:
start_time = time.time()
path = os.path.join(root,file)
imgname = os.path.basename(path)
# read image
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
print(img.shape, path)
if img.shape[0] <imgSize or img.shape[1] <imgSize:
if img.shape[0] > minImgSize or img.shape[1] > minImgSize:
img = cv2.resize(img, (minImgSize, minImgSize), interpolation=cv2.INTER_CUBIC)
test_img = util.modcrop(img, 8)
img = util.modcrop(img, 8)
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()

img_LR = util.imresize(img / 255, 1, antialiasing=True)
img = util.imresize(img_LR, 4, antialiasing=True) * 255

img[0] -= 103.939
img[1] -= 116.779
img[2] -= 123.68
img = img.unsqueeze(0)
img = img.to(device)

with torch.no_grad():
output = seg_model(img).detach().float().cpu().squeeze()

test_img = test_img * 1.0 / 255
if test_img.ndim == 2:
test_img = np.expand_dims(test_img, axis=2)
test_img = torch.from_numpy(np.transpose(test_img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = util.imresize(test_img, 1 , antialiasing=True)
img_LR = img_LR.unsqueeze(0)
img_LR = img_LR.to(device)

seg = output

seg = seg.unsqueeze(0)
seg = seg.to(device)
with torch.no_grad():
output = model((img_LR, seg)).data.float().cpu().squeeze()
output = util.tensor2img(output)
subDir = os.path.join(outputdir,root.replace(outputdir, ""))
if not os.path.exists(subDir):
os.makedirs(subDir)
util.save_img(output, os.path.join(subDir,imgname))

print("time consumption : {}".format(time.time() - start_time))
elif img.shape[0] == finalSize and img.shape[1] == finalSize:
pass
# subDir = os.path.join(outputdir,root.replace(outputdir, ""))
# if not os.path.exists(subDir):
# os.makedirs(subDir)
# cv2.imwrite(os.path.join(subDir,imgname), img)
else:
img = cv2.resize(img, (finalSize, finalSize), interpolation=cv2.INTER_CUBIC)
subDir = os.path.join(outputdir,root.replace(outputdir, ""))
if not os.path.exists(subDir):
os.makedirs(subDir)
cv2.imwrite(os.path.join(subDir,imgname), img)