-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathinference.py
89 lines (70 loc) · 3.88 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
author: Min Seok Lee and Wooseok Shin
"""
import os
import cv2
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
from tqdm import tqdm
from dataloader import get_test_augmentation, get_loader
from model.TRACER import TRACER
from util.utils import load_pretrained
class Inference():
def __init__(self, args, save_path):
super(Inference, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.test_transform = get_test_augmentation(img_size=args.img_size)
self.args = args
self.save_path = save_path
# Network
self.model = TRACER(args).to(self.device)
if args.multi_gpu:
self.model = nn.DataParallel(self.model).to(self.device)
path = load_pretrained(f'TE-{args.arch}')
self.model.load_state_dict(path)
print('###### pre-trained Model restored #####')
te_img_folder = os.path.join(args.data_path, args.dataset)
te_gt_folder = None
self.test_loader = get_loader(te_img_folder, te_gt_folder, edge_folder=None, phase='test',
batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, transform=self.test_transform)
if args.save_map is not None:
os.makedirs(os.path.join('mask', self.args.dataset), exist_ok=True)
os.makedirs(os.path.join('object', self.args.dataset), exist_ok=True)
def test(self):
self.model.eval()
t = time.time()
with torch.no_grad():
for i, (images, original_size, image_name) in enumerate(tqdm(self.test_loader)):
images = torch.tensor(images, device=self.device, dtype=torch.float32)
outputs, edge_mask, ds_map = self.model(images)
H, W = original_size
for i in range(images.size(0)):
h, w = H[i].item(), W[i].item()
output = F.interpolate(outputs[i].unsqueeze(0), size=(h, w), mode='bilinear')
# Save prediction map
if self.args.save_map is not None:
output = (output.squeeze().detach().cpu().numpy() * 255.0).astype(np.uint8)
salient_object = self.post_processing(images[i], output, h, w)
cv2.imwrite(os.path.join('mask', self.args.dataset, image_name[i] + '.png'), output)
cv2.imwrite(os.path.join('object', self.args.dataset, image_name[i] + '.png'), salient_object)
print(f'time: {time.time() - t:.3f}s')
def post_processing(self, original_image, output_image, height, width, threshold=200):
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225]),
transforms.Normalize(mean=[-0.485, -0.456, -0.406],
std=[1., 1., 1.]),
])
original_image = invTrans(original_image)
original_image = F.interpolate(original_image.unsqueeze(0), size=(height, width), mode='bilinear')
original_image = (original_image.squeeze().permute(1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)
rgba_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2BGRA)
output_rbga_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2BGRA)
output_rbga_image[:, :, 3] = output_image # Extract edges
edge_y, edge_x, _ = np.where(output_rbga_image <= threshold) # Edge coordinates
rgba_image[edge_y, edge_x, 3] = 0
return cv2.cvtColor(rgba_image, cv2.COLOR_RGBA2BGRA)