Source code for scico.functional._indicator

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Functionals that are indicator functions/constraints."""

from typing import Union

import jax

from scico import numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.linalg import norm

from ._functional import Functional


class NonNegativeIndicator(Functional):
    r"""Indicator function for non-negative orthant.

    Returns 0 if all elements of input array-like are non-negative, and
    `inf` otherwise

    .. math::
        I(\mb{x}) = \begin{cases}
        0  & \text{ if } x_i \geq 0 \; \forall i \\
        \infty  & \text{ otherwise} \;.
        \end{cases}
    """

    has_eval = True
    has_prox = True

[docs] def __call__(self, x: Union[Array, BlockArray]) -> float: if snp.util.is_complex_dtype(x.dtype): raise ValueError("Not defined for complex input.") # Equivalent to snp.inf if snp.any(x < 0) else 0.0 return jax.lax.cond(snp.any(x < 0), lambda x: snp.inf, lambda x: 0.0, None)
[docs] def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""The scaled proximal operator of the non-negative indicator. Evaluate the scaled proximal operator of the indicator over the non-negative orthant, :math:`I`, .. math:: [\mathrm{prox}_{\lambda I}(\mb{v})]_i = \begin{cases} v_i\, & \text{ if } v_i \geq 0 \\ 0\, & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda` (has no effect). kwargs: Additional arguments that may be used by derived classes. """ return snp.maximum(v, 0)
class L2BallIndicator(Functional): r"""Indicator function for :math:`\ell_2` ball of given radius. Indicator function for :math:`\ell_2` ball of given radius, :math:`r` .. math:: I(\mb{x}) = \begin{cases} 0 & \text{ if } \norm{\mb{x}}_2 \leq r \\ \infty & \text{ otherwise} \;. \end{cases} Attributes: radius: Radius of :math:`\ell_2` ball. """ has_eval = True has_prox = True def __init__(self, radius: float = 1): r"""Initialize a :class:`L2BallIndicator` object. Args: radius: Radius of :math:`\ell_2` ball. Default: 1. """ self.radius = radius super().__init__()
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float: # Equivalent to # snp.inf if norm(x) > self.radius else 0.0 return jax.lax.cond(norm(x) > self.radius, lambda x: snp.inf, lambda x: 0.0, None)
[docs] def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""The scaled proximal operator of the :math:`\ell_2` ball indicator. a :math:`\ell_2` ball Evaluate the scaled proximal operator of the indicator, :math:`I`, of the :math:`\ell_2` ball with radius :math:`r` .. math:: \mathrm{prox}_{\lambda I}(\mb{v}) = r \frac{\mb{v}}{\norm{\mb{v}}_2}\;. """ return self.radius * v / norm(v)