Source code for scico.operator.biconvolve

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Biconvolution operator."""


# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import Tuple, cast

import numpy as np

from jax.scipy.signal import convolve

import scico.linop
from scico.numpy import Array, BlockArray
from scico.numpy.util import is_nested
from scico.typing import DType, Shape

from ._operator import Operator


class BiConvolve(Operator):
    """Biconvolution operator.

    A :class:`.BiConvolve` operator accepts a :class:`.BlockArray` input
    with two blocks of equal ndims, and convolves the first block with
    the second.

    If `A` is a :class:`.BiConvolve` operator, then
    `A(snp.blockarray([x, h]))` equals `jax.scipy.signal.convolve(x, h)`.

    """

    def __init__(
        self,
        input_shape: Tuple[Shape, Shape],
        input_dtype: DType = np.float32,
        mode: str = "full",
        jit: bool = True,
    ):
        r"""
        Args:
            input_shape: Shape of input :class:`.BlockArray`. Must
                correspond to a :class:`.`BlockArray` with two blocks of
                equal ndims.
            input_dtype: `dtype` for input argument. Defaults to
                :attr:`~numpy.float32`.
            mode:  A string indicating the size of the output. One of
                "full", "valid", "same". Defaults to "full".
            jit: If ``True``, jit the evaluation of this
                :class:`.Operator`.

        For more details on `mode`, see :func:`jax.scipy.signal.convolve`.
        """

        if not is_nested(input_shape):
            raise ValueError("A BlockShape is expected; got {input_shape}.")
        if len(input_shape) != 2:
            raise ValueError(f"input_shape must have two blocks; got {len(input_shape)}.")
        if len(input_shape[0]) != len(input_shape[1]):
            raise ValueError(
                f"Both input blocks must have same number of dimensions; got "
                f"{len(input_shape[0]), len(input_shape[1])}."
            )

        if mode not in ["full", "valid", "same"]:
            raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.")

        self.mode = mode

        super().__init__(input_shape=input_shape, input_dtype=input_dtype, jit=jit)

    def _eval(self, x: BlockArray) -> Array:
        return convolve(x[0], x[1], mode=self.mode)

[docs] def freeze(self, argnum: int, val: Array) -> scico.linop.LinearOperator: """Freeze the `argnum` parameter. Return a new :class:`.LinearOperator` with block argument `argnum` fixed to value `val`. If `argnum == 0`, a :class:`.ConvolveByX` object is returned. If `argnum == 1`, a :class:`.Convolve` object is returned. Args: argnum: Index of block to freeze. Must be 0 or 1. val: Value to fix the `argnum`-th input to. """ if argnum == 0: return scico.linop.ConvolveByX( x=val, input_shape=cast(Shape, self.input_shape[1]), input_dtype=self.input_dtype, output_shape=self.output_shape, mode=self.mode, ) if argnum == 1: return scico.linop.Convolve( h=val, input_shape=cast(Shape, self.input_shape[0]), input_dtype=self.input_dtype, output_shape=self.output_shape, mode=self.mode, ) raise ValueError(f"Parameter argnum must be 0 or 1; got {argnum}.")