Source code for scico.optimize._admm

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""ADMM solver."""

# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import List, Optional, Tuple, Union

import scico.numpy as snp
from scico.functional import Functional
from scico.linop import LinearOperator
from scico.numpy import Array, BlockArray
from scico.numpy.linalg import norm

from ._admmaux import (
    FBlockCircularConvolveSolver,
    G0BlockCircularConvolveSolver,
    GenericSubproblemSolver,
    LinearSubproblemSolver,
    MatrixSubproblemSolver,
    SubproblemSolver,
)
from ._common import Optimizer


class ADMM(Optimizer):
    r"""Basic Alternating Direction Method of Multipliers (ADMM) algorithm.

    |

    Solve an optimization problem of the form

    .. math::
        \argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;,

    where :math:`f` and the :math:`g_i` are instances of
    :class:`.Functional`, and the :math:`C_i` are
    :class:`.LinearOperator`.

    The optimization problem is solved by introducing the splitting
    :math:`\mb{z}_i = C_i \mb{x}` and solving

    .. math::
        \argmin_{\mb{x}, \mb{z}_i} \; f(\mb{x}) + \sum_{i=1}^N
        g_i(\mb{z}_i) \; \text{such that}\; C_i \mb{x} = \mb{z}_i \;,

    via an ADMM algorithm :cite:`glowinski-1975-approximation`
    :cite:`gabay-1976-dual` :cite:`boyd-2010-distributed` consisting of
    the iterations (see :meth:`step`)

    .. math::
       \begin{aligned}
       \mb{x}^{(k+1)} &= \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i
       \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i
       \mb{x}}_2^2 \\
       \mb{z}_i^{(k+1)} &= \argmin_{\mb{z}_i} \; g_i(\mb{z}_i) +
       \frac{\rho_i}{2}
       \norm{\mb{z}_i - \mb{u}^{(k)}_i - C_i \mb{x}^{(k+1)}}_2^2  \\
       \mb{u}_i^{(k+1)} &=  \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} -
       \mb{z}^{(k+1)}_i  \; .
       \end{aligned}


    Attributes:
        f (:class:`.Functional`): Functional :math:`f` (usually a
            :class:`.Loss`)
        g_list (list of :class:`.Functional`): List of :math:`g_i`
            functionals. Must be same length as :code:`C_list` and
            :code:`rho_list`.
        C_list (list of :class:`.LinearOperator`): List of :math:`C_i`
            operators.
        rho_list (list of scalars): List of :math:`\rho_i` penalty
            parameters. Must be same length as :code:`C_list` and
            :code:`g_list`.
        alpha (float): Relaxation parameter.
        u_list (list of array-like): List of scaled Lagrange multipliers
            :math:`\mb{u}_i` at current iteration.
        x (array-like): Solution.
        subproblem_solver (:class:`.SubproblemSolver`): Solver for
            :math:`\mb{x}`-update step.
        z_list (list of array-like): List of auxiliary variables
            :math:`\mb{z}_i` at current iteration.
        z_list_old (list of array-like): List of auxiliary variables
            :math:`\mb{z}_i` at previous iteration.
    """

    def __init__(
        self,
        f: Functional,
        g_list: List[Functional],
        C_list: List[LinearOperator],
        rho_list: List[float],
        alpha: float = 1.0,
        x0: Optional[Union[Array, BlockArray]] = None,
        subproblem_solver: Optional[SubproblemSolver] = None,
        **kwargs,
    ):
        r"""Initialize an :class:`ADMM` object.

        Args:
            f: Functional :math:`f` (usually a loss function).
            g_list: List of :math:`g_i` functionals. Must be same length
                 as :code:`C_list` and :code:`rho_list`.
            C_list: List of :math:`C_i` operators.
            rho_list: List of :math:`\rho_i` penalty parameters.
                Must be same length as :code:`C_list` and :code:`g_list`.
            alpha: Relaxation parameter. No relaxation for default 1.0.
            x0: Initial value for :math:`\mb{x}`. If ``None``, defaults
                to an array of zeros.
            subproblem_solver: Solver for :math:`\mb{x}`-update step.
                Defaults to ``None``, which implies use of an instance of
                :class:`GenericSubproblemSolver`.
            **kwargs: Additional optional parameters handled by
                initializer of base class :class:`.Optimizer`.
        """
        N = len(g_list)
        if len(C_list) != N:
            raise ValueError(f"len(C_list)={len(C_list)} not equal to len(g_list)={N}.")
        if len(rho_list) != N:
            raise ValueError(f"len(rho_list)={len(rho_list)} not equal to len(g_list)={N}.")

        self.f: Functional = f
        self.g_list: List[Functional] = g_list
        self.C_list: List[LinearOperator] = C_list
        self.rho_list: List[float] = rho_list
        self.alpha: float = alpha

        if subproblem_solver is None:
            subproblem_solver = GenericSubproblemSolver()
        self.subproblem_solver: SubproblemSolver = subproblem_solver
        self.subproblem_solver.internal_init(self)

        if x0 is None:
            input_shape = C_list[0].input_shape
            dtype = C_list[0].input_dtype
            x0 = snp.zeros(input_shape, dtype=dtype)
        self.x = x0
        self.z_list, self.z_list_old = self.z_init(self.x)
        self.u_list = self.u_init(self.x)

        super().__init__(**kwargs)

    def _working_vars_finite(self) -> bool:
        """Determine where ``NaN`` of ``Inf`` encountered in solve.

        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in
        a solver working variable.
        """
        for v in (
            [
                self.x,
            ]
            + self.z_list
            + self.u_list
        ):
            if not snp.all(snp.isfinite(v)):
                return False
        return True

    def _objective_evaluatable(self):
        """Determine whether the objective function can be evaluated."""
        return (not self.f or self.f.has_eval) and all([_.has_eval for _ in self.g_list])

    def _itstat_extra_fields(self):
        """Define ADMM-specific iteration statistics fields."""
        itstat_fields = {"Prml Rsdl": "%9.3e", "Dual Rsdl": "%9.3e"}
        itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"]

        # subproblem solver info when available
        if isinstance(self.subproblem_solver, GenericSubproblemSolver):
            itstat_fields.update({"Num FEv": "%6d", "Num It": "%6d"})
            itstat_attrib.extend(
                ["subproblem_solver.info['nfev']", "subproblem_solver.info['nit']"]
            )
        elif (
            type(self.subproblem_solver) == LinearSubproblemSolver
            and self.subproblem_solver.cg_function == "scico"
        ):
            itstat_fields.update({"CG It": "%5d", "CG Res": "%9.3e"})
            itstat_attrib.extend(
                ["subproblem_solver.info['num_iter']", "subproblem_solver.info['rel_res']"]
            )
        elif (
            type(self.subproblem_solver)
            in [MatrixSubproblemSolver, FBlockCircularConvolveSolver, G0BlockCircularConvolveSolver]
            and self.subproblem_solver.check_solve
        ):
            itstat_fields.update({"Slv Res": "%9.3e"})
            itstat_attrib.extend(["subproblem_solver.accuracy"])

        return itstat_fields, itstat_attrib

[docs] def minimizer(self): """Return current estimate of the functional mimimizer.""" return self.x
[docs] def objective( self, x: Optional[Union[Array, BlockArray]] = None, z_list: Optional[List[Union[Array, BlockArray]]] = None, ) -> float: r"""Evaluate the objective function. Evaluate the objective function .. math:: f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \;. Note that this form is cheaper to compute, but may have very poor accuracy compared with the "true" objective function .. math:: f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;. when the primal residual is large. Args: x: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.x`. z_list: Point at which to evaluate objective function. If ``None``, the objective is evaluated at the current iterate :code:`self.z_list`. Returns: Value of the objective function. """ if (x is None) != (z_list is None): raise ValueError("Both or neither of x and z_list must be supplied.") if x is None: x = self.x z_list = self.z_list assert z_list is not None out = 0.0 if self.f: out += self.f(x) for g, z in zip(self.g_list, z_list): out += g(z) return out
[docs] def norm_primal_residual(self, x: Optional[Union[Array, BlockArray]] = None) -> float: r"""Compute the :math:`\ell_2` norm of the primal residual. Compute the :math:`\ell_2` norm of the primal residual .. math:: \left( \sum_{i=1}^N \rho_i \left\| C_i \mb{x} - \mb{z}_i^{(k)} \right\|_2^2\right)^{1/2} \;. Args: x: Point at which to evaluate primal residual. If ``None``, the primal residual is evaluated at the current iterate :code:`self.x`. Returns: Norm of primal residual. """ if x is None: x = self.x sum = 0.0 for rhoi, Ci, zi in zip(self.rho_list, self.C_list, self.z_list): sum += rhoi * norm(Ci(self.x) - zi) ** 2 return snp.sqrt(sum)
[docs] def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. Compute the :math:`\ell_2` norm of the dual residual .. math:: \left\| \sum_{i=1}^N \rho_i C_i^T \left( \mb{z}^{(k)}_i - \mb{z}^{(k-1)}_i \right) \right\|_2 \;. Returns: Norm of dual residual. """ sum = 0.0 for rhoi, zi, ziold, Ci in zip(self.rho_list, self.z_list, self.z_list_old, self.C_list): sum += rhoi * Ci.adj(zi - ziold) return norm(sum)
[docs] def z_init( self, x0: Union[Array, BlockArray] ) -> Tuple[List[Union[Array, BlockArray]], List[Union[Array, BlockArray]]]: r"""Initialize auxiliary variables :math:`\mb{z}_i`. Initialized to .. math:: \mb{z}_i = C_i \mb{x}^{(0)} \;. :code:`z_list` and :code:`z_list_old` are initialized to the same value. Args: x0: Initial value of :math:`\mb{x}`. """ z_list: List[Union[Array, BlockArray]] = [Ci(x0) for Ci in self.C_list] z_list_old = z_list.copy() return z_list, z_list_old
[docs] def u_init(self, x0: Union[Array, BlockArray]) -> List[Union[Array, BlockArray]]: r"""Initialize scaled Lagrange multipliers :math:`\mb{u}_i`. Initialized to .. math:: \mb{u}_i = \mb{0} \;. Note that the parameter `x0` is unused, but is provided for potential use in an overridden method. Args: x0: Initial value of :math:`\mb{x}`. """ u_list = [snp.zeros(Ci.output_shape, dtype=Ci.output_dtype) for Ci in self.C_list] return u_list
[docs] def step(self): r"""Perform a single ADMM iteration. The primary variable :math:`\mb{x}` is updated by solving the the optimization problem .. math:: \mb{x}^{(k+1)} = \argmin_{\mb{x}} \; f(\mb{x}) + \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;. Update auxiliary variables :math:`\mb{z}_i` and scaled Lagrange multipliers :math:`\mb{u}_i`. The auxiliary variables are updated according to .. math:: \begin{aligned} \mb{z}_i^{(k+1)} &= \argmin_{\mb{z}_i} \; g_i(\mb{z}_i) + \frac{\rho_i}{2} \norm{\mb{z}_i - \mb{u}^{(k)}_i - C_i \mb{x}^{(k+1)}}_2^2 \\ &= \mathrm{prox}_{g_i}(C_i \mb{x} + \mb{u}_i, 1 / \rho_i) \;, \end{aligned} and the scaled Lagrange multipliers are updated according to .. math:: \mb{u}_i^{(k+1)} = \mb{u}_i^{(k)} + C_i \mb{x}^{(k+1)} - \mb{z}^{(k+1)}_i \;. """ self.x = self.subproblem_solver.solve(self.x) self.z_list_old = self.z_list.copy() for i, (rhoi, gi, Ci, zi, ui) in enumerate( zip(self.rho_list, self.g_list, self.C_list, self.z_list, self.u_list) ): if self.alpha == 1.0: Cix = Ci(self.x) else: Cix = self.alpha * Ci(self.x) + (1.0 - self.alpha) * zi zi = gi.prox(Cix + ui, 1 / rhoi, v0=zi) ui = ui + Cix - zi self.z_list[i] = zi self.u_list[i] = ui