Source code for scico.functional._indicator

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2026 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Functionals that are indicator functions/constraints."""

from typing import Union

import jax

from scico import numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.linalg import norm

from ._functional import Functional


class NonNegativeIndicator(Functional):
    r"""Indicator function for non-negative orthant.

    Returns 0 if all elements of input array-like are non-negative, and
    `inf` otherwise

    .. math::
        I(\mb{x}) = \begin{cases}
        0  & \text{ if } x_i \geq 0 \; \forall i \\
        \infty  & \text{ otherwise} \;.
        \end{cases}
    """

    has_eval = True
    has_prox = True

[docs] def __call__(self, x: Union[Array, BlockArray]) -> float: if snp.util.is_complex_dtype(x.dtype): raise ValueError("Not defined for complex input.") # Equivalent to snp.inf if snp.any(x < 0) else 0.0 return jax.lax.cond(snp.any(x < 0), lambda x: snp.inf, lambda x: 0.0, None)
[docs] def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""The scaled proximal operator of the non-negative indicator. Evaluate the scaled proximal operator of the indicator over the non-negative orthant, :math:`I`, .. math:: [\mathrm{prox}_{\lambda I}(\mb{v})]_i = \begin{cases} v_i\, & \text{ if } v_i \geq 0 \\ 0\, & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda` (has no effect). **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return snp.maximum(v, 0)
class L2BallIndicator(Functional): r"""Indicator function for :math:`\ell_2` ball of given radius. Indicator function for :math:`\ell_2` ball of given radius, :math:`r` .. math:: I(\mb{x}) = \begin{cases} 0 & \text{ if } \norm{\mb{x}}_2 \leq r \\ \infty & \text{ otherwise} \;. \end{cases} Attributes: radius: Radius of :math:`\ell_2` ball. """ has_eval = True has_prox = True def __init__(self, radius: float = 1.0): r"""Initialize a :class:`L2BallIndicator` object. Args: radius: Radius of :math:`\ell_2` ball. Default: 1.0. """ self.radius = radius super().__init__()
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float: # Equivalent to: snp.inf if norm(x) > self.radius else 0.0 return jax.lax.cond(norm(x) > self.radius, lambda x: snp.inf, lambda x: 0.0, None)
[docs] def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""The scaled proximal operator of the :math:`\ell_2` ball indicator. Evaluate the scaled proximal operator of the indicator, :math:`I`, of the :math:`\ell_2` ball with radius :math:`r` .. math:: \mathrm{prox}_{\lambda I}(\mb{v}) = \begin{cases} \mb{v} & \text{ if } \norm{\mb{v}}_2 \leq r \\ r \frac{\mb{v}}{\norm{\mb{v}}_2} & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda` (has no effect). **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return jax.lax.cond( norm(v) > self.radius, lambda v: self.radius * v / norm(v), lambda v: v, v )
class BoxIndicator(Functional): r"""Box indicator function.. Indicator function of the constraint set :math:`a \leq x \leq b` for lower and upper bounds :math:`a` and :math:`b` respectively. """ has_eval = True has_prox = True def __init__(self, lb: float = 0.0, ub: float = 1.0): r"""Initialize a :class:`BoxIndicator` object. Args: lb: Lower bound. ub: Upper bound. """ self.lb = lb self.ub = ub
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float: if snp.util.is_complex_dtype(x.dtype): raise ValueError("Not defined for complex input.") constr = snp.logical_and(self.lb <= x, x <= self.ub) return jax.lax.cond(snp.all(constr), lambda x: 0.0, lambda x: snp.inf, None)
[docs] def prox( self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs ) -> Union[Array, BlockArray]: r"""The scaled proximal operator of the box indicator. Evaluate the scaled proximal operator of the constraint set :math:`a \leq x \leq b` for lower and upper bounds :math:`a` and :math:`b` respectively. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda` (has no effect). **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ return snp.clip(v, self.lb, self.ub)