Source code for scico.random

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Random number generation.

This module provides convenient wrappers around several `jax.random
<https://jax.readthedocs.io/en/stable/jax.random.html>`_ routines to
handle the generation and splitting of PRNG keys, as well as the
generation of random :class:`.BlockArray`:

::

   # Calls to scico.random functions always return a PRNG key
   # If no key is passed to the function, a new key is generated
   x, key = scico.random.randn((2,))
   print(x)   # [ 0.19307713 -0.52678305]

   # scico.random functions automatically split the PRNGkey and return
   # an updated key
   y, key = scico.random.randn((2,), key=key)
   print(y) # [ 0.00870693 -0.04888531]

The user is responsible for passing the PRNG key to :mod:`scico.random`
functions. If no key is passed, repeated calls to :mod:`scico.random`
functions will return the same random numbers:

::

   x, key = scico.random.randn((2,))
   print(x)   # [ 0.19307713 -0.52678305]

   # No key passed, will return the same random numbers!
   y, key = scico.random.randn((2,))
   print(y)   # [ 0.19307713 -0.52678305]


If the desired shape is a tuple containing tuples, a :class:`.BlockArray`
is returned:

::

   x, key = scico.random.randn( ((1, 1), (2,)), key=key)
   print(x)  # scico.numpy.BlockArray:
             # Array([ 1.1378784 , -1.220955  , -0.59153646], dtype=float32)

"""


import inspect
import sys
from typing import Optional, Tuple, Union

import numpy as np

import jax

from scico.numpy import Array, BlockArray
from scico.numpy._wrappers import map_func_over_tuple_of_tuples
from scico.typing import BlockShape, DType, PRNGKey, Shape


def _add_seed(fun):
    """
    Modify a :mod:`jax.random` function to add a `seed` argument.

    Args:
        fun: function to be modified, e.g., :func:`jax.random.normal`.
        Expects `key` to be the first argument.

    Returns:
        fun_alt: a version of `fun` supporting an optional `seed`
           argument that is used to create a :func:`jax.random.PRNGKey`
           that is passed along as the `key`. The `key` argument may
           still be used, but is moved to be second-to-last. By default,
           `seed=0`. The `seed` argument is added last. Other arguments
           are unchanged.
    """

    # find number of arguments to fun
    num_params = len(inspect.signature(fun).parameters)

    def fun_alt(*args, key=None, seed=None, **kwargs):

        # key and seed may be in *args, look for them
        if len(args) >= num_params:  # they passed all position args including key
            key = args[num_params - 1]
        if len(args) > num_params:  # they passed all position args including key and seed
            seed = args[num_params]

        if key is not None and seed is not None:
            raise ValueError("Key and seed cannot both be specified.")

        if key is None:
            if seed is None:
                seed = 0
            key = jax.random.PRNGKey(seed)

        result = fun(key, *args[: num_params - 1], **kwargs)

        key, subkey = jax.random.split(key, 2)
        return result, key

    lines = fun.__doc__.split("\n\n")
    fun_alt.__doc__ = "\n\n".join(
        lines[0:1]
        + [
            f"  Wrapped version of `jax.random.{fun.__name__} "
            f"<https://jax.readthedocs.io/en/stable/jax.random.html#jax.random.{fun.__name__}>`_. "
            "The SCICO version of this function moves the `key` argument to the end of the "
            "argument list, adds an additional `seed` argument after that, and allows the "
            "`shape` argument to accept a nested list, in which case a `BlockArray` is returned. "
            "Always returns a `(result, key)` tuple. Original docstring below.",
        ]
        + lines[1:]
    )

    return fun_alt


def _wrap(fun):
    fun_wrapped = _add_seed(map_func_over_tuple_of_tuples(fun))
    fun_wrapped.__module__ = __name__  # so it appears in docs
    return fun_wrapped


def _is_wrappable(fun):
    params = inspect.signature(getattr(jax.random, fun)).parameters
    prmkey = list(params.keys())
    return prmkey and (prmkey[0] == "key") and ("shape" in params.keys())


wrappable_func_names = [
    t[0] for t in inspect.getmembers(jax.random, inspect.isfunction) if _is_wrappable(t[0])
]

for name in wrappable_func_names:
    setattr(sys.modules[__name__], name, _wrap(getattr(jax.random, name)))


[docs]def randn( shape: Union[Shape, BlockShape], dtype: DType = np.float32, key: Optional[PRNGKey] = None, seed: Optional[int] = None, ) -> Tuple[Union[Array, BlockArray], PRNGKey]: """Return an array drawn from the standard normal distribution. Alias for :func:`scico.random.normal`. Args: shape: Shape of output array. If shape is a tuple, a :class:`jax.Array` is returned. If shape is a tuple of tuples, a :class:`.BlockArray` is returned. key: JAX PRNGKey. Defaults to ``None``, in which case a new key is created using the seed arg. seed: Seed for new PRNGKey. Default: 0. dtype: dtype for returned value. Defaults to :attr:`~numpy.float32`. If a complex dtype such as :attr:`~numpy.complex64`, generates an array sampled from complex normal distribution. Returns: tuple: A tuple (x, key) containing: - **x** : (:class:`jax.Array`): Generated random array. - **key** : Updated random PRNGKey. """ return normal(shape, dtype, key, seed) # type: ignore