Source code for scico.functional._proxavg

# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Implementation of the proximal average method."""

from typing import List, Optional, Union

from scico.numpy import Array, BlockArray, isinf

from ._functional import Functional


class ProximalAverage(Functional):
    """Weighted average of functionals.

    A functional that is composed of a weighted average of functionals.
    All of the component functionals are required to have proximal
    operators. The proximal operator of the composite functional is
    approximated via the proximal average method :cite:`yu-2013-better`,
    which holds for small scaling parameters. This does not imply that it
    can only be applied to problems requiring a small regularization
    parameter since most proximal algorithms include an additional
    algorithm parameter that also plays a role in the parameter of the
    proximal operator. For example, in :class:`.PGM` and
    :class:`.AcceleratedPGM`, the scaled proximal operator parameter
    is the regularization parameter divided by the `L0` algorithm
    parameter, and for :class:`.ADMM`, the scaled proximal operator
    parameters are the regularization parameters divided by the entries
    in the `rho_list` algorithm parameter.
    """

    def __init__(
        self,
        func_list: List[Functional],
        alpha_list: Optional[List[float]] = None,
        no_inf_eval=True,
    ):
        """
        Args:
            func_list: List of component :class:`.Functional` objects,
                all of which must have a proximal operator.
            alpha_list: List of scalar weights for each
                :class:`.Functional`. If not specified, defaults to equal
                weights. If specified, the list of weights must have the
                same length as the :class:`.Functional` list. If the
                weights do not sum to unity, they are scaled to ensure
                that they do.
            no_inf_eval: If ``True``, exclude infinite values (typically
                associated with a functional that is an indicator
                function) from the evaluation of the sum of component
                functionals.
        """
        self.has_prox = all([f.has_prox for f in func_list])
        if not self.has_prox:
            raise ValueError("All functionals in func_list must have has_prox == True.")
        self.has_eval = all([f.has_eval for f in func_list])
        self.no_inf_eval = no_inf_eval
        self.func_list = func_list
        N = len(func_list)
        if alpha_list is None:
            self.alpha_list = [1.0 / N] * N
        else:
            if len(alpha_list) != N:
                raise ValueError("If specified, alpha_list must have the same length as func_list")
            alpha_sum = sum(alpha_list)
            if alpha_sum != 1.0:
                alpha_list = [alpha / alpha_sum for alpha in alpha_list]
            self.alpha_list = alpha_list

    def __repr__(self):
        return (
            Functional.__repr__(self)
            + "\n  Weights: "
            + ", ".join([str(alpha) for alpha in self.alpha_list])
            + "\n  Components:\n"
            + "\n".join(["    " + repr(f) for f in self.func_list])
        )

[docs] def __call__(self, x: Union[Array, BlockArray]) -> float: """Evaluate the weighted average of component functionals.""" if self.has_eval: weight_func_vals = [alpha * f(x) for (alpha, f) in zip(self.alpha_list, self.func_list)] if self.no_inf_eval: weight_func_vals = list(filter(lambda x: not isinf(x), weight_func_vals)) return sum(weight_func_vals) else: raise ValueError("At least one functional in func_list has has_eval == False.")
[docs] def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""Approximate proximal operator of the average of functionals. Approximation of the proximal operator of a weighted average of functionals computed via the proximal average method :cite:`yu-2013-better`. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lam`. kwargs: Additional arguments that may be used by derived classes. """ return sum( [ alpha * f.prox(v, lam, **kwargs) for (alpha, f) in zip(self.alpha_list, self.func_list) ] )