Skip to content

Commit

Permalink
EWS + GD with BB
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdenker committed Sep 27, 2024
1 parent c671234 commit ea06712
Show file tree
Hide file tree
Showing 21 changed files with 70 additions and 3,113 deletions.
163 changes: 0 additions & 163 deletions bsrem.py

This file was deleted.

81 changes: 64 additions & 17 deletions bsrem_bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,47 @@
import numpy as np
import sirf.STIR as STIR
from sirf.Utilities import examples_data_path

import torch

from cil.optimisation.algorithms import Algorithm
from utils.herman_meyer import herman_meyer_order
import time

class RDPDiagHessTorch:
def __init__(self, rdp_diag_hess, prior):
self.epsilon = prior.get_epsilon()
self.gamma = prior.get_gamma()
self.penalty_strength = prior.get_penalisation_factor()

self.weights = torch.zeros([3,3,3]).cuda()
self.kappa = torch.tensor(prior.get_kappa().as_array()).cuda()
self.kappa_padded = torch.nn.functional.pad(self.kappa[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0]
voxel_sizes = rdp_diag_hess.voxel_sizes()
z_dim, y_dim, x_dim = rdp_diag_hess.shape
for i in range(3):
for j in range(3):
for k in range(3):
self.weights[i,j,k] = voxel_sizes[2]/np.sqrt(((i-1)*voxel_sizes[0])**2 + ((j-1)*voxel_sizes[1])**2 + ((k-1)*voxel_sizes[2])**2)
self.weights[1,1,1] = 0
self.z_dim = z_dim
self.y_dim = y_dim
self.x_dim = x_dim


def compute(self, x, precond):
x = torch.tensor(x.as_array(), dtype=torch.float32).cuda()
x_padded = torch.nn.functional.pad(x[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0]
x_rdp_diag_hess = torch.zeros_like(x)
for dz in range(3):
for dy in range(3):
for dx in range(3):
x_neighbour = x_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim]
kappa_neighbour = self.kappa_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim]
kappa_val = self.kappa * kappa_neighbour
numerator = 4 * (2 * x_neighbour + self.epsilon) ** 2
denominator = (x + x_neighbour + self.gamma * torch.abs(x - x_neighbour) + self.epsilon) ** 3
x_rdp_diag_hess += self.weights[dz, dy, dx] * self.penalty_strength * kappa_val * numerator / denominator

precond.fill(x_rdp_diag_hess.cpu().numpy())


class BSREMSkeleton(Algorithm):
Expand Down Expand Up @@ -53,8 +89,6 @@ def __init__(self, data, initial,
# add a small number to avoid division by zero in the preconditioner
self.average_sensitivity += self.average_sensitivity.max()/1e4



self.precond = initial.get_uniform_copy(0)

self.subset = 0
Expand All @@ -68,6 +102,8 @@ def __init__(self, data, initial,

self.x_update = initial.get_uniform_copy(0)

self.rdp_hessian_freq = 4

def subset_sensitivity(self, subset_num):
raise NotImplementedError

Expand All @@ -83,26 +119,37 @@ def step_size(self):
def update(self):

g = self.subset_gradient(self.x, self.subset_order[self.subset])

g.multiply(self.x + self.eps, out=self.x_update)

self.x_update.divide(self.average_sensitivity, out=self.x_update)

if self.iteration == 0:
step_size = min(max(1/(self.x_update.norm() + 1e-3), 0.005), 3.0)
prior_grad = self.dataset.prior.gradient(self.x)
if prior_grad.norm()/g.norm() > 0.5:
self.rdp_diag_hess_obj = RDPDiagHessTorch(self.dataset.OSEM_image.copy(), self.dataset.prior)
self.lkhd_precond = self.dataset.kappa.power(2)
self.compute_rdp_diag_hess = True
self.eps = self.lkhd_precond.max()/1e4
else:
self.compute_rdp_diag_hess = False

if self.compute_rdp_diag_hess:
if self.iteration % self.rdp_hessian_freq == 0:
self.rdp_diag_hess_obj.compute(self.x, self.precond)

g.divide(self.lkhd_precond + self.precond + self.eps, out=self.x_update)
else:
g.multiply(self.x + self.eps, out=self.x_update)
self.x_update.divide(self.average_sensitivity, out=self.x_update)


if self.iteration == 0:
step_size = min(max(1/(self.x_update.norm() + 1e-3), 0.001), 3.0)
else:
delta_x = self.x - self.x_prev
delta_g = self.x_update_prev - self.x_update

dot_product = delta_g.dot(delta_x) # (deltag * deltax).sum()
dot_product = delta_g.dot(delta_x)
alpha_long = delta_x.norm()**2 / np.abs(dot_product)
#dot_product = delta_x.dot(delta_g)
#alpha_short = np.abs((dot_product).sum()) / delta_g.norm()**2
#print("short / long: ", alpha_short, alpha_long)


step_size = max(alpha_long, 0.01) #np.sqrt(alpha_long*alpha_short)
#print("step size: ", step_size)
#print("step size: ", step_size)
step_size = max(alpha_long, 0.001)

self.x_prev = self.x.copy()
self.x_update_prev = self.x_update.copy()
Expand Down
Loading

0 comments on commit ea06712

Please sign in to comment.