# -*- coding: utf-8 -*-
# Copyright (C) 2020-2026 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Functionals that are norms."""
from functools import partial
from typing import Optional, Tuple, Union
from jax import jit, lax
from scico import numpy as snp
from scico.numpy import Array, BlockArray, count_nonzero
from scico.numpy.linalg import norm
from scico.numpy.util import no_nan_divide
from ._functional import Functional
class L0Norm(Functional):
r"""The :math:`\ell_0` 'norm'.
The :math:`\ell_0` 'norm' counts the number of non-zero elements in
an array.
"""
has_eval = True
has_prox = True
@staticmethod
@jit
def __call__(x: Union[Array, BlockArray]) -> float:
return count_nonzero(x)
[docs]
@staticmethod
@jit
def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:
r"""Evaluate scaled proximal operator of :math:`\ell_0` norm.
Evaluate scaled proximal operator of :math:`\ell_0` norm using
.. math::
\left[ \prox_{\lambda\| \cdot \|_0}(\mb{v}) \right]_i =
\begin{cases}
v_i & \text{ if } \abs{v_i} \geq \lambda \\
0 & \text{ otherwise } \;.
\end{cases}
Args:
v: Input array :math:`\mb{v}`.
lam: Thresholding parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes.
Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return snp.where(snp.abs(v) >= lam, v, 0)
class L1Norm(Functional):
r"""The :math:`\ell_1` norm.
Computes
.. math::
\norm{\mb{x}}_1 = \sum_i \abs{x_i}^2 \;.
"""
has_eval = True
has_prox = True
@staticmethod
@jit
def __call__(x: Union[Array, BlockArray]) -> float:
return snp.sum(snp.abs(x))
[docs]
@staticmethod
@jit
def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Array:
r"""Evaluate scaled proximal operator of :math:`\ell_1` norm.
Evaluate scaled proximal operator of :math:`\ell_1` norm using
.. math::
\left[ \prox_{\lambda \|\cdot\|_1}(\mb{v}) \right]_i =
\sign(v_i) (\abs{v_i} - \lambda)_+ \;,
where
.. math::
(x)_+ = \begin{cases}
x & \text{ if } x \geq 0 \\
0 & \text{ otherwise} \;.
\end{cases}
Args:
v: Input array :math:`\mb{v}`.
lam: Thresholding parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes.
Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
tmp = snp.abs(v) - lam
tmp = 0.5 * (tmp + snp.abs(tmp))
if snp.util.is_complex_dtype(v.dtype):
out = snp.exp(1j * snp.angle(v)) * tmp
else:
out = snp.sign(v) * tmp
return out
class SquaredL2Norm(Functional):
r"""The squared :math:`\ell_2` norm.
Squared :math:`\ell_2` norm
.. math::
\norm{\mb{x}}^2_2 = \sum_i \abs{x_i}^2 \;.
"""
has_eval = True
has_prox = True
@staticmethod
@jit
def __call__(x: Union[Array, BlockArray]) -> float:
# Directly implement the squared l2 norm to avoid nondifferentiable
# behavior of snp.norm(x) at 0.
return snp.sum(snp.abs(x) ** 2)
[docs]
@staticmethod
@jit
def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of squared :math:`\ell_2` norm.
Evaluate proximal operator of squared :math:`\ell_2` norm using
.. math::
\prox_{\lambda \| \cdot \|_2^2}(\mb{v})
= \frac{\mb{v}}{1 + 2 \lambda} \;.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes.
Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return v / (1.0 + 2.0 * lam)
class L2Norm(Functional):
r"""The :math:`\ell_2` norm.
.. math::
\norm{\mb{x}}_2 = \sqrt{\sum_i \abs{x_i}^2} \;.
"""
has_eval = True
has_prox = True
@staticmethod
@jit
def __call__(x: Union[Array, BlockArray]) -> float:
return norm(x)
[docs]
@staticmethod
@jit
def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of :math:`\ell_2` norm.
Evaluate proximal operator of :math:`\ell_2` norm using
.. math::
\prox_{\lambda \| \cdot \|_2}(\mb{v}) = \mb{v} \,
\left(1 - \frac{\lambda}{\norm{\mb{v}}_2} \right)_+ \;,
where
.. math::
(x)_+ = \begin{cases}
x & \text{ if } x \geq 0 \\
0 & \text{ otherwise} \;.
\end{cases}
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes.
Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
norm_v = norm(v)
return snp.where(norm_v == 0, 0 * v, snp.maximum(1 - lam / norm_v, 0) * v)
class L21Norm(Functional):
r"""The :math:`\ell_{2,1}` norm.
For a :math:`M \times N` matrix, :math:`\mb{A}`, by default,
.. math::
\norm{\mb{A}}_{2,1} = \sum_{n=1}^N \sqrt{\sum_{m=1}^M
\abs{A_{m,n}}^2} \;.
The norm generalizes to more dimensions by first computing the
:math:`\ell_2` norm along one or more (user-specified) axes,
followed by a sum over all remaining axes. :class:`.BlockArray` inputs
require parameter `l2_axis` to be ``None``, in which case the
:math:`\ell_2` norm is computed over each block.
A typical use case is computing the isotropic total variation norm.
"""
has_eval = True
has_prox = True
def __init__(self, l2_axis: Union[None, int, Tuple] = 0):
r"""
Args:
l2_axis: Axis/axes over which to take the l2 norm. Required
to be ``None`` for :class:`.BlockArray` inputs to be
supported.
"""
self.l2_axis = l2_axis
@staticmethod
@partial(jit, static_argnames=("axis", "keepdims"))
def _l2norm(
x: Union[Array, BlockArray], axis: Union[None, int, Tuple], keepdims: Optional[bool] = False
) -> Union[Array, BlockArray]:
r"""Return the :math:`\ell_2` norm of an array."""
return snp.sqrt((snp.abs(x) ** 2).sum(axis=axis, keepdims=keepdims))
[docs]
def __call__(self, x: Union[Array, BlockArray]) -> float:
if isinstance(x, snp.BlockArray) and self.l2_axis is not None:
raise ValueError("Initializer argument 'l2_axis' must be None for BlockArray input.")
l2 = L21Norm._l2norm(x, axis=self.l2_axis)
return snp.sum(snp.abs(l2))
@staticmethod
@partial(jit, static_argnames=("axis"))
def _prox(
v: Union[Array, BlockArray], lam: float, axis: Union[None, int, Tuple]
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of the :math:`\ell_{2,1}` norm."""
length = L21Norm._l2norm(v, axis=axis, keepdims=True)
direction = no_nan_divide(v, length)
new_length = length - lam
# set negative values to zero without `if`
new_length = 0.5 * (new_length + snp.abs(new_length))
return new_length * direction
[docs]
def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of the :math:`\ell_{2,1}` norm.
In two dimensions,
.. math::
\prox_{\lambda \|\cdot\|_{2,1}}(\mb{v}, \lambda)_{:, n} =
\frac{\mb{v}_{:, n}}{\|\mb{v}_{:, n}\|_2}
(\|\mb{v}_{:, n}\|_2 - \lambda)_+ \;,
where
.. math::
(x)_+ = \begin{cases}
x & \text{ if } x \geq 0 \\
0 & \text{ otherwise} \;.
\end{cases}
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes.
Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
if isinstance(v, snp.BlockArray) and self.l2_axis is not None:
raise ValueError("Initializer argument 'l2_axis' must be None for BlockArray input.")
return L21Norm._prox(v, lam=lam, axis=self.l2_axis)
class L1MinusL2Norm(Functional):
r"""Difference of :math:`\ell_1` and :math:`\ell_2` norms.
Difference of :math:`\ell_1` and :math:`\ell_2` norms
.. math::
\norm{\mb{x}}_1 - \beta * \norm{\mb{x}}_2
"""
has_eval = True
has_prox = True
def __init__(self, beta: float = 1.0):
r"""
Args:
beta: Parameter :math:`\beta` in the norm definition.
"""
self.beta = beta
@staticmethod
@jit
def _l1minusl2norm(x: Union[Array, BlockArray], beta: float) -> float:
r"""Return the :math:`\ell_1 - \ell_2` norm of an array."""
return snp.sum(snp.abs(x)) - beta * norm(x)
[docs]
def __call__(self, x: Union[Array, BlockArray]) -> float:
return L1MinusL2Norm._l1minusl2norm(x, self.beta)
@staticmethod
def _prox_vamx_ge_thresh(v, va, vs, alpha, beta):
u = snp.zeros(v.shape, dtype=v.dtype)
idx = va.ravel().argmax()
u = (
u.ravel().at[idx].set((va.ravel()[idx] + (beta - 1.0) * alpha) * vs.ravel()[idx])
).reshape(v.shape)
return u
@staticmethod
def _prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta):
return snp.where(
vamx < (1.0 - beta) * alpha,
snp.zeros(v.shape, dtype=v.dtype),
L1MinusL2Norm._prox_vamx_ge_thresh(v, va, vs, alpha, beta),
)
@staticmethod
def _prox_vamx_gt_alpha(v, va, vs, alpha, beta):
u = snp.maximum(va - alpha, 0.0) * vs
l2u = norm(u)
u *= (l2u + alpha * beta) / l2u
return u
@staticmethod
def _prox_vamx_gt_0(v, va, vs, vamx, alpha, beta):
return snp.where(
vamx > alpha,
L1MinusL2Norm._prox_vamx_gt_alpha(v, va, vs, alpha, beta),
L1MinusL2Norm._prox_vamx_le_alpha(v, va, vs, vamx, alpha, beta),
)
@staticmethod
@jit
def _prox(v: Union[Array, BlockArray], lam: float, beta: float) -> Union[Array, BlockArray]:
r"""Proximal operator of :math:`\ell_1 - \ell_2` norm."""
alpha = lam
va = snp.abs(v)
vamx = snp.max(va)
if snp.util.is_complex_dtype(v.dtype):
vs = snp.exp(1j * snp.angle(v))
else:
vs = snp.sign(v)
return snp.where(
vamx > 0.0,
L1MinusL2Norm._prox_vamx_gt_0(v, va, vs, vamx, alpha, beta),
snp.zeros(v.shape, dtype=v.dtype),
)
[docs]
def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Proximal operator of difference of :math:`\ell_1` and :math:`\ell_2` norms.
Evaluate the proximal operator of the difference of :math:`\ell_1`
and :math:`\ell_2` norms, i.e. :math:`\alpha \left( \| \mb{x}
\|_1 - \beta \| \mb{x} \|_2 \right)` :cite:`lou-2018-fast`. Note
that this is not a proximal operator according to the strict
definition since the loss function is non-convex.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes.
Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return L1MinusL2Norm._prox(v, lam=lam, beta=self.beta)
class HuberNorm(Functional):
r"""Huber norm.
Compute a norm based on the Huber function :cite:`huber-1964-robust`
:cite:`beck-2017-first` (Sec. 6.7.1). In the non-separable case the
norm is
.. math::
H_{\delta}(\mb{x}) = \begin{cases}
(1/2) \norm{ \mb{x} }_2^2 & \text{ when } \norm{ \mb{x} }_2
\leq \delta \\
\delta \left( \norm{ \mb{x} }_2 - (\delta / 2) \right) &
\text{ when } \norm{ \mb{x} }_2 > \delta \;,
\end{cases}
where :math:`\delta` is a parameter controlling the transitions
between :math:`\ell_1`-norm like and :math:`\ell_2`-norm like
behavior. In the separable case the norm is
.. math::
H_{\delta}(\mb{x}) = \sum_i h_{\delta}(x_i) \,,
where
.. math::
h_{\delta}(x) = \begin{cases}
(1/2) \abs{ x }^2 & \text{ when } \abs{ x } \leq \delta \\
\delta \left( \abs{ x } - (\delta / 2) \right) &
\text{ when } \abs{ x } > \delta \;.
\end{cases}
"""
has_eval = True
has_prox = True
def __init__(self, delta: float = 1.0, separable: bool = True):
r"""
Args:
delta: Huber function parameter :math:`\delta`.
separable: Flag indicating whether to compute separable or
non-separable form.
"""
self.delta = delta
self.separable = separable
if separable:
self._call = self._call_sep
self._prox = self._prox_sep
else:
self._call = self._call_nonsep
self._prox = self._prox_nonsep
super().__init__()
@staticmethod
@jit
def _call_sep(x: Union[Array, BlockArray], delta: float) -> float:
xabs = snp.abs(x)
hx = snp.where(xabs <= delta, 0.5 * xabs**2, delta * (xabs - (delta / 2.0)))
return snp.sum(hx)
@staticmethod
@jit
def _call_nonsep(x: Union[Array, BlockArray], delta: float) -> float:
xl2 = snp.linalg.norm(x)
return lax.cond(
xl2 <= delta, lambda xl2: 0.5 * xl2**2, lambda xl2: delta * (xl2 - delta / 2.0), xl2
)
[docs]
def __call__(self, x: Union[Array, BlockArray]) -> float:
return self._call(x, self.delta)
@staticmethod
@jit
def _prox_sep(
v: Union[Array, BlockArray], lam: float, delta: float
) -> Union[Array, BlockArray]:
den = snp.maximum(snp.abs(v), delta * (1.0 + lam))
return (1.0 - ((delta * lam) / den)) * v
@staticmethod
@jit
def _prox_nonsep(
v: Union[Array, BlockArray], lam: float, delta: float
) -> Union[Array, BlockArray]:
vl2 = snp.linalg.norm(v)
den = snp.maximum(vl2, delta * (1.0 + lam))
return (1.0 - ((delta * lam) / den)) * v
[docs]
def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of the Huber function.
Evaluate scaled proximal operator of the Huber function
:cite:`beck-2017-first` (Sec. 6.7.3). The prox is
.. math::
\prox_{\lambda H_{\delta}} (\mb{v}) = \left( 1 -
\frac{\lambda \delta} {\max\left\{\norm{\mb{v}}_2,
\delta + \lambda \delta\right\} } \right) \mb{v}
in the non-separable case, and
.. math::
\left[ \prox_{\lambda H_{\delta}} (\mb{v}) \right]_i =
\left( 1 - \frac{\lambda \delta} {\max\left\{\abs{v_i},
\delta + \lambda \delta\right\} } \right) v_i
in the separable case.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes.
Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
return self._prox(v, lam=lam, delta=self.delta)
class NuclearNorm(Functional):
r"""Nuclear norm.
Compute the nuclear norm
.. math::
\| X \|_* = \sum_i \sigma_i
where :math:`\sigma_i` are the singular values of matrix :math:`X`.
"""
has_eval = True
has_prox = True
@staticmethod
@jit
def __call__(x: Union[Array, BlockArray]) -> float:
if x.ndim != 2:
raise ValueError("Input array must be two dimensional.")
return snp.sum(snp.linalg.svd(x, full_matrices=False, compute_uv=False))
[docs]
@staticmethod
@jit
def prox(v: Union[Array, BlockArray], lam: float = 1.0, **kwargs) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of the nuclear norm.
Evaluate proximal operator of the nuclear norm
:cite:`cai-2010-singular`.
Args:
v: Input array :math:`\mb{v}`. Required to be two-dimensional.
lam: Proximal parameter :math:`\lambda`.
**kwargs: Additional arguments that may be used by derived
classes.
Returns:
Result of evaluating the scaled proximal operator at `v`.
"""
if v.ndim != 2:
raise ValueError("Input array must be two dimensional.")
svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False)
svdS = snp.maximum(0, svdS - lam)
return svdU @ snp.diag(svdS) @ svdV