Source code for scico.functional._tvnorm

# -*- coding: utf-8 -*-
# Copyright (C) 2023-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Total variation norms."""

from functools import partial
from typing import Optional, Tuple

import jax

from scico import numpy as snp
from scico.linop import (
    Crop,
    FiniteDifference,
    LinearOperator,
    Pad,
    SingleAxisFiniteDifference,
    VerticalStack,
    linop_over_axes,
)
from scico.numpy import Array
from scico.numpy.util import normalize_axes
from scico.typing import Axes, DType, Shape

from ._functional import Functional
from ._norm import L1Norm, L21Norm


class TVNorm(Functional):
    r"""Generic total variation (TV) norm.

    Generic total variation (TV) norm with approximation of the scaled
    proximal operator :cite:`kamilov-2016-parallel`
    :cite:`kamilov-2016-minimizing` :cite:`chandler-2024-closedform`.
    """

    has_eval = True
    has_prox = True

    def __init__(
        self,
        norm: Functional,
        circular: bool = True,
        axes: Optional[Axes] = None,
        input_shape: Optional[Shape] = None,
        input_dtype: DType = snp.float32,
    ):
        """
        While initializers for :class:`.Functional` objects typically do
        not take `input_shape` and `input_dtype` parameters, they are
        included here because methods :meth:`__call__` and :meth:`prox`
        require instantiation of some :class:`.LinearOperator` objects,
        which do take these parameters. If these parameters are not
        provided on intialization of a :class:`TVNorm` object, then
        creation of the required :class:`.LinearOperator` objects is
        deferred until these methods are called, which can result in
        `JAX tracer <https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables>`__
        errors when they are components of a jitted function.

        Args:
            norm: Norm functional from which the TV norm is composed.
            circular: Flag indicating use of circular boundary conditions.
            axes: Axis or axes over which to apply finite difference
                operator. If not specified, or ``None``, differences are
                evaluated along all axes.
            input_shape: Shape of input arrays of :meth:`__call__` and
                :meth:`prox`.
            input_dtype: `dtype` of input arrays of :meth:`__call__` and
                :meth:`prox`.
        """
        self.norm = norm
        self.circular = circular
        self.axes = axes
        self.G: Optional[LinearOperator] = None
        self.WP: Optional[LinearOperator] = None
        self.prox_ndims: Optional[int] = None
        self.prox_slice: Optional[Tuple] = None

        if input_shape is not None:
            self.G = self._call_operator(input_shape, input_dtype)
            self.WP, self.CWT, self.prox_ndims, self.prox_slice = self._prox_operators(
                input_shape, input_dtype
            )

    def _call_operator(self, input_shape: Shape, input_dtype: DType) -> LinearOperator:
        """Construct operator required by __call__ method."""
        G = FiniteDifference(
            input_shape,
            input_dtype=input_dtype,
            axes=self.axes,
            circular=self.circular,
            # For non-circular boundary conditions, zero-pad to the right
            # for equivalence with boundary conditions implemented in the
            # prox calculation.
            append=None if self.circular else 0,
            jit=True,
        )
        return G

[docs] def __call__(self, x: Array) -> float: """Compute the TV norm of an array. Args: x: Array for which the TV norm should be computed. Returns: TV norm of `x`. """ if self.G is None or self.G.shape[1] != x.shape: self.G = self._call_operator(x.shape, x.dtype) return self.norm(self.G @ x)
def _prox_operators( self, input_shape: Shape, input_dtype: DType ) -> Tuple[LinearOperator, LinearOperator, int, Tuple]: """Construct operators required by prox method.""" axes = normalize_axes(self.axes, input_shape) ndims = len(axes) w_input_shape = ( # circular boundary: shape of input array input_shape if self.circular # non-circular boundary: shape of input array on non-differenced # axes and one greater for axes that are differenced else tuple([s + 1 if i in axes else s for i, s in enumerate(input_shape)]) # type: ignore ) W = HaarTransform(w_input_shape, input_dtype=input_dtype, axes=axes, jit=True) # type: ignore if self.circular: # slice selecting highpass component of shift-invariant Haar transform slce = snp.s_[:, 1] # No boundary extension, so fused extend and forward transform, and fused # adjoint transform and crop are just forward and adjoint respectively. WP, CWT = W, W.T else: # slice selecting non-boundary region of highpass component of # shift-invariant Haar transform slce = ( snp.s_[:], snp.s_[1], ) + tuple( [snp.s_[:-1] if i in axes else snp.s_[:] for i, s in enumerate(input_shape)] ) # type: ignore # Replicate-pad to the right (resulting in a zero after finite differencing) # on all axes subject to finite differencing. pad_width = [(0, 1) if i in axes else (0, 0) for i, s in enumerate(input_shape)] # type: ignore P = Pad( input_shape, input_dtype=input_dtype, pad_width=pad_width, mode="edge", jit=True ) # fused boundary extend and forward transform linop WP = W @ P # crop operation that is inverse of the padding operation C = Crop( crop_width=pad_width, input_shape=w_input_shape, input_dtype=input_dtype, jit=True ) # fused adjoint transform and crop linop CWT = C @ W.T return WP, CWT, ndims, slce @staticmethod def _slice_tuple_to_tuple(st: Tuple) -> Tuple: """Convert a tuple of slice or int to a tuple of tuple or int. Required here as a workaround for the unhashability of slices in Python < 3.12, since jax.jit requires static arguments to be hashable. """ return tuple([(s.start, s.stop, s.step) if isinstance(s, slice) else s for s in st]) @staticmethod def _slice_tuple_from_tuple(st: Tuple) -> Tuple: """Convert a tuple of tuple or int to a tuple of slice or int. Required here as a workaround for the unhashability of slices in Python < 3.12, since jax.jit requires static arguments to be hashable. """ return tuple([slice(*s) if isinstance(s, tuple) else s for s in st]) @staticmethod @partial(jax.jit, static_argnums=(0, 1, 2, 4)) def _prox_core( WP: LinearOperator, CWT: LinearOperator, norm: Functional, K: int, slce_rep: Tuple, v: Array, lam: float = 1.0, ) -> Array: """Core component of prox calculation.""" # Apply boundary extension (when circular==False) and single-level Haar # transform to input array. WPv: Array = WP(v) # Convert tuple of slices/ints to tuple of tuples/ints to avoid jax.jit # complaints about unhashability of slices. slce = TVNorm._slice_tuple_from_tuple(slce_rep) # Apply shrinkage to highpass component of shift-invariant Haar transform # of padded input (or to non-boundary region thereof when circular==False). WPv = WPv.at[slce].set(norm.prox(WPv[slce], snp.sqrt(2) * K * lam)) # Apply adjoint of single-level Haar transform and crop extended # part of array (when circular==False). return (1.0 / K) * CWT(WPv)
[docs] def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array: r"""Approximate scaled proximal operator of the TV norm. Approximation of the scaled proximal operator of the TV norm, computed via the methods described in :cite:`kamilov-2016-parallel` :cite:`kamilov-2016-minimizing` :cite:`chandler-2024-closedform`. Args: v: Input array :math:`\mb{v}`. lam: Proximal parameter :math:`\lam`. **kwargs: Additional arguments that may be used by derived classes. Returns: Result of evaluating the scaled proximal operator at `v`. """ if self.WP is None or self.WP.shape[1] != v.shape: self.WP, self.CWT, self.prox_ndims, self.prox_slice = self._prox_operators( v.shape, v.dtype ) assert self.prox_ndims is not None assert self.prox_slice is not None K = 2 * self.prox_ndims u = TVNorm._prox_core( self.WP, self.CWT, self.norm, K, TVNorm._slice_tuple_to_tuple(self.prox_slice), v, lam ) return u
class AnisotropicTVNorm(TVNorm): r"""The anisotropic total variation (TV) norm. The anisotropic total variation (TV) norm computed by .. code-block:: python ATV = scico.functional.AnisotropicTVNorm() x_norm = ATV(x) is equivalent to .. code-block:: python C = linop.FiniteDifference(input_shape=x.shape, circular=True) L1 = functional.L1Norm() x_norm = L1(C @ x) The scaled proximal operator is computed using an approximation that holds for small scaling parameters :cite:`kamilov-2016-parallel`. This does not imply that it can only be applied to problems requiring a small regularization parameter since most proximal algorithms include an additional algorithm parameter that also plays a role in the parameter of the proximal operator. For example, in :class:`.PGM` and :class:`.AcceleratedPGM`, the scaled proximal operator parameter is the regularization parameter divided by the `L0` algorithm parameter, and for :class:`.ADMM`, the scaled proximal operator parameters are the regularization parameters divided by the entries in the `rho_list` algorithm parameter. """ def __init__( self, circular: bool = False, axes: Optional[Axes] = None, input_shape: Optional[Shape] = None, input_dtype: DType = snp.float32, ): """ Args: circular: Flag indicating use of circular boundary conditions. axes: Axis or axes over which to apply finite difference operator. If not specified, or ``None``, differences are evaluated along all axes. input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. """ super().__init__( L1Norm(), circular=circular, axes=axes, input_shape=input_shape, input_dtype=input_dtype, ) class IsotropicTVNorm(TVNorm): r"""The isotropic total variation (TV) norm. The isotropic total variation (TV) norm computed by .. code-block:: python ATV = scico.functional.IsotropicTVNorm() x_norm = ATV(x) is equivalent to .. code-block:: python C = linop.FiniteDifference(input_shape=x.shape, circular=True) L21 = functional.L21Norm() x_norm = L21(C @ x) The scaled proximal operator is computed using an approximation that holds for small scaling parameters :cite:`kamilov-2016-minimizing`. This does not imply that it can only be applied to problems requiring a small regularization parameter since most proximal algorithms include an additional algorithm parameter that also plays a role in the parameter of the proximal operator. For example, in :class:`.PGM` and :class:`.AcceleratedPGM`, the scaled proximal operator parameter is the regularization parameter divided by the `L0` algorithm parameter, and for :class:`.ADMM`, the scaled proximal operator parameters are the regularization parameters divided by the entries in the `rho_list` algorithm parameter. """ def __init__( self, circular: bool = False, axes: Optional[Axes] = None, input_shape: Optional[Shape] = None, input_dtype: DType = snp.float32, ): r""" Args: circular: Flag indicating use of circular boundary conditions. axes: Axis or axes over which to apply finite difference operator. If not specified, or ``None``, differences are evaluated along all axes. input_shape: Shape of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. input_dtype: `dtype` of input arrays of :meth:`~.TVNorm.__call__` and :meth:`~.TVNorm.prox`. """ super().__init__( L21Norm(), circular=circular, axes=axes, input_shape=input_shape, input_dtype=input_dtype, ) class SingleAxisFiniteSum(LinearOperator): r"""Two-point sum operator acting along a single axis. Boundary handling is circular, so that the sum operator corresponds to the matrix .. math:: \left(\begin{array}{rrrrr} 1 & 0 & 0 & \ldots & 0\\ 1 & 1 & 0 & \ldots & 0\\ 0 & 1 & 1 & \ldots & 0\\ \vdots & \vdots & \ddots & \ddots & \vdots\\ 0 & 0 & \ldots & 1 & 1\\ 1 & 0 & \dots & 0 & 1 \end{array}\right) \;. """ def __init__( self, input_shape: Shape, input_dtype: DType = snp.float32, axis: int = -1, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axis: Axis over which to apply sum operator. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ if not isinstance(axis, int): raise TypeError( f"Expected argument 'axis' to be of type int, got {type(axis)} instead." ) if axis < 0: axis = len(input_shape) + axis if axis >= len(input_shape): raise ValueError( f"Invalid argument 'axis' specified ({axis}); 'axis' must be less than " f"len(input_shape)={len(input_shape)}." ) self.axis = axis super().__init__( input_shape=input_shape, output_shape=input_shape, input_dtype=input_dtype, output_dtype=input_dtype, jit=jit, **kwargs, ) def _eval(self, x: snp.Array) -> snp.Array: return x + snp.roll(x, -1, self.axis) class FiniteSum(VerticalStack): """Two-point sum operator. Compute two-point sums along the specified axes, returning the results stacked on axis 0 of a :class:`jax.Array`. See :class:`SingleAxisFiniteSum` for boundary handling details. """ def __init__( self, input_shape: Shape, input_dtype: DType = snp.float32, axes: Optional[Axes] = None, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axes: Axis or axes over which to apply sum operator. If not specified, or ``None``, sums are evaluated along all axes. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ self.axes, ops = linop_over_axes( SingleAxisFiniteSum, input_shape, axes=axes, input_dtype=input_dtype, jit=False, ) super().__init__( ops, # type: ignore jit=jit, **kwargs, ) class SingleAxisHaarTransform(VerticalStack): """Single-level shift-invariant Haar transform along a single axis. Compute one level of a shift-invariant Haar transform along the specified axis, returning the results in a :class:`jax.Array` consisting of sum and difference components (corresponding to lowpass and highpass filtered components respectively) stacked on axis 0. See :class:`SingleAxisFiniteSum` for boundary handling details. """ def __init__( self, input_shape: Shape, input_dtype: DType = snp.float32, axis: int = -1, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axis: Axis over which to apply Haar transform. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ self.axis = axis self.HaarL = (1.0 / snp.sqrt(2.0)) * SingleAxisFiniteSum( input_shape, input_dtype=input_dtype, axis=axis, jit=jit, **kwargs ) self.HaarH = (1.0 / snp.sqrt(2.0)) * SingleAxisFiniteDifference( input_shape, input_dtype=input_dtype, axis=axis, circular=True, jit=jit, **kwargs ) super().__init__( (self.HaarL, self.HaarH), jit=jit, **kwargs, ) class HaarTransform(VerticalStack): """Single-level shift-invariant Haar transform. Compute one level of a shift-invariant Haar transform along the specified axes, returning the results in a :class:`jax.Array`. See :class:`SingleAxisHaarTransform` for details of the transform along each axis. """ def __init__( self, input_shape: Shape, input_dtype: DType = snp.float32, axes: Optional[Axes] = None, jit: bool = True, **kwargs, ): r""" Args: input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. axes: Axis or axes over which to apply Haar transform. If not specified, or ``None``, the transform is evaluated along all axes. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ self.axes, ops = linop_over_axes( SingleAxisHaarTransform, input_shape, axes=axes, input_dtype=input_dtype, jit=False, ) super().__init__( ops, # type: ignore jit=jit, **kwargs, )