Source code for scico.numpy
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2026 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
r""":class:`.BlockArray` and compatible functions.
This module consists of :class:`.BlockArray` and functions that support
both instances of this class and jax arrays. This includes all the
functions from :mod:`jax.numpy` and :mod:`numpy.testing`, where many have
been extended to automatically map over block array blocks as described
in :ref:`numpy_functions_blockarray`. Also included are additional
functions unique to SCICO in :mod:`.util`.
"""
import sys
from functools import partial
from typing import Union
import numpy as np
import jax
import jax.numpy as jnp
from jax import Array
from . import _wrappers, fft, linalg, testing, util
from ._blockarray import BlockArray
from ._wrapped_function_lists import (
creation_routines,
mathematical_functions,
reduction_functions,
testing_functions,
)
__all__ = ["fft", "linalg", "testing", "util"]
# allow snp.blockarray(...) to create BlockArrays
blockarray = BlockArray.blockarray
blockarray.__module__ = __name__ # so that blockarray can be referenced in docs
# BlockArray appears to originate in this module
sys.modules[__name__].BlockArray.__module__ = __name__
# copy most of jnp without wrapping
_wrappers.add_attributes(to_dict=vars(), from_dict=jnp.__dict__)
# wrap jnp funcs
_wrappers.wrap_recursively(
vars(),
creation_routines,
partial(
_wrappers.map_func_over_args,
map_if_nested_args=["shape"],
map_if_list_args=["device"],
),
)
_wrappers.wrap_recursively(vars(), mathematical_functions, _wrappers.map_func_over_args)
_wrappers.wrap_recursively(vars(), reduction_functions, _wrappers.add_full_reduction)
[docs]
def ravel(ba: Union[Array | BlockArray]) -> Array:
"""Completely flatten a :class:`BlockArray` into a single ``Array``.
When called on an ``Array``, flattens the array.
Args:
ba: The :class:`BlockArray` to flatten.
Returns:
`ba` flattened into a single ``Array.``
"""
if isinstance(ba, BlockArray):
return jax.numpy.concatenate([arr.flatten() for arr in ba])
return ba.ravel()
# wrap testing funcs
_wrappers.wrap_recursively(
vars(), testing_functions, partial(_wrappers.map_func_over_args, is_void=True)
)
# clean up
del np, jnp, _wrappers