From 7f19013561ece9c5b96415d950762094f3949fa3 Mon Sep 17 00:00:00 2001 From: epapoutsellis Date: Sun, 2 Feb 2025 16:44:05 +0000 Subject: [PATCH 1/3] Add momentum FISTA --- .../cil/optimisation/algorithms/FISTA.py | 78 +++++++++++++------ 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/algorithms/FISTA.py b/Wrappers/Python/cil/optimisation/algorithms/FISTA.py index b2f1c510da..8e2c02693c 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/cil/optimisation/algorithms/FISTA.py @@ -230,6 +230,31 @@ def calculate_objective_function_at_point(self, x): """ return self.f(x) + self.g(x) + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +@dataclass +class MomentumCoefficient(ABC): + momentum: float = None + + @abstractmethod + def __call__(self, algo=None): + pass + +class ConstantMomentum(MomentumCoefficient): + + def __call__(self, algo): + return self.momentum + + +class Nesterov(MomentumCoefficient): + + def __call__(self, algo=None): + t_old = algo.t + algo.t = 0.5*(1 + numpy.sqrt(1 + 4*(t_old**2))) + return (t_old-1)/algo.t + class FISTA(ISTA): r"""Fast Iterative Shrinkage-Thresholding Algorithm (FISTA), see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`, is used to solve: @@ -302,6 +327,20 @@ class FISTA(ISTA): """ + @property + def momentum(self): + return self._momentum + + def set_momentum(self, momentum): + + if momentum is None: + self._momentum = Nesterov() + else: + if isinstance(momentum, Number): + self._momentum = ConstantMomentum(momentum) + else: + self._momentum = momentum + def _calculate_default_step_size(self): """Calculate the default step size if a step size rule or step size is not provided """ @@ -320,14 +359,16 @@ def _provable_convergence_condition(self): raise TypeError( "Can't check convergence criterion for non-constant step size") - def __init__(self, initial, f, g, step_size=None, preconditioner=None, **kwargs): + def __init__(self, initial, f, g, step_size = None, preconditioner=None, momentum=None, **kwargs): + self.y = initial.copy() - self.t = 1 - super(FISTA, self).__init__(initial=initial, f=f, g=g, - step_size=step_size, preconditioner=preconditioner, **kwargs) - - def update(self): + self.t = 1 + self._momentum = None + self.set_momentum(momentum=momentum) + super(FISTA, self).__init__(initial=initial, f=f, g=g, step_size=step_size, preconditioner=preconditioner,**kwargs) + + def update(self): r"""Performs a single iteration of FISTA. For :math:`k\geq 1`: .. math:: @@ -339,23 +380,16 @@ def update(self): \end{cases} """ + + self.f.gradient(self.y, out=self.x) + + # update step size + step_size = self.step_size_rule.get_step_size(self) - self.t_old = self.t - - self.f.gradient(self.y, out=self.gradient_update) - - if self.preconditioner is not None: - self.preconditioner.apply( - self, self.gradient_update, out=self.gradient_update) - - step_size = self.step_size_rule.get_step_size(self) - - self.y.sapyb(1., self.gradient_update, -step_size, out=self.y) - + self.y.sapyb(1., self.x, -step_size, out=self.y) self.g.proximal(self.y, step_size, out=self.x) - self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2))) - self.x.subtract(self.x_old, out=self.y) - self.y.sapyb(((self.t_old-1)/self.t), self.x, 1.0, out=self.y) - + + momentum = self.momentum(self) + self.y.sapyb(momentum, self.x, 1.0, out=self.y) \ No newline at end of file From e3afbe929986751dd18e3f2c019be9d54af124ff Mon Sep 17 00:00:00 2001 From: epapoutsellis Date: Mon, 3 Feb 2025 10:24:41 +0000 Subject: [PATCH 2/3] fix precond error --- Wrappers/Python/cil/optimisation/algorithms/FISTA.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Wrappers/Python/cil/optimisation/algorithms/FISTA.py b/Wrappers/Python/cil/optimisation/algorithms/FISTA.py index 8e2c02693c..20b065cd94 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/FISTA.py +++ b/Wrappers/Python/cil/optimisation/algorithms/FISTA.py @@ -382,6 +382,10 @@ def update(self): """ self.f.gradient(self.y, out=self.x) + + if self.preconditioner is not None: + self.preconditioner.apply( + self, self.gradient_update, out=self.gradient_update) # update step size step_size = self.step_size_rule.get_step_size(self) From 12252ceb941cda9c6a0b259551fa2628f2a271f0 Mon Sep 17 00:00:00 2001 From: epapoutsellis Date: Tue, 4 Feb 2025 10:21:01 +0000 Subject: [PATCH 3/3] add FISTA momentum test --- Wrappers/Python/test/test_algorithms.py | 45 +++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index 2efe206e23..4286be94c2 100644 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -458,6 +458,51 @@ def test_FISTA_APGD_alias(self): self.assertNumpyArrayEqual(alg.y.array, initial.array) mock_method.assert_called_once_with(initial=initial, f=f, g=g, step_size=4, preconditioner=None) + def test_FISTA_momentum(self): + + np.random.seed(10) + n = 1000 + m = 500 + A = np.random.normal(0,1, (m, n)).astype('float32') + # A /= np.linalg.norm(A, axis=1, keepdims=True) + b = np.random.normal(0,1, m).astype('float32') + reg = 0.5 + + Aop = MatrixOperator(A) + bop = VectorData(b) + ig = Aop.domain + + # cvxpy solutions + u_cvxpy = cvxpy.Variable(ig.shape[0]) + objective = cvxpy.Minimize( 0.5 * cvxpy.sum_squares(Aop.A @ u_cvxpy - bop.array) + reg/2 * cvxpy.sum_squares(u_cvxpy)) + p = cvxpy.Problem(objective) + p.solve(verbose=False, solver=cvxpy.SCS, eps=1e-4) + + # default fista + f = LeastSquares(A=Aop, b=bop, c=0.5) + g = reg/2*L2NormSquared() + fista = FISTA(initial=ig.allocate(), f=f, g=g, update_objective_interval=1) + fista.run(500) + np.testing.assert_allclose(fista.objective[-1], p.value, atol=1e-3) + np.testing.assert_allclose(fista.solution.array, u_cvxpy.value, atol=1e-3) + + # fista Dossal Chambolle "On the convergence of the iterates of ”FISTA” + from cil.optimisation.algorithms.FISTA import MomentumCoefficient + class DossalChambolle(MomentumCoefficient): + def __call__(self, algo=None): + return (algo.iteration-1)/(algo.iteration+50) + momentum = DossalChambolle() + fista_dc = FISTA(initial=ig.allocate(), f=f, g=g, update_objective_interval=1, momentum=momentum) + fista_dc.run(500) + np.testing.assert_allclose(fista_dc.solution.array, u_cvxpy.value, atol=1e-3) + np.testing.assert_allclose(fista_dc.solution.array, u_cvxpy.value, atol=1e-3) + + + + + + + class testISTA(CCPiTestClass):