Source code for scico.functional._tvnorm

# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Anisotropic total variation norm."""

from typing import Optional, Tuple

from scico import numpy as snp
from scico.linop import (
    CircularConvolve,
    FiniteDifference,
    LinearOperator,
    VerticalStack,
)
from scico.numpy import Array

from ._functional import Functional
from ._norm import L1Norm


class AnisotropicTVNorm(Functional):
    r"""The anisotropic total variation (TV) norm.

    The anisotropic total variation (TV) norm computed by

    .. code-block:: python

       ATV = scico.functional.AnisotropicTVNorm()
       x_norm = ATV(x)

    is equivalent to

    .. code-block:: python

       C = linop.FiniteDifference(input_shape=x.shape, circular=True)
       L1 = functional.L1Norm()
       x_norm = L1(C @ x)

    The scaled proximal operator is computed using an approximation that
    holds for small scaling parameters :cite:`kamilov-2016-parallel`.
    This does not imply that it can only be applied to problems requiring
    a small regularization parameter since most proximal algorithms
    include an additional algorithm parameter that also plays a role in
    the parameter of the proximal operator. For example, in :class:`.PGM`
    and :class:`.AcceleratedPGM`, the scaled proximal operator parameter
    is the regularization parameter divided by the `L0` algorithm
    parameter, and for :class:`.ADMM`, the scaled proximal operator
    parameters are the regularization parameters divided by the entries
    in the `rho_list` algorithm parameter.
    """

    has_eval = True
    has_prox = True

    def __init__(self, ndims: Optional[int] = None):
        """
        Args:
            ndims: Number of (trailing) dimensions of the input over
                which to apply the finite difference operator. If
                ``None``, differences are evaluated along all axes.
        """
        self.ndims = ndims
        self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0)  # lowpass filter
        self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0)  # highpass filter
        self.l1norm = L1Norm()
        self.G: Optional[LinearOperator] = None
        self.W: Optional[LinearOperator] = None

[docs] def __call__(self, x: Array) -> float: """Compute the anisotropic TV norm of an array.""" if self.G is None or self.G.shape[1] != x.shape: if self.ndims is None: ndims = x.ndim else: ndims = self.ndims axes = tuple(range(ndims)) self.G = FiniteDifference( x.shape, input_dtype=x.dtype, axes=axes, circular=True, jit=True ) return snp.sum(snp.abs(self.G @ x))
@staticmethod def _shape(idx: int, ndims: int) -> Tuple: """Construct a shape tuple. Construct a tuple of size `ndims` with all unit entries except for index `idx`, which has a -1 entry. """ return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1)
[docs] def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: r"""Approximate proximal operator of the isotropic TV norm. Approximation of the proximal operator of the anisotropic TV norm, computed via the method described in :cite:`kamilov-2016-parallel`. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lam`. kwargs: Additional arguments that may be used by derived classes. """ if self.ndims is None: ndims = v.ndim else: ndims = self.ndims K = 2 * ndims if self.W is None or self.W.shape[1] != v.shape: h0 = self.h0.astype(v.dtype) h1 = self.h1.astype(v.dtype) C0 = VerticalStack( # Stack of lowpass filter operators for each axis [ CircularConvolve( h0.reshape(AnisotropicTVNorm._shape(k, ndims)), v.shape, ndims=self.ndims, ) for k in range(ndims) ] ) C1 = VerticalStack( # Stack of highpass filter operators for each axis [ CircularConvolve( h1.reshape(AnisotropicTVNorm._shape(k, ndims)), v.shape, ndims=self.ndims, ) for k in range(ndims) ] ) # single-level shift-invariant Haar transform self.W = VerticalStack([C0, C1], jit=True) Wv = self.W @ v # Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam)) return (1.0 / K) * self.W.T @ Wv