-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest.py
More file actions
61 lines (48 loc) · 1.82 KB
/
Copy pathtest.py
File metadata and controls
61 lines (48 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from __future__ import division
from __future__ import print_function
import os, time, scipy.io
import argparse
import tensorflow as tf
from tensorflow.contrib.layers.python.layers import initializers
import numpy as np
import glob
import re
import cv2
from model import *
parser = argparse.ArgumentParser(description='Testing on DND dataset')
parser.add_argument('--ckpt', type=str, default='all',
choices=['all', 'real', 'synthetic'], help='checkpoint type')
parser.add_argument('--cpu', nargs='?', const=1, help = 'Use CPU')
args = parser.parse_args()
if not args.cpu:
print('Using GPU!')
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Inference on CPU
print('Using CPU!')
input_dir = './dataset/test/'
checkpoint_dir = './checkpoint/' + args.ckpt
result_dir = './result/'
test_fns = glob.glob(input_dir + '*.bmp')
# model setting
in_image = tf.placeholder(tf.float32, [None, None, None, 3])
_, out_image = CBDNet(in_image)
# load model
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
print('loaded', checkpoint_dir)
saver.restore(sess, ckpt.model_checkpoint_path)
if not os.path.isdir(result_dir + 'test/'):
os.makedirs(result_dir + 'test/')
for ind, test_fn in enumerate(test_fns):
print(test_fn)
noisy_img = cv2.imread(test_fn)
noisy_img = noisy_img[:,:,::-1] / 255.0
noisy_img = np.array(noisy_img).astype('float32')
temp_noisy_img = np.expand_dims(noisy_img, axis=0)
output = sess.run(out_image, feed_dict={in_image:temp_noisy_img})
output = np.clip(output, 0, 1)
temp = np.concatenate((temp_noisy_img[0, :, :, :], output[0, :, :, :]), axis=1)
scipy.misc.toimage(temp*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + 'test/test_%d.jpg'%(ind))