-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconvert_to_fann_train.py
executable file
·160 lines (118 loc) · 5.28 KB
/
convert_to_fann_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env python
import argparse
import os
import os.path as path
import Image
def parse_args():
parser = argparse.ArgumentParser(description="Convert ascii superpixel outputs to flan traning file")
parser.add_argument("-s", "--super_pixel_dir",
required=True,
help="Directory containing ASCII .dat files generated by SLICO")
parser.add_argument("-d", "--data_dir",
required=True,
help="Directory containing KITTI image files")
parser.add_argument("-o", "--out_file",
required=True,
help="The output training data file for FANN")
parser.add_argument("-n", "--nexamples",
type=int,
help="How many training cases to generate. Default: all")
parser.add_argument('--gt_threshold',
type=float,
default=0.5,
help="Average superpixel value in ground truth to be considered a road. Default: %(default)s")
args = parser.parse_args()
#Hardcode for now
args.n_colour_channels = 3
args.n_classes = 2
return args
def main():
args = parse_args()
#Load the .dat files generated by SLICO
dat_files = sorted([path.join(args.super_pixel_dir, f) for f in os.listdir(args.super_pixel_dir)])
#Trim to nexamples if set
if args.nexamples != None:
dat_files = dat_files[:args.nexamples]
dat_super_pixel_dicts = {}
for i, dat_file in enumerate(dat_files):
super_pixel_dict = load_dat(dat_file)
if len(super_pixel_dict) > 0:
dat_super_pixel_dicts[dat_file] = super_pixel_dict
else:
print "Error with file ", dat_file, " found ", len(super_pixel_dict), " super pixels"
img_sp_counts = [len(sp_dict.keys()) for dat_file, sp_dict in dat_super_pixel_dicts.iteritems()]
max_sp = max(img_sp_counts)
avg_sp = sum(img_sp_counts) / float(len(img_sp_counts))
print "Max super pixels: ", max_sp
print "Avg super pixels: ", avg_sp
print "Writting fann training file", args.out_file
with open(args.out_file, 'w') as f:
write_fann_header(f, num_training_pairs=len(dat_super_pixel_dicts), num_inputs=args.n_colour_channels*max_sp, num_outputs=args.n_classes*max_sp)
for dat_file in dat_files:
if dat_file not in dat_super_pixel_dicts:
continue #Skip files with errors
src_img_file = path.splitext(path.basename(dat_file))[0] + '.png'
label_file = None
write_fann_training_pair(args, f, max_sp, src_img_file, dat_super_pixel_dicts[dat_file])
def write_fann_header(f, num_training_pairs, num_inputs, num_outputs):
print >>f, '{npairs} {ninputs} {noutputs}'.format(npairs=num_training_pairs, ninputs=num_inputs, noutputs=num_outputs)
def write_fann_training_pair(args, f, num_super_pixels, img_filename, super_pixel_dict):
#Determine values of super pixels
#Training image
train_img_filepath = path.join(args.data_dir, 'training', 'image_2', img_filename)
train_super_pixel_values = calc_super_pixel_values(args, train_img_filepath, super_pixel_dict)
#Ground trueth image
img_base, img_ext = path.splitext(img_filename)
img_base_split = img_base.split('_')
gt_filename = '_'.join([img_base_split[0], 'road', img_base_split[1]]) + img_ext
gt_img_filepath = path.join(args.data_dir, 'training', 'gt_image_2', gt_filename)
gt_super_pixel_values = calc_super_pixel_values(args, gt_img_filepath, super_pixel_dict)
#Write out all inputs
for sp_idx, values in train_super_pixel_values.iteritems():
print >>f, ' '.join([str(x) for x in values]),
print >>f
#Write out all outputs
for sp_idx, values in gt_super_pixel_values.iteritems():
#Third channel contains ground truth
if values[2] >= args.gt_threshold:
print >>f, "1.0",
else:
print >>f, "0.0",
print >>f
def calc_super_pixel_values(args, img_filepath, super_pixel_dict):
super_pixel_values = {}
#Open the image
img = Image.open(img_filepath)
print img.format, img.size, img.mode
w, h = img.size
pixels = img.load()
for sp_idx, pixel_idxs in super_pixel_dict.iteritems():
sp_value = [0., 0., 0.]
for pixel_idx in pixel_idxs:
x = pixel_idx % w
y = pixel_idx / w
pixel = pixels[x,y]
sp_value[0] += pixel[0]
sp_value[1] += pixel[1]
sp_value[2] += pixel[2]
#Average over all pixels
sp_value[0] /= len(pixel_idxs)
sp_value[1] /= len(pixel_idxs)
sp_value[2] /= len(pixel_idxs)
#Normalize on [0.0, 1.0]
sp_value[0] /= 255.0
sp_value[1] /= 255.0
sp_value[2] /= 255.0
super_pixel_values[sp_idx] = sp_value
return super_pixel_values
def load_dat(dat_file):
super_pixel_dict = {}
with open(dat_file) as f:
for pixel_idx, line in enumerate(f, 0):
super_pixel_idx = int(line)
if super_pixel_idx not in super_pixel_dict:
super_pixel_dict[super_pixel_idx] = []
super_pixel_dict[super_pixel_idx].append(pixel_idx)
return super_pixel_dict
if __name__ == "__main__":
main()