# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Functional base class."""
# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations
from typing import List, Optional, Union
import scico
from scico import numpy as snp
from scico.numpy import Array, BlockArray
class Functional:
r"""Base class for functionals.
A functional maps an :code:`array-like` to a scalar; abstractly, a
functional is a mapping from :math:`\mathbb{R}^n` or
:math:`\mathbb{C}^n` to :math:`\mathbb{R}`.
"""
#: True if this functional can be evaluated, False otherwise.
#: This attribute must be overridden and set to True or False in any derived classes.
has_eval: Optional[bool] = None
#: True if this functional has the prox method, False otherwise.
#: This attribute must be overridden and set to True or False in any derived classes.
has_prox: Optional[bool] = None
def __init__(self):
self._grad = scico.grad(self.__call__)
def __repr__(self):
return f"""{type(self)} (has_eval = {self.has_eval}, has_prox = {self.has_prox})"""
def __mul__(self, other: Union[float, int]) -> ScaledFunctional:
if snp.util.is_scalar_equiv(other):
return ScaledFunctional(self, other)
return NotImplemented
def __rmul__(self, other: Union[float, int]) -> ScaledFunctional:
return self.__mul__(other)
def __add__(self, other: Functional) -> FunctionalSum:
if isinstance(other, Functional):
return FunctionalSum(self, other)
return NotImplemented
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
r"""Evaluate this functional at point :math:`\mb{x}`.
Args:
x: Point at which to evaluate this functional.
"""
# Functionals that can be evaluated should override this method.
raise NotImplementedError(f"Functional {type(self)} cannot be evaluated.")
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Scaled proximal operator of functional.
Evaluate scaled proximal operator of this functional, with
scaling :math:`\lambda` = `lam` and evaluated at point
:math:`\mb{v}` = `v`. The scaled proximal operator is defined as
.. math::
\prox_{\lambda f}(\mb{v}) = \argmin_{\mb{x}}
\lambda f(\mb{x}) +
\frac{1}{2} \norm{\mb{v} - \mb{x}}_2^2\;,
where :math:`\lambda f(\mb{x})` represents this functional evaluated at
:math:`\mb{x}` multiplied by :math:`\lambda`.
Args:
v: Point at which to evaluate prox function.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes. These include `x0`, an initial guess for the
minimizer in the definition of :math:`\prox`.
"""
# Functionals that have a prox should override this method.
raise NotImplementedError(f"Functional {type(self)} does not have a prox.")
[docs] def conj_prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Scaled proximal operator of convex conjugate of functional.
Evaluate scaled proximal operator of convex conjugate (Fenchel
conjugate) of this functional, with scaling
:math:`\lambda` = `lam`, and evaluated at point
:math:`\mb{v}` = `v`. Denoting this functional by :math:`f` and
its convex conjugate by :math:`f^*`, the proximal operator of
:math:`f^*` is computed as follows by exploiting the extended
Moreau decomposition (see Sec. 6.6 of :cite:`beck-2017-first`)
.. math::
\prox_{\lambda f^*}(\mb{v}) = \mb{v} - \lambda \,
\prox_{\lambda^{-1} f}(\mb{v / \lambda}) \;.
Args:
v: Point at which to evaluate prox function.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional keyword args, passed directly to
`self.prox`.
"""
return v - lam * self.prox(v / lam, 1.0 / lam, **kwargs)
[docs] def grad(self, x: Union[Array, BlockArray]):
r"""Evaluates the gradient of this functional at :math:`\mb{x}`.
Args:
x: Point at which to evaluate gradient.
"""
return self._grad(x)
class ScaledFunctional(Functional):
r"""A functional multiplied by a scalar."""
def __init__(self, functional: Functional, scale: float):
self.functional = functional
self.scale = scale
self.has_eval = functional.has_eval
self.has_prox = functional.has_prox
super().__init__()
def __repr__(self):
return (
"Scaled functional of type " + str(type(self.functional)) + f" (scale = {self.scale})"
)
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
return self.scale * self.functional(x)
def __mul__(self, other: Union[float, int]) -> ScaledFunctional:
if snp.util.is_scalar_equiv(other):
return ScaledFunctional(self.functional, other * self.scale)
return NotImplemented
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate the scaled proximal operator of the scaled functional.
Note that, by definition, the scaled proximal operator of a
functional is the proximal operator of the scaled functional. The
scaled proximal operator of a scaled functional is the scaled
proximal operator of the unscaled functional with the proximal
operator scaling consisting of the product of the two scaling
factors, i.e., for functional :math:`f` and scaling factors
:math:`\alpha` and :math:`\beta`, the proximal operator with
scaling parameter :math:`\alpha` of scaled functional
:math:`\beta f` is the proximal operator with scaling parameter
:math:`\alpha \beta` of functional :math:`f`,
.. math::
\prox_{\alpha (\beta f)}(\mb{v}) =
\prox_{(\alpha \beta) f}(\mb{v}) \;.
"""
return self.functional.prox(v, lam * self.scale, **kwargs)
class SeparableFunctional(Functional):
r"""A functional that is separable in its arguments.
A separable functional :math:`f : \mathbb{C}^N \to \mathbb{R}` can
be written as the sum of functionals :math:`f_i : \mathbb{C}^{N_i}
\to \mathbb{R}` with :math:`\sum_i N_i = N`. In particular,
.. math::
f(\mb{x}) = f(\mb{x}_1, \dots, \mb{x}_N) = f_1(\mb{x}_1) + \dots
+ f_N(\mb{x}_N) \;.
A :class:`SeparableFunctional` accepts a :class:`.BlockArray` and is
separable in the block components.
"""
def __init__(self, functional_list: List[Functional]):
r"""
Args:
functional_list: List of component functionals f_i. This
functional takes as an input a :class:`.BlockArray` with
`num_blocks == len(functional_list)`.
"""
self.functional_list: List[Functional] = functional_list
self.has_eval: bool = all(fi.has_eval for fi in functional_list)
self.has_prox: bool = all(fi.has_prox for fi in functional_list)
super().__init__()
[docs] def __call__(self, x: BlockArray) -> float:
if len(x.shape) == len(self.functional_list):
return snp.sum(snp.array([fi(xi) for fi, xi in zip(self.functional_list, x)]))
raise ValueError(
f"Number of blocks in x, {len(x.shape)}, and length of functional_list, "
f"{len(self.functional_list)}, do not match."
)
[docs] def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray:
r"""Evaluate proximal operator of the separable functional.
Evaluate proximal operator of the separable functional (see
Theorem 6.6 of :cite:`beck-2017-first`).
.. math::
\prox_{\lambda f}(\mb{v})
=
\begin{bmatrix}
\prox_{\lambda f_1}(\mb{v}_1) \\ \vdots \\
\prox_{\lambda f_N}(\mb{v}_N) \\
\end{bmatrix} \;.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
if len(v.shape) == len(self.functional_list):
return snp.blockarray(
[fi.prox(vi, lam, **kwargs) for fi, vi in zip(self.functional_list, v)]
)
raise ValueError(
f"Number of blocks in v, {len(v.shape)}, and length of functional_list, "
f"{len(self.functional_list)}, do not match."
)
class FunctionalSum(Functional):
r"""A sum of two functionals."""
def __init__(self, functional1: Functional, functional2: Functional):
self.functional1 = functional1
self.functional2 = functional2
self.has_eval = functional1.has_eval and functional2.has_eval
self.has_prox = False
super().__init__()
def __repr__(self):
return (
"Sum of functionals of types "
+ str(type(self.functional1))
+ " and "
+ str(type(self.functional2))
)
def __call__(self, x: Union[Array, BlockArray]) -> float:
return self.functional1(x) + self.functional2(x)
class ZeroFunctional(Functional):
r"""Zero functional, :math:`f(\mb{x}) = 0 \in \mbb{R}` for any input."""
has_eval = True
has_prox = True
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
return 0.0
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
return v