# -*- coding: utf-8 -*-# Copyright (C) 2020-2023 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Biconvolution operator."""# Needed to annotate a class method that returns the encapsulating class;# see https://www.python.org/dev/peps/pep-0563/from__future__importannotationsfromtypingimportTuple,castimportnumpyasnpfromjax.scipy.signalimportconvolveimportscico.linopfromscico.numpyimportArray,BlockArrayfromscico.numpy.utilimportis_nestedfromscico.typingimportDType,Shapefrom._operatorimportOperatorclassBiConvolve(Operator):"""Biconvolution operator. A :class:`.BiConvolve` operator accepts a :class:`.BlockArray` input with two blocks of equal ndims, and convolves the first block with the second. If `A` is a :class:`.BiConvolve` operator, then `A(snp.blockarray([x, h]))` equals `jax.scipy.signal.convolve(x, h)`. """def__init__(self,input_shape:Tuple[Shape,Shape],input_dtype:DType=np.float32,mode:str="full",jit:bool=True,):r""" Args: input_shape: Shape of input :class:`.BlockArray`. Must correspond to a :class:`.`BlockArray` with two blocks of equal ndims. input_dtype: `dtype` for input argument. Defaults to :attr:`~numpy.float32`. mode: A string indicating the size of the output. One of "full", "valid", "same". Defaults to "full". jit: If ``True``, jit the evaluation of this :class:`.Operator`. For more details on `mode`, see :func:`jax.scipy.signal.convolve`. """ifnotis_nested(input_shape):raiseValueError("A BlockShape is expected; got {input_shape}.")iflen(input_shape)!=2:raiseValueError(f"input_shape must have two blocks; got {len(input_shape)}.")iflen(input_shape[0])!=len(input_shape[1]):raiseValueError(f"Both input blocks must have same number of dimensions; got "f"{len(input_shape[0]),len(input_shape[1])}.")ifmodenotin["full","valid","same"]:raiseValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'.")self.mode=modesuper().__init__(input_shape=input_shape,input_dtype=input_dtype,jit=jit)def_eval(self,x:BlockArray)->Array:returnconvolve(x[0],x[1],mode=self.mode)
[docs]deffreeze(self,argnum:int,val:Array)->scico.linop.LinearOperator:"""Freeze the `argnum` parameter. Return a new :class:`.LinearOperator` with block argument `argnum` fixed to value `val`. If `argnum == 0`, a :class:`.ConvolveByX` object is returned. If `argnum == 1`, a :class:`.Convolve` object is returned. Args: argnum: Index of block to freeze. Must be 0 or 1. val: Value to fix the `argnum`-th input to. """ifargnum==0:returnscico.linop.ConvolveByX(x=val,input_shape=cast(Shape,self.input_shape[1]),input_dtype=self.input_dtype,output_shape=self.output_shape,mode=self.mode,)ifargnum==1:returnscico.linop.Convolve(h=val,input_shape=cast(Shape,self.input_shape[0]),input_dtype=self.input_dtype,output_shape=self.output_shape,mode=self.mode,)raiseValueError(f"Parameter argnum must be 0 or 1; got {argnum}.")