# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Functionals that are norms."""
from typing import Optional, Tuple, Union
from jax import jit, lax
from scico import numpy as snp
from scico.numpy import Array, BlockArray, count_nonzero
from scico.numpy.linalg import norm
from scico.numpy.util import no_nan_divide
from ._functional import Functional
class L0Norm(Functional):
r"""The :math:`\ell_0` 'norm'.
The :math:`\ell_0` 'norm' counts the number of non-zero elements in
an array.
"""
has_eval = True
has_prox = True
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
return count_nonzero(x)
[docs] @staticmethod
@jit
def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:
r"""Evaluate scaled proximal operator of :math:`\ell_0` norm.
Evaluate scaled proximal operator of :math:`\ell_0` norm using
.. math::
\left[ \prox_{\lambda\| \cdot \|_0}(\mb{v}) \right]_i =
\begin{cases}
v_i & \text{ if } \abs{v_i} \geq \lambda \\
0 & \text{ otherwise } \;.
\end{cases}
Args:
v: Input array :math:`\mb{v}`.
lam: Thresholding parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
return snp.where(snp.abs(v) >= lam, v, 0)
class L1Norm(Functional):
r"""The :math:`\ell_1` norm.
Computes
.. math::
\norm{\mb{x}}_1 = \sum_i \abs{x_i}^2 \;.
"""
has_eval = True
has_prox = True
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
return snp.sum(snp.abs(x))
[docs] @staticmethod
def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Array:
r"""Evaluate scaled proximal operator of :math:`\ell_1` norm.
Evaluate scaled proximal operator of :math:`\ell_1` norm using
.. math::
\left[ \prox_{\lambda \|\cdot\|_1}(\mb{v}) \right]_i =
\sign(v_i) (\abs{v_i} - \lambda)_+ \;,
where
.. math::
(x)_+ = \begin{cases}
x & \text{ if } x \geq 0 \\
0 & \text{ otherwise} \;.
\end{cases}
Args:
v: Input array :math:`\mb{v}`.
lam: Thresholding parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
tmp = snp.abs(v) - lam
tmp = 0.5 * (tmp + snp.abs(tmp))
if snp.util.is_complex_dtype(v.dtype):
out = snp.exp(1j * snp.angle(v)) * tmp
else:
out = snp.sign(v) * tmp
return out
class SquaredL2Norm(Functional):
r"""The squared :math:`\ell_2` norm.
Squared :math:`\ell_2` norm
.. math::
\norm{\mb{x}}^2_2 = \sum_i \abs{x_i}^2 \;.
"""
has_eval = True
has_prox = True
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
# Directly implement the squared l2 norm to avoid nondifferentiable
# behavior of snp.norm(x) at 0.
return snp.sum(snp.abs(x) ** 2)
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of squared :math:`\ell_2` norm.
Evaluate proximal operator of squared :math:`\ell_2` norm using
.. math::
\prox_{\lambda \| \cdot \|_2^2}(\mb{v})
= \frac{\mb{v}}{1 + 2 \lambda} \;.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
return v / (1.0 + 2.0 * lam)
class L2Norm(Functional):
r"""The :math:`\ell_2` norm.
.. math::
\norm{\mb{x}}_2 = \sqrt{\sum_i \abs{x_i}^2} \;.
"""
has_eval = True
has_prox = True
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
return norm(x)
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of :math:`\ell_2` norm.
Evaluate proximal operator of :math:`\ell_2` norm using
.. math::
\prox_{\lambda \| \cdot \|_2}(\mb{v}) = \mb{v} \,
\left(1 - \frac{\lambda}{\norm{\mb{v}}_2} \right)_+ \;,
where
.. math::
(x)_+ = \begin{cases}
x & \text{ if } x \geq 0 \\
0 & \text{ otherwise} \;.
\end{cases}
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
norm_v = norm(v)
return snp.where(norm_v == 0, 0 * v, snp.maximum(1 - lam / norm_v, 0) * v)
class L21Norm(Functional):
r"""The :math:`\ell_{2,1}` norm.
For a :math:`M \times N` matrix, :math:`\mb{A}`, by default,
.. math::
\norm{\mb{A}}_{2,1} = \sum_{n=1}^N \sqrt{\sum_{m=1}^M
\abs{A_{m,n}}^2} \;.
The norm generalizes to more dimensions by first computing the
:math:`\ell_2` norm along one or more (user-specified) axes,
followed by a sum over all remaining axes.
For `BlockArray` inputs, the :math:`\ell_2` norm follows the
reduction rules described in :class:`BlockArray`.
A typical use case is computing the isotropic total variation norm.
"""
has_eval = True
has_prox = True
def __init__(self, l2_axis: Union[int, Tuple] = 0):
r"""
Args:
l2_axis: Axis/axes over which to take the l2 norm. Default: 0.
"""
self.l2_axis = l2_axis
@staticmethod
def _l2norm(
x: Union[Array, BlockArray], axis: Union[int, Tuple], keepdims: Optional[bool] = False
):
r"""Return the :math:`\ell_2` norm of an array."""
return snp.sqrt(snp.sum(snp.abs(x) ** 2, axis=axis, keepdims=keepdims))
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
l2 = L21Norm._l2norm(x, axis=self.l2_axis)
return snp.abs(l2).sum()
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of the :math:`\ell_{2,1}` norm.
In two dimensions,
.. math::
\prox_{\lambda \|\cdot\|_{2,1}}(\mb{v}, \lambda)_{:, n} =
\frac{\mb{v}_{:, n}}{\|\mb{v}_{:, n}\|_2}
(\|\mb{v}_{:, n}\|_2 - \lambda)_+ \;,
where
.. math::
(x)_+ = \begin{cases}
x & \text{ if } x \geq 0 \\
0 & \text{ otherwise} \;.
\end{cases}
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
length = L21Norm._l2norm(v, axis=self.l2_axis, keepdims=True)
direction = no_nan_divide(v, length)
new_length = length - lam
# set negative values to zero without `if`
new_length = 0.5 * (new_length + snp.abs(new_length))
return new_length * direction
class L1MinusL2Norm(Functional):
r"""Difference of :math:`\ell_1` and :math:`\ell_2` norms.
Difference of :math:`\ell_1` and :math:`\ell_2` norms
.. math::
\norm{\mb{x}}_1 - \beta * \norm{\mb{x}}_2
"""
has_eval = True
has_prox = True
def __init__(self, beta: float = 1.0):
r"""
Args:
beta: Parameter :math:`\beta` in the norm definition.
"""
self.beta = beta
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
return snp.sum(snp.abs(x)) - self.beta * norm(x)
@staticmethod
def _prox_vamx_ge_thresh(v, va, vs, alpha, beta):
u = snp.zeros(v.shape, dtype=v.dtype)
idx = va.ravel().argmax()
u = (
u.ravel().at[idx].set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx])
).reshape(v.shape)
return u
@staticmethod
def _prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta):
return snp.where(
vamx < (1.0 - beta) * alpha,
snp.zeros(v.shape, dtype=v.dtype),
L1MinusL2Norm._prox_vamx_ge_thresh(v, va, vs, alpha, beta),
)
@staticmethod
def _prox_vamx_gt_alpha(v, va, vs, alpha, beta):
u = snp.maximum(va - alpha, 0.0) * vs
l2u = norm(u)
u *= (l2u + alpha * beta) / l2u
return u
@staticmethod
def _prox_vamx_gt_0(v, va, vs, vamx, alpha, beta):
return snp.where(
vamx > alpha,
L1MinusL2Norm._prox_vamx_gt_alpha(v, va, vs, alpha, beta),
L1MinusL2Norm._prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta),
)
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms.
Evaluate the proximal operator of the difference of :math:`\ell_1`
and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x}
\|_1 - \beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note
that this is not a proximal operator according to the strict
definition since the loss function is non-convex.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
alpha = lam
beta = self.beta
va = snp.abs(v)
vamx = snp.max(va)
if snp.util.is_complex_dtype(v.dtype):
vs = snp.exp(1j * snp.angle(v))
else:
vs = snp.sign(v)
return snp.where(
vamx > 0.0,
L1MinusL2Norm._prox_vamx_gt_0(v, va, vs, vamx, alpha, beta),
snp.zeros(v.shape, dtype=v.dtype),
)
class HuberNorm(Functional):
r"""Huber norm.
Compute a norm based on the Huber function :cite:`huber-1964-robust`
:cite:`beck-2017-first` (Sec. 6.7.1). In the non-separable case the
norm is
.. math::
H_{\delta}(\mb{x}) = \begin{cases}
(1/2) \norm{ \mb{x} }_2^2 & \text{ when } \norm{ \mb{x} }_2
\leq \delta \\
\delta \left( \norm{ \mb{x} }_2 - (\delta / 2) \right) &
\text{ when } \norm{ \mb{x} }_2 > \delta \;,
\end{cases}
where :math:`\delta` is a parameter controlling the transitions
between :math:`\ell_1`-norm like and :math:`\ell_2`-norm like
behavior. In the separable case the norm is
.. math::
H_{\delta}(\mb{x}) = \sum_i h_{\delta}(x_i) \,,
where
.. math::
h_{\delta}(x) = \begin{cases}
(1/2) \abs{ x }^2 & \text{ when } \abs{ x } \leq \delta \\
\delta \left( \abs{ x } - (\delta / 2) \right) &
\text{ when } \abs{ x } > \delta \;.
\end{cases}
"""
has_eval = True
has_prox = True
def __init__(self, delta: float = 1.0, separable: bool = True):
r"""
Args:
delta: Huber function parameter :math:`\delta`.
separable: Flag indicating whether to compute separable or
non-separable form.
"""
self.delta = delta
self.separable = separable
if separable:
self._call = self._call_sep
self._prox = self._prox_sep
else:
self._call_lt_branch = lambda xl2: 0.5 * xl2**2
self._call_gt_branch = lambda xl2: self.delta * (xl2 - self.delta / 2.0)
self._call = self._call_nonsep
self._prox = self._prox_nonsep
super().__init__()
def _call_sep(self, x: Union[Array, BlockArray]) -> float:
xabs = snp.abs(x)
hx = snp.where(
xabs <= self.delta, 0.5 * xabs**2, self.delta * (xabs - (self.delta / 2.0))
)
return snp.sum(hx)
def _call_nonsep(self, x: Union[Array, BlockArray]) -> float:
xl2 = snp.linalg.norm(x)
return lax.cond(xl2 <= self.delta, self._call_lt_branch, self._call_gt_branch, xl2)
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
return self._call(x)
def _prox_sep(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
den = snp.maximum(snp.abs(v), self.delta * (1.0 + lam))
return (1 - ((self.delta * lam) / den)) * v
def _prox_nonsep(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
vl2 = snp.linalg.norm(v)
den = snp.maximum(vl2, self.delta * (1.0 + lam))
return (1 - ((self.delta * lam) / den)) * v
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of the Huber function.
Evaluate scaled proximal operator of the Huber function
:cite:`beck-2017-first` (Sec. 6.7.3). The prox is
.. math::
\prox_{\lambda H_{\delta}} (\mb{v}) = \left( 1 -
\frac{\lambda \delta} {\max\left\{\norm{\mb{v}}_2,
\delta + \lambda \delta\right\} } \right) \mb{v}
in the non-separable case, and
.. math::
\left[ \prox_{\lambda H_{\delta}} (\mb{v}) \right]_i =
\left( 1 - \frac{\lambda \delta} {\max\left\{\abs{v_i},
\delta + \lambda \delta\right\} } \right) v_i
in the separable case.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
return self._prox(v, lam=lam, **kwargs)
class NuclearNorm(Functional):
r"""Nuclear norm.
Compute the nuclear norm
.. math::
\| X \|_* = \sum_i \sigma_i
where :math:`\sigma_i` are the singular values of matrix :math:`X`.
"""
has_eval = True
has_prox = True
[docs] def __call__(self, x: Union[Array, BlockArray]) -> float:
return snp.sum(snp.linalg.svd(x, full_matrices=False, compute_uv=False))
[docs] def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of the nuclear norm.
Evaluate proximal operator of the nuclear norm
:cite:`cai-2010-singular`.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.
"""
svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False)
svdS = snp.maximum(0, svdS - lam)
return svdU @ snp.diag(svdS) @ svdV