Source code for scico.functional._tvnorm

# -*- coding: utf-8 -*-
# Copyright (C) 2023-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Total variation norms."""

from typing import Optional, Tuple

from scico import numpy as snp
from scico.linop import (
    CircularConvolve,
    Crop,
    FiniteDifference,
    LinearOperator,
    Pad,
    VerticalStack,
)
from scico.numpy import Array
from scico.typing import DType, Shape

from ._functional import Functional
from ._norm import L1Norm, L21Norm


class TVNorm(Functional):
    r"""Generic total variation (TV) norm.

    Generic total variation (TV) norm with approximation of the scaled
    proximal operator :cite:`kamilov-2016-parallel`
    :cite:`kamilov-2016-minimizing`.
    """

    has_eval = True
    has_prox = True

    def __init__(
        self,
        norm: Functional,
        circular: bool = True,
        ndims: Optional[int] = None,
        input_shape: Optional[Shape] = None,
        input_dtype: DType = snp.float32,
    ):
        """
        While initializers for :class:`.Functional` objects typically do
        not take `input_shape` and `input_dtype` parameters, they are
        included here because methods :meth:`__call__` and :meth:`prox`
        require instantiation of some :class:`.LinearOperator` objects,
        which do take these parameters. If these parameters are not
        provided on intialization of a :class:`TVNorm` object, then
        creation of the required :class:`.LinearOperator` objects is
        deferred until these methods are called, which can result in
        `JAX tracer <https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables>`__
        errors when they are components of a jitted function.

        Args:
            norm: Norm functional from which the TV norm is composed.
            circular: Flag indicating use of circular boundary conditions.
            ndims: Number of (trailing) dimensions of the input over
                which to apply the finite difference operator. If
                ``None``, differences are evaluated along all axes.
            input_shape: Shape of input arrays of :meth:`__call__` and
                :meth:`prox`.
            input_dtype: `dtype` of input arrays of :meth:`__call__` and
                :meth:`prox`.
        """
        self.norm = norm
        self.circular = circular
        self.ndims = ndims
        self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0)  # lowpass filter
        self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0)  # highpass filter
        self.G: Optional[LinearOperator] = None
        self.WP: Optional[LinearOperator] = None

        if input_shape is not None:
            if ndims is None:
                ndims = len(input_shape)
            self.G = self._call_operator(ndims, input_shape, input_dtype)
            self.WP, self.CWT = self._prox_operators(ndims, input_shape, input_dtype)

    def _call_operator(self, ndims: int, input_shape: Shape, input_dtype: DType) -> LinearOperator:
        """Construct operator required by __call__ method."""
        axes = tuple(range(len(input_shape) - ndims, len(input_shape)))
        G = FiniteDifference(
            input_shape,
            input_dtype=input_dtype,
            axes=axes,
            circular=self.circular,
            append=None if self.circular else 0,
            jit=True,
        )
        return G

[docs] def __call__(self, x: Array) -> float: """Compute the TV norm of an array. Args: x: Array for which the TV norm should be computed. Returns: TV norm of `x`. """ if self.G is None or self.G.shape[1] != x.shape: if self.ndims is None: ndims = x.ndim else: ndims = self.ndims self.G = self._call_operator(ndims, x.shape, x.dtype) return self.norm(self.G @ x)
@staticmethod def _shape(idx: int, ndims: int) -> Tuple: """Construct a shape tuple. Construct a tuple of size `ndims` with all unit entries except for index `idx`, which has a -1 entry. """ return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1) @staticmethod def _center(idx: int, ndims: int) -> Tuple: """Construct a center tuple. Construct a tuple of size `ndims` with all zero entries except for index `idx`, which has a unit entry. """ return (0,) * idx + (1,) + (0,) * (ndims - idx - 1) def _haar_operator(self, ndims: int, input_shape: Shape, input_dtype: DType) -> LinearOperator: """Construct single-level shift-invariant Haar transform.""" h0 = self.h0.astype(input_dtype) h1 = self.h1.astype(input_dtype) ConvOp = lambda h, k: CircularConvolve( h.reshape(TVNorm._shape(k, ndims)), input_shape, ndims=ndims, h_center=TVNorm._center(k, ndims), ) L = VerticalStack( # stack of lowpass filter operators for each axis [ConvOp(h0, k) for k in range(ndims)] ) H = VerticalStack( # stack of highpass filter operators for each axis [ConvOp(h1, k) for k in range(ndims)] ) # single-level shift-invariant Haar transform return VerticalStack([L, H], jit=True) def _prox_operators( self, ndims: int, input_shape: Shape, input_dtype: DType ) -> Tuple[LinearOperator, LinearOperator]: """Construct operators required by prox method.""" w_input_shape = ( # circular boundary: shape of input array input_shape if self.circular # non-circular boundary: shape of input array on non-differenced # axes and one greater for axes that are differenced else input_shape[0 : (len(input_shape) - ndims)] + tuple([n + 1 for n in input_shape[-ndims:]]) ) W = self._haar_operator(ndims, w_input_shape, input_dtype) if self.circular: WP, CWT = W, W.T else: pad_width = ((0, 0),) * (len(input_shape) - ndims) + ((0, 1),) * ndims P = Pad(input_shape, pad_width=pad_width, mode="edge", jit=True) WP = W @ P C = Crop(crop_width=pad_width, input_shape=w_input_shape, jit=True) CWT = C @ W.T return WP, CWT
[docs] def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: r"""Approximate proximal operator of the TV norm. Approximation of the proximal operator of the TV norm, computed via the methods described in :cite:`kamilov-2016-parallel` :cite:`kamilov-2016-minimizing`. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lam`. kwargs: Additional arguments that may be used by derived classes. """ if self.ndims is None: ndims = v.ndim else: ndims = self.ndims K = 2 * ndims if self.WP is None or self.WP.shape[1] != v.shape: self.WP, self.CWT = self._prox_operators(ndims, v.shape, v.dtype) if self.circular: slce = snp.s_[1] else: slce = ( ( 1, snp.s_[:], ) + (snp.s_[:],) * (v.ndim - ndims) + (snp.s_[:-1],) * ndims ) # Apply shrinkage to highpass component of shift-invariant Haar transform # of padded input (or to non-boundary region thereof for non-circular # boundary conditions). WPv: Array = self.WP(v) WPv = WPv.at[slce].set(self.norm.prox(WPv[slce], snp.sqrt(2) * K * lam)) u = (1.0 / K) * self.CWT(WPv) return u
class AnisotropicTVNorm(TVNorm): r"""The anisotropic total variation (TV) norm. The anisotropic total variation (TV) norm computed by .. code-block:: python ATV = scico.functional.AnisotropicTVNorm() x_norm = ATV(x) is equivalent to .. code-block:: python C = linop.FiniteDifference(input_shape=x.shape, circular=True) L1 = functional.L1Norm() x_norm = L1(C @ x) The scaled proximal operator is computed using an approximation that holds for small scaling parameters :cite:`kamilov-2016-parallel`. This does not imply that it can only be applied to problems requiring a small regularization parameter since most proximal algorithms include an additional algorithm parameter that also plays a role in the parameter of the proximal operator. For example, in :class:`.PGM` and :class:`.AcceleratedPGM`, the scaled proximal operator parameter is the regularization parameter divided by the `L0` algorithm parameter, and for :class:`.ADMM`, the scaled proximal operator parameters are the regularization parameters divided by the entries in the `rho_list` algorithm parameter. """ def __init__( self, circular: bool = False, ndims: Optional[int] = None, input_shape: Optional[Shape] = None, input_dtype: DType = snp.float32, ): """ Args: circular: Flag indicating use of circular boundary conditions. ndims: Number of (trailing) dimensions of the input over which to apply the finite difference operator. If ``None``, differences are evaluated along all axes. input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. """ super().__init__( L1Norm(), circular=circular, ndims=ndims, input_shape=input_shape, input_dtype=input_dtype, ) class IsotropicTVNorm(TVNorm): r"""The isotropic total variation (TV) norm. The isotropic total variation (TV) norm computed by .. code-block:: python ATV = scico.functional.IsotropicTVNorm() x_norm = ATV(x) is equivalent to .. code-block:: python C = linop.FiniteDifference(input_shape=x.shape, circular=True) L21 = functional.L21Norm() x_norm = L21(C @ x) The scaled proximal operator is computed using an approximation that holds for small scaling parameters :cite:`kamilov-2016-minimizing`. This does not imply that it can only be applied to problems requiring a small regularization parameter since most proximal algorithms include an additional algorithm parameter that also plays a role in the parameter of the proximal operator. For example, in :class:`.PGM` and :class:`.AcceleratedPGM`, the scaled proximal operator parameter is the regularization parameter divided by the `L0` algorithm parameter, and for :class:`.ADMM`, the scaled proximal operator parameters are the regularization parameters divided by the entries in the `rho_list` algorithm parameter. """ def __init__( self, circular: bool = False, ndims: Optional[int] = None, input_shape: Optional[Shape] = None, input_dtype: DType = snp.float32, ): r""" Args: circular: Flag indicating use of circular boundary conditions. ndims: Number of (trailing) dimensions of the input over which to apply the finite difference operator. If ``None``, differences are evaluated along all axes. input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. """ super().__init__( L21Norm(), circular=circular, ndims=ndims, input_shape=input_shape, input_dtype=input_dtype, )