Source code for scico.operator._stack

# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Stack of operators classes."""

from __future__ import annotations

from typing import Optional, Sequence, Tuple, Union

import numpy as np

from typing_extensions import TypeGuard

import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.util import is_nested
from scico.typing import BlockShape, Shape

from ._operator import Operator


def collapse_shapes(
    shapes: Sequence[Union[Shape, BlockShape]], allow_collapse=True
) -> Tuple[Union[Shape, BlockShape], bool]:
    """Compute the collapsed representation of a sequence of shapes.

    Decide whether to collapse a sequence of shapes, returning either
    the sequence of shapes or a collapsed shape, and a boolean indicating
    whether the shape was collapsed."""

    if is_collapsible(shapes) and allow_collapse:
        return (len(shapes), *shapes[0]), True

    if is_blockable(shapes):
        return shapes, False

    raise ValueError(
        "Combining these shapes would result in a twice-nested BlockArray, which is not supported."
    )


def is_collapsible(shapes: Sequence[Union[Shape, BlockShape]]) -> bool:
    """Determine whether a sequence of shapes can be collapsed.

    Return ``True`` if the a list of shapes represent arrays that can
    be stacked, i.e., they are all the same."""
    return all(s == shapes[0] for s in shapes)


def is_blockable(shapes: Sequence[Union[Shape, BlockShape]]) -> TypeGuard[Union[Shape, BlockShape]]:
    """Determine whether a sequence of shapes could be a :class:`BlockArray` shape.

    Return ``True`` if the sequence of shapes represent arrays that can
    be combined into a :class:`BlockArray`, i.e., none are nested."""
    return not any(is_nested(s) for s in shapes)


class VerticalStack(Operator):
    r"""A vertical stack of operators.

    Given operators :math:`A_1, A_2, \dots, A_N`, create the operator
    :math:`H` such that

    .. math::
       H(\mb{x})
       =
       \begin{pmatrix}
            A_1(\mb{x}) \\
            A_2(\mb{x}) \\
            \vdots \\
            A_N(\mb{x}) \\
       \end{pmatrix} \;.
    """

    def __init__(
        self,
        ops: Sequence[Operator],
        collapse_output: Optional[bool] = True,
        jit: bool = True,
        **kwargs,
    ):
        r"""
        Args:
            ops: Operators to stack.
            collapse_output: If ``True`` and the output would be a
                :class:`BlockArray` with shape ((m, n, ...), (m, n, ...),
                ...), the output is instead a :class:`jax.Array` with
                shape (S, m, n, ...) where S is the length of `ops`.
            jit: See `jit` in :class:`Operator`.
        """
        VerticalStack.check_if_stackable(ops)

        self.ops = ops
        self.collapse_output = collapse_output

        output_shapes = tuple(op.output_shape for op in ops)
        self.output_collapsible = is_collapsible(output_shapes)

        if self.output_collapsible and self.collapse_output:
            output_shape = (len(ops),) + output_shapes[0]  # collapse to jax array
        else:
            output_shape = output_shapes

        super().__init__(
            input_shape=ops[0].input_shape,
            output_shape=output_shape,  # type: ignore
            input_dtype=ops[0].input_dtype,
            output_dtype=ops[0].output_dtype,
            jit=jit,
            **kwargs,
        )

[docs] @staticmethod def check_if_stackable(ops: Sequence[Operator]): """Check that input ops are suitable for stack creation.""" if not isinstance(ops, (list, tuple)): raise TypeError("Expected a list of Operator.") input_shapes = [op.shape[1] for op in ops] if not all(input_shapes[0] == s for s in input_shapes): raise ValueError( "Expected all Operators to have the same input shapes, " f"but got {input_shapes}." ) input_dtypes = [op.input_dtype for op in ops] if not all(input_dtypes[0] == s for s in input_dtypes): raise ValueError( "Expected all Operators to have the same input dtype, " f"but got {input_dtypes}." ) if any([is_nested(op.shape[0]) for op in ops]): raise ValueError("Cannot stack Operators with nested output shapes.") output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): raise ValueError("Expected all Operators to have the same output dtype.")
def _eval(self, x: Array) -> Union[Array, BlockArray]: if self.output_collapsible and self.collapse_output: return snp.stack([op(x) for op in self.ops]) return BlockArray([op(x) for op in self.ops]) class DiagonalStack(Operator): r"""A diagonal stack of operators. Given operators :math:`A_1, A_2, \dots, A_N`, create the operator :math:`H` such that .. math:: H \left( \begin{pmatrix} \mb{x}_1 \\ \mb{x}_2 \\ \vdots \\ \mb{x}_N \\ \end{pmatrix} \right) = \begin{pmatrix} A_1(\mb{x}_1) \\ A_2(\mb{x}_2) \\ \vdots \\ A_N(\mb{x}_N) \\ \end{pmatrix} \;. By default, if the inputs :math:`\mb{x}_1, \mb{x}_2, \dots, \mb{x}_N` all have the same (possibly nested) shape, `S`, this operator will work on the stack, i.e., have an input shape of `(N, *S)`. If the inputs have distinct shapes, `S1`, `S2`, ..., `SN`, this operator will work on the block concatenation, i.e., have an input shape of `(S1, S2, ..., SN)`. The same holds for the output shape. """ def __init__( self, ops: Sequence[Operator], collapse_input: Optional[bool] = True, collapse_output: Optional[bool] = True, jit: bool = True, **kwargs, ): """ Args: ops: Operators to stack. collapse_input: If ``True``, inputs are expected to be stacked along the first dimension when possible. collapse_output: If ``True``, the output will be stacked along the first dimension when possible. jit: See `jit` in :class:`Operator`. """ DiagonalStack.check_if_stackable(ops) self.ops = ops input_shape, self.collapse_input = collapse_shapes( tuple(op.input_shape for op in ops), collapse_input, ) output_shape, self.collapse_output = collapse_shapes( tuple(op.output_shape for op in ops), collapse_output, ) super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=ops[0].input_dtype, output_dtype=ops[0].output_dtype, jit=jit, **kwargs, )
[docs] @staticmethod def check_if_stackable(ops: Sequence[Operator]): """Check that input ops are suitable for stack creation.""" if not isinstance(ops, (list, tuple)): raise TypeError("Expected a list of Operator.") if any([is_nested(op.shape[0]) for op in ops]): raise ValueError("Cannot stack Operators with nested output shapes.") output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): raise ValueError("Expected all Operators to have the same output dtype.")
def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: result = tuple(op(x_n) for op, x_n in zip(self.ops, x)) if self.collapse_output: return snp.stack(result) return snp.blockarray(result)