# -*- coding: utf-8 -*-# Copyright (C) 2020-2023 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Functionals that are indicator functions/constraints."""fromtypingimportUnionimportjaxfromscicoimportnumpyassnpfromscico.numpyimportArray,BlockArrayfromscico.numpy.linalgimportnormfrom._functionalimportFunctionalclassNonNegativeIndicator(Functional):r"""Indicator function for non-negative orthant. Returns 0 if all elements of input array-like are non-negative, and `inf` otherwise .. math:: I(\mb{x}) = \begin{cases} 0 & \text{ if } x_i \geq 0 \; \forall i \\ \infty & \text{ otherwise} \;. \end{cases} """has_eval=Truehas_prox=True
[docs]def__call__(self,x:Union[Array,BlockArray])->float:ifsnp.util.is_complex_dtype(x.dtype):raiseValueError("Not defined for complex input.")# Equivalent to snp.inf if snp.any(x < 0) else 0.0returnjax.lax.cond(snp.any(x<0),lambdax:snp.inf,lambdax:0.0,None)
[docs]defprox(self,v:Union[Array,BlockArray],lam:float=1.0,**kwargs)->Union[Array,BlockArray]:r"""The scaled proximal operator of the non-negative indicator. Evaluate the scaled proximal operator of the indicator over the non-negative orthant, :math:`I`, .. math:: [\mathrm{prox}_{\lambda I}(\mb{v})]_i = \begin{cases} v_i\, & \text{ if } v_i \geq 0 \\ 0\, & \text{ otherwise} \;. \end{cases} Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lambda` (has no effect). kwargs: Additional arguments that may be used by derived classes. """returnsnp.maximum(v,0)
classL2BallIndicator(Functional):r"""Indicator function for :math:`\ell_2` ball of given radius. Indicator function for :math:`\ell_2` ball of given radius, :math:`r` .. math:: I(\mb{x}) = \begin{cases} 0 & \text{ if } \norm{\mb{x}}_2 \leq r \\ \infty & \text{ otherwise} \;. \end{cases} Attributes: radius: Radius of :math:`\ell_2` ball. """has_eval=Truehas_prox=Truedef__init__(self,radius:float=1):r"""Initialize a :class:`L2BallIndicator` object. Args: radius: Radius of :math:`\ell_2` ball. Default: 1. """self.radius=radiussuper().__init__()
[docs]def__call__(self,x:Union[Array,BlockArray])->float:# Equivalent to# snp.inf if norm(x) > self.radius else 0.0returnjax.lax.cond(norm(x)>self.radius,lambdax:snp.inf,lambdax:0.0,None)
[docs]defprox(self,v:Union[Array,BlockArray],lam:float=1.0,**kwargs)->Union[Array,BlockArray]:r"""The scaled proximal operator of the :math:`\ell_2` ball indicator. a :math:`\ell_2` ball Evaluate the scaled proximal operator of the indicator, :math:`I`, of the :math:`\ell_2` ball with radius :math:`r` .. math:: \mathrm{prox}_{\lambda I}(\mb{v}) = r \frac{\mb{v}}{\norm{\mb{v}}_2}\;. """returnself.radius*v/norm(v)