Source code for scico.linop._diag

# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Miscellaneous linear operator definitions."""


# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import Optional, Union

import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.util import broadcast_nested_shapes, is_nested
from scico.operator._operator import _wrap_mul_div_scalar
from scico.typing import BlockShape, DType, Shape

from ._linop import LinearOperator, _wrap_add_sub

__all__ = ["Diagonal", "Identity", "ScaledIdentity"]


class Diagonal(LinearOperator):
    """Diagonal linear operator."""

    def __init__(
        self,
        diagonal: Union[Array, BlockArray],
        input_shape: Optional[Union[Shape, BlockShape]] = None,
        input_dtype: Optional[DType] = None,
        **kwargs,
    ):
        r"""
        Args:
            diagonal: Diagonal elements of this :class:`LinearOperator`.
            input_shape: Shape of input array. By default, equal to
               `diagonal.shape`, but may also be set to a shape that is
               broadcast-compatible with `diagonal.shape`.
            input_dtype: `dtype` of input argument. The default,
               ``None``, means `diagonal.dtype`.
        """
        self._diagonal = diagonal

        if input_shape is None:
            input_shape = self._diagonal.shape

        if input_dtype is None:
            input_dtype = self._diagonal.dtype

        if isinstance(diagonal, BlockArray) and is_nested(input_shape):
            output_shape = broadcast_nested_shapes(input_shape, self._diagonal.shape)
        elif not isinstance(diagonal, BlockArray) and not is_nested(input_shape):
            output_shape = snp.broadcast_shapes(input_shape, self._diagonal.shape)
        elif isinstance(diagonal, BlockArray):
            raise ValueError("Parameter diagonal was a BlockArray but input_shape was not nested.")
        else:
            raise ValueError("Parameter diagonal was not a BlockArray but input_shape was nested.")

        super().__init__(
            input_shape=input_shape,
            input_dtype=input_dtype,
            output_shape=output_shape,
            output_dtype=input_dtype,
            **kwargs,
        )

    def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
        return self._diagonal * x

    @property
    def diagonal(self) -> Union[Array, BlockArray]:
        """Return an array representing the diagonal component."""
        return self._diagonal

    @property
    def T(self) -> Diagonal:
        """Transpose of this :class:`Diagonal`."""
        return self

[docs] def conj(self) -> Diagonal: """Complex conjugate of this :class:`Diagonal`.""" return Diagonal(diagonal=self.diagonal.conj())
@property def H(self) -> Diagonal: """Hermitian transpose of this :class:`Diagonal`.""" return self.conj() @property def gram_op(self) -> Diagonal: """Gram operator of this :class:`Diagonal`. Return a new :class:`Diagonal` :code:`G` such that :code:`G(x) = A.adj(A(x)))`. """ return Diagonal(diagonal=self.diagonal.conj() * self.diagonal) @_wrap_add_sub def __add__(self, other): if self.diagonal.shape == other.diagonal.shape: return Diagonal(diagonal=self.diagonal + other.diagonal) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_add_sub def __sub__(self, other): if self.diagonal.shape == other.diagonal.shape: return Diagonal(diagonal=self.diagonal - other.diagonal) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_mul_div_scalar def __mul__(self, scalar): return Diagonal(diagonal=self.diagonal * scalar) @_wrap_mul_div_scalar def __truediv__(self, scalar): return Diagonal(diagonal=self.diagonal / scalar) def __matmul__(self, other): # self @ other if isinstance(other, Diagonal): if self.shape == other.shape: return Diagonal(diagonal=self.diagonal * other.diagonal) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") else: return self(other)
[docs] def norm(self, ord=None): # pylint: disable=W0622 """Compute the matrix norm of the diagonal operator. Valid values of `ord` and the corresponding norm definition are those listed under "norm for matrices" in the :func:`scico.numpy.linalg.norm` documentation. """ ordfunc = { "fro": lambda x: snp.linalg.norm(x), "nuc": lambda x: snp.sum(snp.abs(x)), -snp.inf: lambda x: snp.abs(x).min(), snp.inf: lambda x: snp.abs(x).max(), } mord = ord if mord is None: mord = "fro" elif mord in (-1, -2): mord = -snp.inf elif mord in (1, 2): mord = snp.inf if mord not in ordfunc: raise ValueError(f"Invalid value {ord} for parameter ord.") return ordfunc[mord](self._diagonal)
class ScaledIdentity(Diagonal): """Scaled identity operator.""" def __init__( self, scalar: float, input_shape: Union[Shape, BlockShape], input_dtype: DType = snp.float32, **kwargs, ): """ Args: scalar: Scaling of the identity. input_shape: Shape of input array. input_dtype: `dtype` of input argument. """ if is_nested(input_shape): diagonal = scalar * snp.ones(((),) * len(input_shape), dtype=input_dtype) else: diagonal = scalar * snp.ones((), dtype=input_dtype) super().__init__( diagonal=diagonal, input_shape=input_shape, input_dtype=input_dtype, **kwargs, ) @property def diagonal(self) -> Union[Array, BlockArray]: return self._diagonal * snp.ones(self.input_shape, dtype=self.input_dtype)
[docs] def conj(self) -> ScaledIdentity: """Complex conjugate of this :class:`ScaledIdentity`.""" return ScaledIdentity( scalar=self._diagonal.conj(), input_shape=self.input_shape, input_dtype=self.input_dtype )
@property def gram_op(self) -> ScaledIdentity: """Gram operator of this :class:`ScaledIdentity`.""" return ScaledIdentity( scalar=self._diagonal * self._diagonal.conj(), input_shape=self.input_shape, input_dtype=self.input_dtype, ) @_wrap_add_sub def __add__(self, other): if self.input_shape == other.input_shape: return ScaledIdentity( scalar=self._diagonal + other._diagonal, input_shape=self.input_shape, input_dtype=self.input_dtype, ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_add_sub def __sub__(self, other): if self.input_shape == other.input_shape: return ScaledIdentity( scalar=self._diagonal - other._diagonal, input_shape=self.input_shape, input_dtype=self.input_dtype, ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") @_wrap_mul_div_scalar def __mul__(self, scalar): return ScaledIdentity( scalar=self._diagonal * scalar, input_shape=self.input_shape, input_dtype=self.input_dtype, ) @_wrap_mul_div_scalar def __truediv__(self, scalar): return ScaledIdentity( scalar=self._diagonal / scalar, input_shape=self.input_shape, input_dtype=self.input_dtype, ) def __matmul__(self, other): # self @ other if isinstance(other, Diagonal): if self.shape != other.shape: raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") if isinstance(other, ScaledIdentity): return ScaledIdentity( scalar=self._diagonal * other._diagonal, input_shape=self.input_shape, input_dtype=self.input_dtype, ) else: return Diagonal(diagonal=self._diagonal * other.diagonal) else: return self(other)
[docs] def norm(self, ord=None): # pylint: disable=W0622 """Compute the matrix norm of the identity operator. Valid values of `ord` and the corresponding norm definition are those listed under "norm for matrices" in the :func:`scico.numpy.linalg.norm` documentation. """ N = self.input_size if ord is None or ord == "fro": return snp.abs(self._diagonal) * snp.sqrt(N) elif ord == "nuc": return snp.abs(self._diagonal) * N elif ord in (-snp.inf, -1, -2, 1, 2, snp.inf): return snp.abs(self._diagonal) else: raise ValueError(f"Invalid value {ord} for parameter ord.")
class Identity(ScaledIdentity): """Identity operator.""" def __init__( self, input_shape: Union[Shape, BlockShape], input_dtype: DType = snp.float32, **kwargs ): """ Args: input_shape: Shape of input array. input_dtype: `dtype` of input argument. """ super().__init__( scalar=1.0, input_shape=input_shape, input_dtype=input_dtype, **kwargs, ) def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: return x @property def diagonal(self) -> Union[Array, BlockArray]: return snp.ones(self.input_shape, dtype=self.input_dtype)
[docs] def conj(self) -> Identity: """Complex conjugate of this :class:`Diagonal`.""" return self
@property def gram_op(self) -> Identity: """Gram operator of this :class:`Identity`.""" return self def __matmul__(self, other): return other def __rmatmul__(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: return x