# -*- coding: utf-8 -*-# Copyright (C) 2021-2023 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Discrete Fourier transform linear operator class."""# Needed to annotate a class method that returns the encapsulating class;# see https://www.python.org/dev/peps/pep-0563/from__future__importannotationsfromtypingimportOptional,Sequenceimportnumpyasnpimportscico.numpyassnpfromscico.typingimportShapefrom._linopimportLinearOperatorclassDFT(LinearOperator):r"""Multi-dimensional discrete Fourier transform."""def__init__(self,input_shape:Shape,axes:Optional[Sequence]=None,axes_shape:Optional[Shape]=None,norm:Optional[str]=None,jit:bool=True,**kwargs,):r""" Args: input_shape: Shape of input array. axes: Axes over which to compute the DFT. If ``None``, the DFT is computed over all axes. axes_shape: Output shape on the subset of array axes selected by `axes`. This parameter has the same behavior as the `s` parameter of :func:`numpy.fft.fftn`. norm: DFT normalization mode. See the `norm` parameter of :func:`numpy.fft.fftn`. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ifaxesisnotNoneandaxes_shapeisnotNoneandlen(axes)!=len(axes_shape):raiseValueError(f"len(axes)={len(axes)} does not equal len(axes_shape)={len(axes_shape)}.")ifaxes_shapeisnotNone:ifaxesisNone:axes=tuple(range(len(input_shape)-len(axes_shape),len(input_shape)))tmp_output_shape=list(input_shape)fori,sinzip(axes,axes_shape):tmp_output_shape[i]=soutput_shape=tuple(tmp_output_shape)else:output_shape=input_shapeifaxesisNoneoraxes_shapeisNone:self.inv_axes_shape=Noneelse:self.inv_axes_shape=[input_shape[i]foriinaxes]self.axes=axesself.axes_shape=axes_shapeself.norm=norm# To satisfy mypy -- DFT shapes must be tuples, not list of tuple# These get set inside of super().__init__ call, but we want to have# more restrictive type than the general LinearOperatorself.input_shape:Shapeself.output_shape:Shapesuper().__init__(input_shape=input_shape,output_shape=output_shape,input_dtype=np.complex64,output_dtype=np.complex64,jit=jit,**kwargs,)def_eval(self,x:snp.Array)->snp.Array:returnsnp.fft.fftn(x,s=self.axes_shape,axes=self.axes,norm=self.norm)
[docs]definv(self,z:snp.Array)->snp.Array:"""Compute the inverse of this LinearOperator. Compute the inverse of this LinearOperator applied to `z`. Args: z: Input array to inverse DFT. """returnsnp.fft.ifftn(z,s=self.inv_axes_shape,axes=self.axes,norm=self.norm)