-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtomography_hand.py
executable file
·408 lines (313 loc) · 15.5 KB
/
tomography_hand.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
"""2D tomograpgy example using entropy regularized optimal transport.
This example uses a hand-phantom and needs the images in order to run.
These images do not belong to the authors and are therefore not included.
"""
import numpy as np
import scipy
import odl
import time
import matplotlib.pyplot as plt
import sys
from transport_cost import EntropyRegularizedOptimalTransport, KMatrixFFT2
from utils import Logger, CallbackShowAndSave, CallbackPrintDiff
# Seed randomness for reproducability
np.random.seed(seed=1)
# =========================================================================== #
# Create a log that is written to disc
# =========================================================================== #
time_str = (str(time.localtime().tm_year) + str(time.localtime().tm_mon) +
str(time.localtime().tm_mday) + '_' +
str(time.localtime().tm_hour) + str(time.localtime().tm_min))
output_filename = 'Output_' + time_str + '.txt'
sys.stdout = Logger(output_filename)
# =========================================================================== #
# Set up the tomography problem and create phantom
# =========================================================================== #
# Select data type to use
dtype = 'float64'
# Create a discrete space matching the image size
n_image = 256
image_space = odl.uniform_discr(min_pt=[-20, -20], max_pt=[20, 20],
shape=[n_image, n_image], dtype=dtype)
# Create a discrete reconstruction space
n = 256
reco_space = odl.uniform_discr(min_pt=[-20, -20], max_pt=[20, 20],
shape=[n, n], dtype=dtype)
# Make a parallel beam geometry with flat detector with uniform angle and
# detector partition
angle_partition = odl.uniform_partition(0, np.pi, 15) # 30
detector_partition = odl.uniform_partition(-30, 30, 350)
geometry = odl.tomo.Parallel2dGeometry(angle_partition, detector_partition)
# Ray transform (= forward projection). We use ASTRA CUDA backend.
ray_trafo_unscaled = odl.tomo.RayTransform(reco_space, geometry, impl='astra_cuda')
ray_trafo_scaling = np.sqrt(2)
ray_trafo = ray_trafo_scaling * ray_trafo_unscaled
# Read the images as phantom and prior, down-sample them if needed
tmp_image = np.rot90(scipy.misc.imread('/home/aringh/Downloads/handnew2.png'), k=-1)
phantom_full = image_space.element(tmp_image)
tmp_image = np.rot90(scipy.misc.imread('/home/aringh/Downloads/handnew1.png'), k=-1)
prior_full = image_space.element(tmp_image)
if not n_image == n:
resample_op = odl.Resampling(image_space, reco_space)
phantom = resample_op(phantom_full)
prior = resample_op(prior_full)
else:
phantom = phantom_full
prior = prior_full
# Make sure they are nonnegative
prior = prior+1e-4 # 1e-6
phantom = phantom+1e-4 # 1e-6
# Show the phantom and the prior
phantom.show(title='Phantom', saveto='Phantom')
prior.show(title='Prior', saveto='Prior')
no_title = '_no_title'
phantom.show(saveto='Phantom'+no_title)
prior.show(saveto='Prior'+no_title)
# Create projection data by calling the ray transform on the phantom
# proj_data = ray_trafo(phantom)
proj_data = ray_trafo(phantom)
# =========================================================================== #
# Display to show how mass moves
# =========================================================================== #
ball_pos = [[10, 2], [-5, 5], [2, -7]]
ball_rad = [np.sqrt(2), np.sqrt(2), np.sqrt(2)]
def balls(x):
"""Helper function for drawing the circles."""
ball = reco_space.zero()
for i, r in zip(ball_pos, ball_rad):
ball += reco_space.element(((x[0]-i[0])**2 +
(x[1]-i[1])**2 <= r**2).astype(int))
return ball
transport_mask = reco_space.element(balls)
fig_prior = prior.show()
ax_prior = fig_prior.gca()
for i, r in zip(ball_pos, ball_rad):
circle = plt.Circle((i[0], i[1]), r, color='w', fill=False, linewidth=2.5)
ax_prior.add_artist(circle)
fig_prior.savefig('masked_prior' + no_title)
# =========================================================================== #
# Parameters for the reconstructions
# =========================================================================== #
# Parameters for Douglas-Rachford solver
douglas_rachford_iter = 10000
scaling_tau = 0.5
scaling_sigma = 1.0 / scaling_tau
tau = 1.0 * scaling_tau # tau same in all solvers. Some contains 3 sigmas.
scaling_tau_omt = 500.0
scaling_sigma_omt = 1.0 / scaling_tau_omt
tau_omt = 1.0 * scaling_tau_omt # tau same in all solvers. Some contains 3 sigmas.
# Amount of added noise
noise_level = 0.03
data_eps = proj_data.norm() * noise_level * 1.2
# 1) TV reconstruction
reg_param_TV = 1.0
# 2) L2 + TV regularization with prior
reg_param_TV_l2_and_tv = 1.0
# 3) Optimal transport
sinkhorn_iter = 200
epsilon = 1.5
reg_param_op_TV = 1.0
# Add noise to data
noise = odl.phantom.white_noise(proj_data.space)
noise = noise / noise.norm() * proj_data.norm() * noise_level
noisy_data = proj_data + noise
# Constructing data-matching functional
data_func = odl.solvers.IndicatorLpUnitBall(proj_data.space, 2).translated(noisy_data / data_eps) / data_eps
# Components common for several methods
gradient = odl.Gradient(reco_space)
gradient_norm = odl.power_method_opnorm(gradient, maxiter=1000)
ray_trafo_norm = odl.power_method_opnorm(ray_trafo, maxiter=1000)
show_func_L2_data = odl.solvers.L2Norm(proj_data.space).translated( noisy_data / data_eps) / data_eps * ray_trafo
show_func_TV = odl.solvers.GroupL1Norm(gradient.range) * gradient
# =========================================================================== #
# Print parameters used
# =========================================================================== #
print(reco_space)
print(geometry)
print('noise_level:', str(noise_level))
print('data_func =', str(data_func))
print('douglas_rachford_iter:' + str(douglas_rachford_iter))
print('scaling_tau: ', str(scaling_tau))
print('reg_param_TV_l2_and_tv: ', str(reg_param_TV_l2_and_tv))
print('scaling_tau_omt: ', str(scaling_tau_omt))
print('reg_param_op_TV: ', str(reg_param_op_TV))
print('sinkhorn_iter: ', str(sinkhorn_iter))
print('epsilon: ', str(epsilon))
# =========================================================================== #
# Settint up in order to print appropriate things in each iteration
# =========================================================================== #
callback = (odl.solvers.CallbackPrintIteration() &
odl.solvers.CallbackPrintTiming() &
CallbackShowAndSave(show_funcs=[show_func_L2_data, show_func_TV],
display_step=50) &
CallbackPrintDiff(data_func=show_func_L2_data, display_step=2))
# =========================================================================== #
# Filtered Backprojection
# =========================================================================== #
# Create FBP reconstruction using a Hann filter
fbp_op = odl.tomo.fbp_op(ray_trafo_unscaled, filter_type='Hann',
frequency_scaling=0.5)
fbp_reconstruction = fbp_op(noisy_data/ray_trafo_scaling)
fbp_reconstruction.show('Filtered backprojection',
saveto='Filtered Backprojection')
fbp_reconstruction.show(clim=[0, 255.0],
saveto='Filtered Backprojection'+no_title)
# =========================================================================== #
# TV
# =========================================================================== #
# Assemble TV functional
print('======================================================================')
print('TV')
print('======================================================================')
TV_func = reg_param_TV * odl.solvers.GroupL1Norm(gradient.range)
data_func_tv = data_func
f = odl.solvers.IndicatorBox(reco_space, lower=0) # , upper=255)
g = [data_func_tv, TV_func]
L = [ray_trafo, gradient]
sigma_unscaled = [1 / ray_trafo_norm**2, 1 / gradient_norm**2]
sigma = [s * scaling_sigma for s in sigma_unscaled]
# Solve the problem
x_tv = reco_space.one()
callback.reset()
odl.solvers.douglas_rachford_pd(x=x_tv, f=f, g=g, L=L, tau=tau, sigma=sigma,
niter=douglas_rachford_iter, callback=callback)
x_tv.show('TV reconstruction', saveto='TV reconstruction')
x_tv.show(clim=[0, 255.0], saveto='TV reconstruction'+no_title)
# =========================================================================== #
# L2 + TV regularization with prior
# =========================================================================== #
# Assemble regularizing and data functional
print('======================================================================')
print('L2 + TV regularization with prior')
print('======================================================================')
for l2_reg_param_loop in [100.0, 10.0, 1.0, 0.1]:
print('=================================')
print('l2_reg_param_loop: '+str(l2_reg_param_loop))
print('=================================')
data_func_l2_l2_and_tv = data_func
l2_reg_func_l2_and_tv = l2_reg_param_loop * odl.solvers.L2NormSquared(
reco_space).translated(prior)
TV_func_l2_and_tv = reg_param_TV_l2_and_tv * odl.solvers.GroupL1Norm(gradient.range)
f = odl.solvers.IndicatorBox(reco_space, lower=0) # , upper=255)
g = [data_func_l2_l2_and_tv, l2_reg_func_l2_and_tv, TV_func_l2_and_tv]
L = [ray_trafo, odl.IdentityOperator(reco_space), gradient]
sigma_unscaled = [1 / ray_trafo_norm**2, 1.0, 1 / gradient_norm**2]
sigma = [s * scaling_sigma for s in sigma_unscaled]
# Solve the problem
x_l2 = reco_space.one()
callback.reset()
odl.solvers.douglas_rachford_pd(x=x_l2, f=f, g=g, L=L, tau=tau,
sigma=sigma, niter=douglas_rachford_iter,
callback=callback)
x_l2.show(('L2 + TV regularization with prior reg param ' +
str(l2_reg_param_loop)),
saveto=('L2 plus TV regularization, reg param ' +
str(l2_reg_param_loop)).replace('.', '_'))
x_l2.show(clim=[0, 255.0], saveto=('L2 plus TV regularization, reg param ' +
str(l2_reg_param_loop).replace('.', '_') +
no_title))
# =========================================================================== #
# Optimal transport
# =========================================================================== #
print('======================================================================')
print('Optimal transport:')
print('======================================================================')
# Define the transportation cost
tmp = np.arange(0, n, 1, dtype=dtype) * (1 / n) * 40.0 # Normalize cost to n indep.
tmp = tmp[:, np.newaxis]
v_ones = np.ones(n, dtype=dtype)
v_ones = v_ones[np.newaxis, :]
x = np.dot(tmp, v_ones)
tmp = np.transpose(tmp)
v_ones = np.transpose(v_ones)
y = np.dot(v_ones, tmp)
tmp_mat = (x + 1j*y).flatten()
tmp_mat = tmp_mat[:, np.newaxis]
long_v_ones = np.transpose(np.ones(tmp_mat.shape, dtype=dtype))
# This is the matrix defining the distance
matrix_param = np.minimum(20.0**2,np.abs(x + 1j*y)**2)
for reg_para_loop in [4.0]:
print('=================================')
print('reg_para_loop: ', str(reg_para_loop))
print('=================================')
try:
# Creating the optimal transport functional and proximal
opt_trans_func = EntropyRegularizedOptimalTransport(space=reco_space,
matrix_param=matrix_param, K_class=KMatrixFFT2, epsilon=epsilon,
mu0=prior, niter=sinkhorn_iter)
callback_omt = (odl.solvers.CallbackPrintIteration() &
odl.solvers.CallbackPrintTiming() &
CallbackShowAndSave(file_prefix=('omt_dr_reg_' +
str(reg_para_loop) +
'_iter').replace('.',
'_'),
display_step=25,
show_funcs=[show_func_L2_data, show_func_TV]) &
CallbackPrintDiff(data_func=show_func_L2_data,
display_step=2))
# Assemble data and TV functionals
data_func_op = data_func
TV_op_func = reg_param_op_TV * odl.solvers.GroupL1Norm(gradient.range)
f = reg_para_loop * opt_trans_func
g = [data_func_op, TV_op_func]
L = [ray_trafo, gradient]
sigma_unscaled_omt = [1/ray_trafo_norm**2, 1/gradient_norm**2]
sigma_omt = [s * scaling_sigma_omt for s in sigma_unscaled_omt]
# Solve the prolbem
x_op = x_tv.copy() + 0.01 # Start from TV-reconstuction to save time
callback.reset()
t = time.time() # Measure the time
odl.solvers.douglas_rachford_pd(x=x_op, f=f, g=g, L=L, tau=tau_omt,
lam=1.8, sigma=sigma_omt,
niter=douglas_rachford_iter,
callback=callback_omt)
t = time.time() - t
print('Time to solve the problem: {}'.format(int(t)))
# Show and save reconstruction
x_op.show(title=('Optimal transport + TV reconstruction, reg param ' +
str(reg_para_loop)),
saveto=('Optimal transport + TV reconstruction, reg param ' +
str(reg_para_loop)).replace('.', '_'))
x_op.show(clim=[0, 255.0], saveto=('Optimal transport + TV reconstruction, reg param ' +
str(reg_para_loop)).replace('.', '_') + no_title)
except:
reco_space.one().show(saveto=('Crashed, reg param ' +
str(reg_para_loop)).replace('.', '_') +
no_title)
# =========================================================================== #
# This dumps the omt reconstruction and other things to disc
# =========================================================================== #
import pickle
with open('omt_recon.pickle', 'wb') as f: # Python 3: open(..., 'wb')
pickle.dump([prior, phantom, proj_data, noise, noisy_data, transport_mask,
x_op], f)
# =========================================================================== #
# Post manipulation of the transport to generate figure of moved mass
# =========================================================================== #
# Making sure the reconstruction is postive, and making it slightly more
# well-conditioned
tmp = np.min(x_op)
x_op_ture = x_op.copy()
if tmp < 0:
x_op = x_op + (1e-4 - tmp)
elif tmp < 1e-4:
x_op = x_op + 1e-4
deformed_mask = opt_trans_func.deform_image(x_op, transport_mask)
fig_defo_mask_text = deformed_mask.show(
'Mass movement from prior to reconstruction')
fig_defo_mask = deformed_mask.show()
ax_defo_mask_text = fig_defo_mask_text.gca()
ax_defo_mask = fig_defo_mask.gca()
for i, r in zip(ball_pos, ball_rad):
circle_text = plt.Circle((i[0], i[1]), r, color='w', fill=False, linewidth=2.5)
circle = plt.Circle((i[0], i[1]), r, color='w', fill=False, linewidth=2.5)
ax_defo_mask_text.add_artist(circle_text)
ax_defo_mask.add_artist(circle)
save_string = ('Optimal transport + TV reconstruction, ' +
'reg param ' + str(reg_para_loop).replace('.', '_') +
'mass movment_postManipulation')
fig_defo_mask_text.savefig(save_string)
fig_defo_mask.savefig(save_string + no_title)
# Close the logger and only write in terminal again
sys.stdout.log.close()
sys.stdout = sys.stdout.terminal