Source code for scico.optimize._pgm

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Proximal Gradient Method classes."""

# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import Optional, Union

import jax

import scico.numpy as snp
from scico.functional import Functional
from scico.loss import Loss
from scico.numpy import Array, BlockArray

from ._common import Optimizer
from ._pgmaux import (
    AdaptiveBBStepSize,
    BBStepSize,
    PGMStepSize,
    RobustLineSearchStepSize,
)


class PGM(Optimizer):
    r"""Proximal Gradient Method (PGM) base class.

    Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`, where
    :math:`f` and the :math:`g` are instances of :class:`.Functional`.

    Uses helper :class:`StepSize` to provide an estimate of the Lipschitz
    constant :math:`L` of :math:`f`. The step size :math:`\alpha` is the
    reciprocal of :math:`L`, i.e.: :math:`\alpha = 1 / L`.
    """

    def __init__(
        self,
        f: Union[Loss, Functional],
        g: Functional,
        L0: float,
        x0: Union[Array, BlockArray],
        step_size: Optional[PGMStepSize] = None,
        **kwargs,
    ):
        r"""

        Args:
            f: Loss or Functional object with `grad` defined.
            g: Instance of Functional with defined prox method.
            L0: Initial estimate of Lipschitz constant of f.
            x0: Starting point for :math:`\mb{x}`.
            step_size: helper :class:`StepSize` to estimate the Lipschitz
                constant of f.
            **kwargs: Additional optional parameters handled by
                initializer of base class :class:`.Optimizer`.
        """

        #: Functional or Loss to minimize; must have grad method defined.
        self.f: Union[Loss, Functional] = f

        if g.has_prox is not True:
            raise ValueError(f"The functional g ({type(g)}) must have a prox method.")

        #: Functional to minimize; must have prox defined
        self.g: Functional = g

        if step_size is None:
            step_size = PGMStepSize()
        self.step_size: PGMStepSize = step_size
        self.step_size.internal_init(self)
        self.L: float = L0  # reciprocal of step size (estimate of Lipschitz constant of f)
        self.fixed_point_residual = snp.inf

        def x_step(v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]:
            return self.g.prox(v - 1.0 / L * self.f.grad(v), 1.0 / L)

        self.x_step = jax.jit(x_step)

        self.x: Union[Array, BlockArray] = x0  # current estimate of solution

        super().__init__(**kwargs)

    def _working_vars_finite(self) -> bool:
        """Determine where ``NaN`` of ``Inf`` encountered in solve.

        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in
        a solver working variable.
        """
        return snp.all(snp.isfinite(self.x))

    def _objective_evaluatable(self):
        """Determine whether the objective function can be evaluated."""
        return self.f.has_eval and self.g.has_eval

    def _itstat_extra_fields(self):
        """Define linearized ADMM-specific iteration statistics fields."""
        itstat_fields = {"L": "%9.3e", "Residual": "%9.3e"}
        itstat_attrib = ["L", "norm_residual()"]
        return itstat_fields, itstat_attrib

[docs] def minimizer(self): """Return current estimate of the functional mimimizer.""" return self.x
[docs] def objective(self, x: Optional[Union[Array, BlockArray]] = None) -> float: r"""Evaluate the objective function :math:`f(\mb{x}) + g(\mb{x})`.""" if x is None: x = self.x return self.f(x) + self.g(x)
[docs] def f_quad_approx( self, x: Union[Array, BlockArray], y: Union[Array, BlockArray], L: float ) -> float: r"""Evaluate the quadratic approximation to function :math:`f`. Evaluate the quadratic approximation to function :math:`f`, corresponding to :math:`\hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\|\mb{x} - \mb{y}\right\|_2^2`. """ diff_xy = x - y return ( self.f(y) + snp.sum(snp.real(snp.conj(self.f.grad(y)) * diff_xy)) + 0.5 * L * snp.linalg.norm(diff_xy) ** 2 )
[docs] def norm_residual(self) -> float: r"""Return the fixed point residual. Return the fixed point residual (see Sec. 4.3 of :cite:`liu-2018-first`). """ return self.fixed_point_residual
[docs] def step(self): """Take a single PGM step.""" # Update reciprocal of step size using current solution. self.L = self.step_size.update(self.x) x = self.x_step(self.x, self.L) self.fixed_point_residual = snp.linalg.norm(self.x - x) self.x = x
class AcceleratedPGM(PGM): r"""Accelerated Proximal Gradient Method (AcceleratedPGM) base class. Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`. Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`, where :math:`f` and the :math:`g` are instances of :class:`.Functional`. The accelerated form of PGM is also known as FISTA :cite:`beck-2009-fast`. For documentation on inherited attributes, see :class:`.PGM`. """ def __init__( self, f: Union[Loss, Functional], g: Functional, L0: float, x0: Union[Array, BlockArray], step_size: Optional[PGMStepSize] = None, **kwargs, ): r""" Args: f: Loss or Functional object with `grad` defined. g: Instance of Functional with defined prox method. L0: Initial estimate of Lipschitz constant of f. x0: Starting point for :math:`\mb{x}`. step_size: helper :class:`StepSize` to estimate the Lipschitz constant of f. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ super().__init__(f=f, g=g, L0=L0, x0=x0, step_size=step_size, **kwargs) self.v = x0 self.t = 1.0
[docs] def step(self): """Take a single AcceleratedPGM step.""" x_old = self.x # Update reciprocal of step size using current extrapolation. if isinstance(self.step_size, (AdaptiveBBStepSize, BBStepSize)): self.L = self.step_size.update(self.x) else: self.L = self.step_size.update(self.v) if isinstance(self.step_size, RobustLineSearchStepSize): # Robust line search step size uses a different extrapolation sequence. # Update in solution is computed while updating the reciprocal of step size. self.x = self.step_size.Z self.fixed_point_residual = snp.linalg.norm(self.x - x_old) else: self.x = self.x_step(self.v, self.L) self.fixed_point_residual = snp.linalg.norm(self.x - self.v) t_old = self.t self.t = 0.5 * (1 + snp.sqrt(1 + 4 * t_old**2)) self.v = self.x + ((t_old - 1) / self.t) * (self.x - x_old)
def _working_vars_finite(self) -> bool: """Determine where ``NaN`` of ``Inf`` encountered in solve. Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in a solver working variable. """ return snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.v))