# -*- coding: utf-8 -*-
# Copyright (C) 2022-2026 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.
"""Utility functions for working with jax arrays and BlockArrays."""
from __future__ import annotations
import collections
from math import prod
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
import jax
from typing_extensions import TypeGuard
import scico.numpy as snp
from scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, Shape
[docs]
def transpose_ntpl_of_list(ntpl: NamedTuple) -> List[NamedTuple]:
"""Convert a namedtuple of lists/arrays to a list of namedtuples.
Args:
ntpl: Named tuple object to be transposed.
Returns:
List of namedtuple objects.
"""
cls = ntpl.__class__
numentry = len(ntpl[0]) if isinstance(ntpl[0], list) else ntpl[0].shape[0]
nfields = len(ntpl._fields)
return [cls(*[ntpl[m][n] for m in range(nfields)]) for n in range(numentry)]
[docs]
def transpose_list_of_ntpl(ntlist: List[NamedTuple]) -> NamedTuple:
"""Convert a list of namedtuples to namedtuple of lists.
Args:
ntpl: List of namedtuple objects to be transposed.
Returns:
Named tuple of lists.
"""
cls = ntlist[0].__class__
numentry = len(ntlist)
nfields = len(ntlist[0])
return cls(*[[ntlist[m][n] for m in range(numentry)] for n in range(nfields)]) # type: ignore
[docs]
def namedtuple_to_array(ntpl: NamedTuple) -> snp.Array:
"""Convert a namedtuple to an array.
Convert a :func:`collections.namedtuple` object to a
:class:`numpy.ndarray` object that can be saved using
:func:`numpy.savez`.
Args:
ntpl: Named tuple object to be converted to ndarray.
Returns:
Array representation of input named tuple.
"""
return np.asarray(
{
"name": ntpl.__class__.__name__,
"fields": ntpl._fields,
"data": {fname: fval for fname, fval in zip(ntpl._fields, ntpl)},
}
)
[docs]
def array_to_namedtuple(array: snp.Array) -> NamedTuple:
"""Convert an array representation of a namedtuple back to a namedtuple.
Convert a :class:`numpy.ndarray` object constructed by
:func:`namedtuple_to_array` back to the original
:func:`collections.namedtuple` representation.
Args:
Array representation of named tuple constructed by
:func:`namedtuple_to_array`.
Returns:
Named tuple object with the same name and fields as the original
named tuple object provided to :func:`namedtuple_to_array`.
"""
cls = collections.namedtuple(array.item()["name"], array.item()["fields"]) # type: ignore
return cls(**array.item()["data"])
[docs]
def normalize_axes(
axes: Optional[Axes],
shape: Optional[Shape] = None,
default: Optional[List[int]] = None,
sort: bool = False,
) -> Sequence[int]:
"""Normalize `axes` to a sequence and optionally ensure correctness.
Normalize `axes` to a tuple or list and (optionally) ensure that
entries refer to axes that exist in `shape`.
Args:
axes: User specification of one or more axes: int, list, tuple,
or ``None``. Negative values count from the last to the first
axis.
shape: The shape of the array of which axes are being specified.
If not ``None``, `axes` is checked to make sure its entries
refer to axes that exist in `shape`.
default: Default value to return if `axes` is ``None``. By
default, `tuple(range(len(shape)))`.
sort: If ``True``, sort the returned axis indices.
Returns:
Tuple or list of axes (never an int, never ``None``). The output
will only be a list if the input is a list or if the input is
``None`` and `defaults` is a list.
"""
if axes is None:
if default is None:
if shape is None:
raise ValueError(
"Argument 'axes' cannot be None without a default or shape specified."
)
axes = tuple(range(len(shape)))
else:
axes = default
elif isinstance(axes, (list, tuple)):
axes = axes
elif isinstance(axes, int):
axes = (axes,)
else:
raise ValueError(f"Could not understand argument 'axes' {axes} as a list of axes.")
if shape is not None:
if min(axes) < 0:
axes = tuple([len(shape) + a if a < 0 else a for a in axes])
if max(axes) >= len(shape):
raise ValueError(
f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}."
)
if len(set(axes)) != len(axes):
raise ValueError(f"Duplicate value in axes {axes}; each axis must be unique.")
if sort:
axes = tuple(sorted(axes))
return axes
[docs]
def slice_length(length: int, idx: AxisIndex) -> Optional[int]:
"""Determine the length of an array axis after indexing.
Determine the length of an array axis after slicing. An exception is
raised if the indexing expression is an integer that is out of bounds
for the specified axis length. A value of ``None`` is returned for
valid integer indexing expressions as an indication that the
corresponding axis shape is an empty tuple; this value should be
converted to a unit integer if the axis size is required.
Args:
length: Length of axis being sliced.
idx: Indexing/slice to be applied to axis.
Returns:
Length of indexed/sliced axis.
Raises:
ValueError: If `idx` is an integer index that is out bounds for
the axis length or if the type of `idx` is not one of
`Ellipsis`, `int`, or `slice`.
"""
if idx is Ellipsis:
return length
if isinstance(idx, int):
if idx < -length or idx > length - 1:
raise ValueError(f"Index {idx} out of bounds for axis of length {length}.")
return None
if not isinstance(idx, slice):
raise ValueError(f"Index expression {idx} is of an unrecognized type.")
start, stop, stride = idx.indices(length)
if start > stop:
start = stop
return (stop - start + stride - 1) // stride
[docs]
def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]:
"""Determine the shape of an array after indexing/slicing.
The indexed shape is determined by replicating the observed effects
of NumPy/JAX array indexing/slicing syntax. It is significantly
faster than :func:`.jax_indexed_shape`, and has a minimal memory
footprint in all circumstances.
Args:
shape: Shape of array.
idx: Indexing expression (singleton or tuple of `Ellipsis`,
`int`, `slice`, or ``None`` (`np.newaxis`)).
Returns:
Shape of indexed/sliced array.
Raises:
ValueError: If any element of `idx` is not one of `Ellipsis`,
`int`, `slice`, or ``None`` (`np.newaxis`), or if an integer
index is out bounds for the corresponding axis length.
"""
if not isinstance(idx, tuple):
idx = (idx,)
idx_shape: List[Optional[int]] = list(shape)
offset = 0
newaxis = 0
for axis, ax_idx in enumerate(idx):
if ax_idx is None:
idx_shape.insert(axis, 1)
newaxis += 1
continue
if ax_idx is Ellipsis:
offset = len(shape) - len(idx)
continue
idx_shape[axis + offset + newaxis] = slice_length(shape[axis + offset], ax_idx)
return tuple(filter(lambda x: x is not None, idx_shape)) # type: ignore
[docs]
def jax_indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]:
"""Determine the shape of an array after indexing/slicing.
The indexed shape is determined by constructing and indexing an array
of the appropriate shape, relying on :func:`jax.jit` to avoid memory
allocation. It is potentially more reliable than
:func:`.indexed_shape` because the indexing/slicing calculations are
referred to JAX, but is significantly slower, and will involved
potentially significant memory allocations if JIT is disabled, e.g.
for debugging purposes.
Args:
shape: Shape of array.
idx: Indexing expression (singleton or tuple of `Ellipsis`,
`int`, `slice`, or ``None`` (`np.newaxis`)).
Returns:
Shape of indexed/sliced array.
"""
if not isinstance(idx, tuple):
idx = (idx,)
# Convert any slices to its representation (slice, (start, stop, step))
# allowing hashing, needed for jax.jit
idx = tuple(exp.__reduce__() if isinstance(exp, slice) else exp for exp in idx) # type: ignore
def get_shape(in_shape, ind_expr):
# convert slices representations back to slices
ind_expr = tuple(
(slice(*exp[1]) if isinstance(exp, tuple) and len(exp) > 0 and exp[0] == slice else exp)
for exp in ind_expr
)
return jax.numpy.empty(in_shape)[ind_expr].shape
# This compiles each time it gets new arguments because all arguments are static.
f = jax.jit(get_shape, static_argnums=(0, 1))
return tuple(t.item() for t in f(shape, idx)) # type: ignore
[docs]
def no_nan_divide(
x: Union[snp.BlockArray, snp.Array], y: Union[snp.BlockArray, snp.Array]
) -> Union[snp.BlockArray, snp.Array]:
"""Return `x/y`, with 0 instead of :data:`~numpy.NaN` where `y` is 0.
Args:
x: Numerator.
y: Denominator.
Returns:
`x / y` with 0 wherever `y == 0`.
"""
return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0)
def _readable_size(size: int) -> str:
"""Return a human-readable representation of an array size.
Args:
size: A positive integer array size.
Returns:
A string representation of the size.
"""
factor = [1, 1024, 1024**2, 1024**3, 1024**4]
units = ["B", "KB", "MB", "GB", "TB"]
idx_tuple = np.nonzero([size // f for f in factor[::-1]])
if idx_tuple[0].size == 0:
idx = len(factor) - 1
else:
idx = int(idx_tuple[0][0])
val = size // factor[::-1][idx]
ustr = units[::-1][idx]
return f"{val} {ustr}"
[docs]
def array_info(x: Union[snp.BlockArray, snp.Array]) -> str:
"""Return a string providing information about an array.
Args:
x: A numpy or jax array or scico :class:`BlockArray`.
Returns:
A string containing information on the array.
Raises:
TypeError: If the array is not of a recognized type.
"""
if isinstance(x, np.ndarray):
array_type = "numpy.ndarray"
elif isinstance(x, jax.Array):
array_type = "jax.Array"
elif isinstance(x, snp.BlockArray):
array_type = "scico.numpy.BlockArray"
else:
raise TypeError("Unrecognized array type {type(x)}.")
totalbytes = np.sum(x.nbytes).item() # type: ignore
return (
f"""{array_type}
shape: {x.shape}
size: {x.size}
bytes: {totalbytes} ({_readable_size(totalbytes)})
"""
+ (f" device: {x.device}\n" if hasattr(x, "device") else "")
+ f""" dtype: {dtype_name(x.dtype)}
id: {id(x)}
min, max: {snp.ravel(x).min()}, {snp.ravel(x).max()}
"""
)
[docs]
def shape_to_size(shape: Union[Shape, BlockShape]) -> int:
r"""Compute array size corresponding to a specified shape.
Compute array size corresponding to a specified shape, which may be
nested, i.e. corresponding to a :class:`BlockArray`.
Args:
shape: A shape tuple.
Returns:
The number of elements in an array or :class:`BlockArray` with
shape `shape`.
"""
if is_nested(shape):
return sum(prod(s) for s in shape) # type: ignore
return prod(shape) # type: ignore
[docs]
def is_array(x: Any) -> bool:
"""Check if input is of type :class:`jax.Array` or :class:`numpy.ndarray`.
Check if input is an array, of type :class:`jax.Array` or
:class:`numpy.ndarray`.
Args:
x: Object to be tested.
Returns:
``True`` if `x` is an array, ``False`` otherwise.
"""
return isinstance(x, (np.ndarray, jax.Array))
[docs]
def is_arraylike(x: Any) -> bool:
"""Check if input is of type :class:`jax.typing.ArrayLike`.
`isinstance(x, jax.typing.ArrayLike)` does not work in Python < 3.10,
see https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices.
Args:
x: Object to be tested.
Returns:
``True`` if `x` is an ArrayLike, ``False`` otherwise.
"""
return isinstance(x, (np.ndarray, jax.Array)) or np.isscalar(x)
[docs]
def is_nested(x: Any) -> bool:
"""Check if input is a list/tuple containing at least one list/tuple.
Args:
x: Object to be tested.
Returns:
``True`` if `x` is a list/tuple containing at least one
list/tuple, ``False`` otherwise.
Example:
>>> is_nested([1, 2, 3])
False
>>> is_nested([(1,2), (3,)])
True
>>> is_nested([[1, 2], 3])
True
"""
return isinstance(x, (list, tuple)) and any([isinstance(_, (list, tuple)) for _ in x])
[docs]
def is_collapsible(shapes: Sequence[Union[Shape, BlockShape]]) -> bool:
"""Determine whether a sequence of shapes can be collapsed.
Return ``True`` if the a list of shapes represent arrays that can
be stacked, i.e., they are all the same.
Args:
shapes: A sequence of shapes.
Returns:
A boolean value indicating whether the shapes are all the same.
"""
return all(s == shapes[0] for s in shapes)
[docs]
def is_blockable(shapes: Sequence[Union[Shape, BlockShape]]) -> TypeGuard[Union[Shape, BlockShape]]:
"""Determine whether a sequence of shapes could be a :class:`BlockArray` shape.
Return ``True`` if the sequence of shapes represent arrays that can
be combined into a :class:`BlockArray`, i.e., none are nested.
Args:
shapes: A sequence of shapes.
Returns:
A boolean value indicating whether any of the shapes are nested.
"""
return not any(is_nested(s) for s in shapes)
[docs]
def shape_dtype_rep(
shape: Union[Shape, BlockShape], dtype: DType
) -> Union[jax.ShapeDtypeStruct, snp.BlockArray]:
"""Construct a representation of array or blockarray shape and dtype.
Construct a representation of array or block array shape and dtype
that is suitable for both jax arrays and scico blockarrays.
Args:
shape: Array or blockarray shape.
dtype: Array or blockarray dtype.
Returns:
A :class:`jax.ShapeDtypeStruct` or a :class:`.BlockArray`
containing objects of type :class:`jax.ShapeDtypeStruct`.
"""
if is_nested(shape): # block array
return snp.BlockArray([jax.ShapeDtypeStruct(blk_shape, dtype=dtype) for blk_shape in shape])
else: # standard array
return jax.ShapeDtypeStruct(shape, dtype=dtype)
[docs]
def broadcast_nested_shapes(
shape_a: Union[Shape, BlockShape], shape_b: Union[Shape, BlockShape]
) -> Union[Shape, BlockShape]:
r"""Compute the result of broadcasting on array shapes.
Compute the result of applying a broadcasting binary operator to
(block) arrays with (possibly nested) shapes `shape_a` and `shape_b`.
Extends :func:`numpy.broadcast_shapes` to also support the nested
tuple shapes of :class:`BlockArray`\ s.
Args:
shape_a: First array shape.
shape_b: Second array shape.
Returns:
A (possibly nested) shape tuple.
Example:
>>> broadcast_nested_shapes(((1, 1, 3), (2, 3, 1)), ((2, 3,), (2, 1, 4)))
((1, 2, 3), (2, 3, 4))
"""
if not is_nested(shape_a) and not is_nested(shape_b):
return snp.broadcast_shapes(shape_a, shape_b)
if is_nested(shape_a) and not is_nested(shape_b):
return tuple(snp.broadcast_shapes(s, shape_b) for s in shape_a)
if not is_nested(shape_a) and is_nested(shape_b):
return tuple(snp.broadcast_shapes(shape_a, s) for s in shape_b)
if is_nested(shape_a) and is_nested(shape_b):
return tuple(snp.broadcast_shapes(s_a, s_b) for s_a, s_b in zip(shape_a, shape_b))
raise RuntimeError("Unexpected case encountered in broadcast_nested_shapes.")
[docs]
def is_real_dtype(dtype: DType) -> bool:
"""Determine whether a dtype is real.
Args:
dtype: A :mod:`numpy` or :mod:`scico.numpy` dtype (e.g.
:attr:`~numpy.float32`, :attr:`~numpy.complex64`).
Returns:
``False`` if the dtype is complex, otherwise ``True``.
"""
return snp.dtype(dtype).kind != "c"
[docs]
def is_complex_dtype(dtype: DType) -> bool:
"""Determine whether a dtype is complex.
Args:
dtype: A :mod:`numpy` or :mod:`scico.numpy` dtype (e.g.
:attr:`~numpy.float32`, :attr:`~numpy.complex64`).
Returns:
``True`` if the dtype is complex, otherwise ``False``.
"""
return snp.dtype(dtype).kind == "c"
[docs]
def real_dtype(dtype: DType) -> DType:
"""Construct the corresponding real dtype for a given complex dtype.
Construct the corresponding real dtype for a given complex dtype,
e.g. the real dtype corresponding to :attr:`~numpy.complex64` is
:attr:`~numpy.float32`.
Args:
dtype: A complex numpy or scico.numpy dtype (e.g.
:attr:`~numpy.complex64`, :attr:`~numpy.complex128`).
Returns:
The real dtype corresponding to the input dtype
"""
return snp.zeros(1, dtype).real.dtype
[docs]
def complex_dtype(dtype: DType) -> DType:
"""Construct the corresponding complex dtype for a given real dtype.
Construct the corresponding complex dtype for a given real dtype,
e.g. the complex dtype corresponding to :attr:`~numpy.float32` is
:attr:`~numpy.complex64`.
Args:
dtype: A real numpy or scico.numpy dtype (e.g. :attr:`~numpy.float32`,
:attr:`~numpy.float64`).
Returns:
The complex dtype corresponding to the input dtype.
"""
return (snp.zeros(1, dtype) + 1j).dtype
[docs]
def dtype_name(dtype: DType) -> str:
"""Return the name of a dtype.
Construct a string representation of a dtype name.
Args:
dtype: The dtype for which the name is required.
Returns:
The name of the dtype.
"""
if type(dtype).__module__ == "numpy.dtypes":
return f"""numpy.{dtype.name}""" # type: ignore
return f"""{dtype.__module__}.{dtype.__qualname__}""" # type: ignore
[docs]
def is_scalar_equiv(s: Any) -> bool:
"""Determine whether an object is a scalar or is scalar-equivalent.
Determine whether an object is a scalar or a singleton array.
Args:
s: Object to be tested.
Returns:
``True`` if the object is a scalar or a singleton array,
otherwise ``False``.
"""
return snp.isscalar(s) or (isinstance(s, jax.Array) and s.ndim == 0)