Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split FISTA and APGD to give more options for momentum #2061

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
165 changes: 129 additions & 36 deletions Wrappers/Python/cil/optimisation/algorithms/FISTA.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from cil.optimisation.algorithms import Algorithm
from cil.optimisation.functions import ZeroFunction
from cil.optimisation.utilities import ConstantStepSize, StepSizeRule
from cil.optimisation.utilities import ConstantStepSize, StepSizeRule, NesterovMomentum, MomentumCoefficient, ConstantMomentum
import numpy
import logging
from numbers import Real, Number
Expand Down Expand Up @@ -118,7 +118,7 @@ def step_size(self):
return self.step_size_rule.step_size
else:
warnings.warn(
"Note the step-size is set by a step-size rule and could change wit each iteration")
"Note the step-size is set by a step-size rule and could change with each iteration")
return self.step_size_rule.get_step_size()

# Set default step size
Expand Down Expand Up @@ -170,6 +170,9 @@ def set_up(self, initial, f, g, step_size, preconditioner, **kwargs):
self.step_size_rule = ConstantStepSize(step_size)
elif isinstance(step_size, StepSizeRule):
self.step_size_rule = step_size
else:
raise TypeError(
"step_size must be a real number or a child class of :meth:`cil.optimisation.utilities.StepSizeRule`")

self.preconditioner = preconditioner

Expand Down Expand Up @@ -229,8 +232,127 @@ def calculate_objective_function_at_point(self, x):

"""
return self.f(x) + self.g(x)


class APGD(ISTA):

r"""Accelerated Proximal Gradient Descent (APGD), is used to solve:

.. math:: \min_{x} f(x) + g(x)

where :math:`f` is differentiable and :math:`g` has a *simple* proximal operator.


In each update the algorithm completes the following steps:

.. math::

\begin{cases}
x_{k} = \mathrm{prox}_{\alpha g}(y_{k} - \alpha\nabla f(y_{k}))\\
y_{k+1} = x_{k} + M(x_{k} - x_{k-1})
\end{cases}

where :math:`\alpha` is the :code:`step_size` and :math:`M` is the momentum coefficient.

class FISTA(ISTA):
Note that the above applies for :math:`k\geq 1`. For :math:`k=0`, :math:`x_{0}` and :math:`y_{0}` are initialised to `initial`, and :math:`t_{1}=1`.

Parameters
----------
initial : DataContainer
Starting point of the algorithm
f : Function
Differentiable function. If `None` is passed, the algorithm will use the ZeroFunction.
g : Function or `None`
Convex function with *simple* proximal operator. If `None` is passed, the algorithm will use the ZeroFunction.
step_size : positive :obj:`float` or child class of :meth:`cil.optimisation.utilities.StepSizeRule`', default = None
Step size for the gradient step of ISTA. If a float is passed, this is used as a constant step size. If a child class of :meth:`cil.optimisation.utilities.StepSizeRule` is passed then it's method :meth:`get_step_size` is called for each update.
The default :code:`step_size` is a constant :math:`\frac{1}{L}` or 1 if `f=None`.
preconditioner : class with an `apply` method or a function that takes an initialised CIL function as an argument and modifies a provided `gradient`.
This could be a custom `preconditioner` or one provided in :meth:`~cil.optimisation.utilities.preconditoner`. If None is passed then `self.gradient_update` will remain unmodified.
momentum : float or child class of :meth:`cil.optimisation.utilities.MomentumCoefficient`, default = None
Momentum coefficient. If a float is passed, this is used as a constant momentum coefficient. If a child class of :meth:`cil.optimisation.utilities.MomentumCoefficient` is passed then it's method :meth:`__call__` is called for each update. The default momentum coefficient is the Nesterov momentum coefficient.

Note
-----
Running this algorithm with the default step size and the default momentum coefficient is equivalent to running the FISTA algorithm.

"""


def __init__(self, initial, f, g, step_size=None, preconditioner=None, momentum=None, **kwargs):

self.y = initial.copy()

self.set_momentum(momentum)

super(APGD, self).__init__(initial=initial, f=f, g=g,
step_size=step_size, preconditioner=preconditioner, **kwargs)

def _calculate_default_step_size(self):
"""Calculate the default step size if a step size rule or step size is not provided
"""
return 1./self.f.L

def _provable_convergence_condition(self):
if self.preconditioner is not None:
raise NotImplementedError(
"Can't check convergence criterion if a preconditioner is used ")


if isinstance(self.step_size_rule, ConstantStepSize) and isinstance(self.momentum, NesterovMomentum):
return self.step_size_rule.step_size <= 1./self.f.L
else:
raise TypeError(
"Can't check convergence criterion for non-constant step size or non-Nesterov momentum coefficient")


@property
def momentum(self):
return self._momentum

def set_momentum(self, momentum):

if momentum is None:
self._momentum = NesterovMomentum()
else:
if isinstance(momentum, Number):
self._momentum = ConstantMomentum(momentum)
elif isinstance(momentum, MomentumCoefficient):
self._momentum = momentum
else:
raise TypeError("Momentum must be a number or a child class of MomentumCoefficient")

def update(self):
r"""Performs a single iteration of APGD. For :math:`k\geq 1`:

.. math::

\begin{cases}
x_{k} = \mathrm{prox}_{\alpha g}(y_{k} - \alpha\nabla f(y_{k}))\\
y_{k+1} = x_{k} + M(x_{k} - x_{k-1})
\end{cases}

"""

self.f.gradient(self.y, out=self.gradient_update)

if self.preconditioner is not None:
self.preconditioner.apply(
self, self.gradient_update, out=self.gradient_update)

step_size = self.step_size_rule.get_step_size(self)

self.y.sapyb(1., self.gradient_update, -step_size, out=self.y)

self.g.proximal(self.y, step_size, out=self.x)

self.x.subtract(self.x_old, out=self.y)

momentum = self.momentum(self)
self.y.sapyb(momentum, self.x, 1.0, out=self.y)


class FISTA(APGD):

r"""Fast Iterative Shrinkage-Thresholding Algorithm (FISTA), see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`, is used to solve:

Expand Down Expand Up @@ -302,6 +424,7 @@ class FISTA(ISTA):

"""


def _calculate_default_step_size(self):
"""Calculate the default step size if a step size rule or step size is not provided
"""
Expand All @@ -320,42 +443,12 @@ def _provable_convergence_condition(self):
raise TypeError(
"Can't check convergence criterion for non-constant step size")

def __init__(self, initial, f, g, step_size=None, preconditioner=None, **kwargs):

def __init__(self, initial, f, g, step_size = None, preconditioner=None, momentum=None, **kwargs):

self.y = initial.copy()
self.t = 1
super(FISTA, self).__init__(initial=initial, f=f, g=g,
step_size=step_size, preconditioner=preconditioner, **kwargs)

def update(self):
r"""Performs a single iteration of FISTA. For :math:`k\geq 1`:

.. math::

\begin{cases}
x_{k} = \mathrm{prox}_{\alpha g}(y_{k} - \alpha\nabla f(y_{k}))\\
t_{k+1} = \frac{1+\sqrt{1+ 4t_{k}^{2}}}{2}\\
y_{k+1} = x_{k} + \frac{t_{k}-1}{t_{k+1}}(x_{k} - x_{k-1})
\end{cases}

"""
step_size=step_size, preconditioner=preconditioner, momentum=None, **kwargs)

self.t_old = self.t

self.f.gradient(self.y, out=self.gradient_update)

if self.preconditioner is not None:
self.preconditioner.apply(
self, self.gradient_update, out=self.gradient_update)

step_size = self.step_size_rule.get_step_size(self)

self.y.sapyb(1., self.gradient_update, -step_size, out=self.y)

self.g.proximal(self.y, step_size, out=self.x)

self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))

self.x.subtract(self.x_old, out=self.y)
self.y.sapyb(((self.t_old-1)/self.t), self.x, 1.0, out=self.y)

2 changes: 1 addition & 1 deletion Wrappers/Python/cil/optimisation/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .FISTA import FISTA
from .FISTA import ISTA
from .FISTA import ISTA as PGD
from .FISTA import FISTA as APGD
from .FISTA import APGD
from .PDHG import PDHG
from .ADMM import LADMM
from .SPDHG import SPDHG
Expand Down
1 change: 1 addition & 0 deletions Wrappers/Python/cil/optimisation/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .sampler import SamplerRandom
from .StepSizeMethods import ConstantStepSize, ArmijoStepSizeRule, StepSizeRule, BarzilaiBorweinStepSizeRule
from .preconditioner import Preconditioner, AdaptiveSensitivity, Sensitivity
from .momentum import MomentumCoefficient, ConstantMomentum, NesterovMomentum
19 changes: 19 additions & 0 deletions Wrappers/Python/cil/optimisation/utilities/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
# Copyright 2024 United Kingdom Research and Innovation
# Copyright 2024 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# - CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt


from abc import ABC, abstractmethod
from functools import partialmethod

Expand Down
76 changes: 76 additions & 0 deletions Wrappers/Python/cil/optimisation/utilities/momentum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024 United Kingdom Research and Innovation
# Copyright 2024 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# - CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt


from abc import ABC, abstractmethod
import numpy

class MomentumCoefficient(ABC):
'''Abstract base class for MomentumCoefficient objects. The `__call__` method of this class returns the momentum coefficient for the given iteration.
'''
def __init__(self):
'''Initialises the meomentum coefficient object.
'''
pass

@abstractmethod
def __call__(self, algorithm):
'''Returns the momentum coefficient for the given iteration.

Parameters
----------
algorithm: CIL Algorithm
The algorithm object.
'''

pass

class ConstantMomentum(MomentumCoefficient):

'''MomentumCoefficient object that returns a constant momentum coefficient.

Parameters
----------
momentum: float
The momentum coefficient.
'''

def __init__(self, momentum):
self.momentum = momentum

def __call__(self, algorithm):
return self.momentum

class NesterovMomentum(MomentumCoefficient):

'''MomentumCoefficient object that returns the Nesterov momentum coefficient.

Parameters
----------
t: float
The initial value for the momentum coefficient.
'''

def __init__(self, t= 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need t=1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think self.t=1 fits better in the __init__ as before because it is one of the parameters need to initialise the algorithm.

self.t = 1

def __call__(self, algorithm):
self.t_old = self.t
self.t = 0.5*(1 + numpy.sqrt(1 + 4*(self.t_old**2)))
return (self.t_old-1)/self.t

Loading
Loading