Source code for scico.linop._dft

# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Discrete Fourier transform linear operator class."""


# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import Optional, Sequence

import numpy as np

import scico.numpy as snp
from scico.typing import Shape

from ._linop import LinearOperator


class DFT(LinearOperator):
    r"""Multi-dimensional discrete Fourier transform."""

    def __init__(
        self,
        input_shape: Shape,
        axes: Optional[Sequence] = None,
        axes_shape: Optional[Shape] = None,
        norm: Optional[str] = None,
        jit: bool = True,
        **kwargs,
    ):
        r"""
        Args:
            input_shape: Shape of input array.
            axes: Axes over which to compute the DFT. If ``None``, the
                DFT is computed over all axes.
            axes_shape: Output shape on the subset of array axes selected
                by `axes`. This parameter has the same behavior as the
                `s` parameter of :func:`numpy.fft.fftn`.
            norm: DFT normalization mode. See the `norm` parameter of
                :func:`numpy.fft.fftn`.
            jit: If ``True``, jit the evaluation, adjoint, and gram
                functions of the LinearOperator.
        """
        if axes is not None and axes_shape is not None and len(axes) != len(axes_shape):
            raise ValueError(
                f"len(axes)={len(axes)} does not equal len(axes_shape)={len(axes_shape)}."
            )

        if axes_shape is not None:
            if axes is None:
                axes = tuple(range(len(input_shape) - len(axes_shape), len(input_shape)))
            tmp_output_shape = list(input_shape)
            for i, s in zip(axes, axes_shape):
                tmp_output_shape[i] = s
            output_shape = tuple(tmp_output_shape)
        else:
            output_shape = input_shape

        if axes is None or axes_shape is None:
            self.inv_axes_shape = None
        else:
            self.inv_axes_shape = [input_shape[i] for i in axes]

        self.axes = axes
        self.axes_shape = axes_shape
        self.norm = norm

        # To satisfy mypy -- DFT shapes must be tuples, not list of tuple
        # These get set inside of super().__init__ call, but we want to have
        # more restrictive type than the general LinearOperator
        self.input_shape: Shape
        self.output_shape: Shape

        super().__init__(
            input_shape=input_shape,
            output_shape=output_shape,
            input_dtype=np.complex64,
            output_dtype=np.complex64,
            jit=jit,
            **kwargs,
        )

    def _eval(self, x: snp.Array) -> snp.Array:
        return snp.fft.fftn(x, s=self.axes_shape, axes=self.axes, norm=self.norm)

[docs] def inv(self, z: snp.Array) -> snp.Array: """Compute the inverse of this LinearOperator. Compute the inverse of this LinearOperator applied to `z`. Args: z: Input array to inverse DFT. """ return snp.fft.ifftn(z, s=self.inv_axes_shape, axes=self.axes, norm=self.norm)