This module provides convenient wrappers around several jax.random routines to
handle the generation and splitting of PRNG keys, as well as the
generation of random BlockArray:
# Calls to scico.random functions always return a PRNG key# If no key is passed to the function, a new key is generatedx,key=scico.random.randn((2,))print(x)# [ 0.19307713 -0.52678305]# scico.random functions automatically split the PRNG key and return# an updated keyy,key=scico.random.randn((2,),key=key)print(y)# [ 0.00870693 -0.04888531]
The user is responsible for passing the PRNG key to scico.random
functions. If no key is passed, repeated calls to scico.random
functions will return the same random numbers:
x,key=scico.random.randn((2,))print(x)# [ 0.19307713 -0.52678305]# No key passed, will return the same random numbers!y,key=scico.random.randn((2,))print(y)# [ 0.19307713 -0.52678305]
If the desired shape is a tuple containing tuples, a BlockArray
is returned:
key (Optional[Array]) – JAX PRNGKey. Defaults to None, in which case a new key
is created using the seed arg.
seed (Optional[int]) – Seed for new PRNGKey. Default: 0.
dtype (DType) – dtype for returned value. Defaults to float32.
If a complex dtype such as complex64, generates
an array sampled from complex normal distribution.
Wrapped version of jax.random.ball. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
d – a nonnegative int representing the dimensionality of the ball.
p – a float representing the p parameter of the Lp norm.
shape – optional, the batch dimensions of the result. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding – optional, specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array of shape (*shape, d) and specified dtype.
Sample Bernoulli random values with given shape and mean.
Wrapped version of jax.random.bernoulli. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability mass function:
\[f(k; p) = p^k(1 - p)^{1 - k}\]
where \(k \in \{0, 1\}\) and \(0 \le p \le 1\).
Parameters:
key – a PRNG key used as the random key.
p – optional, a float or array of floats for the mean of the random
variables. Must be broadcast-compatible with shape. Default 0.5.
shape – optional, a tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with p.shape. The default (None)
produces a result shape equal to p.shape.
mode – optional, “high” or “low” for how many bits to use when sampling.
default=’low’. Set to “high” for correct sampling at small values of
p. When sampling in float32, bernoulli samples with mode=’low’ produce
incorrect results for p < ~1E-7. mode=”high” approximately doubles the
cost of sampling.
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with boolean dtype and shape given by shape if shape
is not None, or else p.shape.
Sample Beta random values with given shape and float dtype.
Wrapped version of jax.random.beta. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x;a,b) \propto x^{a - 1}(1 - x)^{b - 1}\]
on the domain \(0 \le x \le 1\).
Parameters:
key – a PRNG key used as the random key.
a – a float or array of floats broadcast-compatible with shape
representing the first parameter “alpha”.
b – a float or array of floats broadcast-compatible with shape
representing the second parameter “beta”.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with a and b. The default
(None) produces a result shape by broadcasting a and b.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and shape given by shape if
shape is not None, or else by broadcasting a and b.
Sample Binomial random values with given shape and float dtype.
Wrapped version of jax.random.binomial. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability mass function:
\[f(k;n,p) = \binom{n}{k}p^k(1-p)^{n-k}\]
on the domain \(0 < p < 1\), and where \(n\) is a nonnegative integer
representing the number of trials and \(p\) is a float representing the
probability of success of an individual trial.
Parameters:
key – a PRNG key used as the random key.
n – a float or array of floats broadcast-compatible with shape
representing the number of trials.
p – a float or array of floats broadcast-compatible with shape
representing the probability of success of an individual trial.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with n and p.
The default (None) produces a result shape equal to np.broadcast(n,p).shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by
np.broadcast(n,p).shape.
Sample uniform bits in the form of unsigned integers.
Wrapped version of jax.random.bits. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
Parameters:
key – a PRNG key used as the random key.
shape – optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype – optional, an unsigned integer dtype for the returned values (default
uint64 if jax_enable_x64 is true, otherwise uint32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Sample random values from categorical distributions.
Wrapped version of jax.random.categorical. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
the Gumbel top-k trick. See [1] for reference.
Parameters:
key – a PRNG key used as the random key.
logits – Unnormalized log probabilities of the categorical distribution(s) to sample from,
so that softmax(logits, axis) gives the corresponding probabilities.
axis – Axis along which logits belong to the same categorical distribution.
shape – Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with np.delete(logits.shape,axis).
The default (None) produces a result shape equal to np.delete(logits.shape,axis).
replace – If True (default), perform sampling with replacement. If False, perform
sampling without replacement.
mode – optional, “high” or “low” for how many bits to use in the gumbel sampler.
The default is determined by the use_high_dynamic_range_gumbel config,
which defaults to “low”. With mode=”low”, in float32 sampling will be biased
for events with probability less than about 1E-7; with mode=”high” this limit
is pushed down to about 1E-14. mode=”high” approximately doubles the cost of
sampling.
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with int dtype and shape given by shape if shape
is not None, or else np.delete(logits.shape,axis).
Sample Cauchy random values with given shape and float dtype.
Wrapped version of jax.random.cauchy. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x) \propto \frac{1}{x^2 + 1}\]
on the domain \(-\infty < x < \infty\)
Parameters:
key – a PRNG key used as the random key.
shape – optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Sample Chisquare random values with given shape and float dtype.
Wrapped version of jax.random.chisquare. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x; \nu) \propto x^{\nu/2 - 1}e^{-x/2}\]
on the domain \(0 < x < \infty\), where \(\nu > 0\) represents the
degrees of freedom, given by the parameter df.
Parameters:
key – a PRNG key used as the random key.
df – a float or array of floats broadcast-compatible with shape
representing the parameter of the distribution.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with df. The default (None)
produces a result shape equal to df.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
optional, Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by df.shape.
Wrapped version of jax.random.choice. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
Warning
If p has fewer non-zero elements than the requested number of samples,
as specified in shape, and replace=False, the output of this
function is ill-defined. Please make sure to use appropriate inputs.
Parameters:
key – a PRNG key used as the random key.
a – array or int. If an ndarray, a random sample is generated from
its elements. If an int, the random sample is generated as if a were
arange(a).
shape – tuple of ints, optional. Output shape. If the given shape is,
e.g., (m,n), then m*n samples are drawn. Default is (),
in which case a single value is returned.
replace – boolean. Whether the sample is with or without replacement.
Default is True.
p – 1-D array-like, The probabilities associated with each entry in a.
If not given the sample assumes a uniform distribution over all
entries in a.
axis – int, optional. The axis along which the selection is performed.
The default, 0, selects by row.
mode – optional, “high” or “low” for how many bits to use in the gumbel sampler
when p is None and replace = False. The default is determined by the
use_high_dynamic_range_gumbel config, which defaults to “low”. With mode=”low”,
in float32 sampling will be biased for choices with probability less than about
1E-7; with mode=”high” this limit is pushed down to about 1E-14. mode=”high”
approximately doubles the cost of sampling.
Returns:
An array of shape shape containing samples from a.
Sample Dirichlet random values with given shape and float dtype.
Wrapped version of jax.random.dirichlet. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
Where \(k\) is the dimension, and \(\{x_i\}\) satisfies
\[\sum_{i=1}^k x_i = 1\]
and \(0 \le x_i \le 1\) for all \(x_i\).
Parameters:
key – a PRNG key used as the random key.
alpha – an array of shape (...,n) used as the concentration
parameter of the random variables.
shape – optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
element of value n. Must be broadcast-compatible with
alpha.shape[:-1]. The default (None) produces a result shape equal to
alpha.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and shape given by
shape+(alpha.shape[-1],) if shape is not None, or else
alpha.shape.
Wrapped version of jax.random.double_sided_maxwell. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x;\mu,\sigma) \propto z^2 e^{-z^2 / 2}\]
where \(z = (x - \mu) / \sigma\), with the center \(\mu\) specified by
loc and the scale \(\sigma\) specified by scale.
Parameters:
key – a PRNG key.
loc – The location parameter of the distribution.
scale – The scale parameter of the distribution.
shape – The shape added to the parameters loc and scale broadcastable shape.
Sample Exponential random values with given shape and float dtype.
Wrapped version of jax.random.exponential. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x) = e^{-x}\]
on the domain \(0 \le x < \infty\).
Parameters:
key – a PRNG key used as the random key.
shape – optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Sample F-distribution random values with given shape and float dtype.
Wrapped version of jax.random.f. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
on the domain \(0 < x < \infty\). Here \(\nu_1\) is the degrees of
freedom of the numerator (dfnum), and \(\nu_2\) is the degrees of
freedom of the denominator (dfden).
Parameters:
key – a PRNG key used as the random key.
dfnum – a float or array of floats broadcast-compatible with shape
representing the numerator’s df of the distribution.
dfden – a float or array of floats broadcast-compatible with shape
representing the denominator’s df of the distribution.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with dfnum and dfden.
The default (None) produces a result shape equal to dfnum.shape,
and dfden.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
optional, specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by df.shape.
Sample Gamma random values with given shape and float dtype.
Wrapped version of jax.random.gamma. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x;a) \propto x^{a - 1} e^{-x}\]
on the domain \(0 \le x < \infty\), with \(a > 0\).
This is the standard gamma density, with a unit scale/rate parameter.
Dividing the sample output by the rate is equivalent to sampling from
gamma(a, rate), and multiplying the sample output by the scale is equivalent
to sampling from gamma(a, scale).
Parameters:
key – a PRNG key used as the random key.
a – a float or array of floats broadcast-compatible with shape
representing the parameter of the distribution.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with a. The default (None)
produces a result shape equal to a.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by a.shape.
See also
loggammasample gamma values in log-space, which can provide improved
Wrapped version of jax.random.generalized_normal. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability density function:
\[f(x;p) \propto e^{-|x|^p}\]
on the domain \(-\infty < x < \infty\), where \(p > 0\) is the
shape parameter.
Parameters:
key – a PRNG key used as the random key.
p – a float representing the shape parameter.
shape – optional, the batch dimensions of the result. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified shape and dtype.
Sample Geometric random values with given shape and float dtype.
Wrapped version of jax.random.geometric. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability mass function:
\[f(k;p) = p(1-p)^{k-1}\]
on the domain \(0 < p < 1\).
Parameters:
key – a PRNG key used as the random key.
p – a float or array of floats broadcast-compatible with shape
representing the probability of success of an individual trial.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with p. The default
(None) produces a result shape equal to np.shape(p).
dtype – optional, a int dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by p.shape.
Sample Gumbel random values with given shape and float dtype.
Wrapped version of jax.random.gumbel. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x) = e^{-(x + e^{-x})}\]
Parameters:
key – a PRNG key used as the random key.
shape – optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
mode – optional, “high” or “low” for how many bits to use when sampling.
The default is determined by the use_high_dynamic_range_gumbel config,
which defaults to “low”. When drawing float32 samples, with mode=”low” the
uniform resolution is such that the largest possible gumbel logit is ~16;
with mode=”high” this is increased to ~32, at approximately double the
computational cost.
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Sample Laplace random values with given shape and float dtype.
Wrapped version of jax.random.laplace. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x) = \frac{1}{2}e^{-|x|}\]
Parameters:
key – a PRNG key used as the random key.
shape – optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Sample log-gamma random values with given shape and float dtype.
Wrapped version of jax.random.loggamma. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
This function is implemented such that the following will hold for a
dtype-appropriate tolerance:
The benefit of log-gamma is that for samples very close to zero (which occur frequently
when a << 1) sampling in log space provides better precision.
Parameters:
key – a PRNG key used as the random key.
a – a float or array of floats broadcast-compatible with shape
representing the parameter of the distribution.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with a. The default (None)
produces a result shape equal to a.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by a.shape.
Sample logistic random values with given shape and float dtype.
Wrapped version of jax.random.logistic. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}\]
Parameters:
key – a PRNG key used as the random key.
shape – optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Sample lognormal random values with given shape and float dtype.
Wrapped version of jax.random.lognormal. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
sigma – a float or array of floats broadcast-compatible with shape representing
the standard deviation of the underlying normal distribution. Default 1.
shape – optional, a tuple of nonnegative integers specifying the result
shape. The default (None) produces a result shape equal to ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
optional, specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape.
Wrapped version of jax.random.maxwell. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
Wrapped version of jax.random.multinomial. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
n – number of trials. Should have shape broadcastable to p.shape[:-1].
p – probability of each outcome, with outcomes along the last axis.
shape – optional, a tuple of nonnegative integers specifying the result batch
shape, that is, the prefix of the result shape excluding the last axis.
Must be broadcast-compatible with p.shape[:-1]. The default (None)
produces a result shape equal to p.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
unroll – optional, unroll parameter passed to jax.lax.scan inside the
implementation of this function.
Returns:
An array of counts for each outcome with the specified dtype and with shape
p.shape if shape is None, otherwise shape+(p.shape[-1],).
Sample multivariate normal random values with given mean and covariance.
Wrapped version of jax.random.multivariate_normal. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability density function:
where \(k\) is the dimension, \(\mu\) is the mean (given by mean) and
\(\Sigma\) is the covariance matrix (given by cov).
Parameters:
key – a PRNG key used as the random key.
mean – a mean vector of shape (...,n).
cov – a positive definite covariance matrix of shape (...,n,n). The
batch shape ... must be broadcast-compatible with that of mean.
shape – optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis. Must be broadcast-compatible with mean.shape[:-1] and
cov.shape[:-2]. The default (None) produces a result batch shape by
broadcasting together the batch shapes of mean and cov.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
method – optional, a method to compute the factor of cov.
Must be one of ‘svd’, ‘eigh’, and ‘cholesky’. Default ‘cholesky’. For
singular covariance matrices, use ‘svd’ or ‘eigh’.
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and shape given by
shape+mean.shape[-1:] if shape is not None, or else
broadcast_shapes(mean.shape[:-1],cov.shape[:-2])+mean.shape[-1:].
Sample standard normal random values with given shape and float dtype.
Wrapped version of jax.random.normal. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability density function:
\[f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}\]
on the domain \(-\infty < x < \infty\)
Parameters:
key – a PRNG key used as the random key.
shape – optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Wrapped version of jax.random.orthogonal. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
If the dtype is complex, sample uniformly from the unitary group U(n).
For unequal rows and columns, this samples a semi-orthogonal matrix instead.
That is, if \(A\) is the resulting matrix and \(A^*\) is its conjugate
transpose, then:
If \(n \leq m\), the rows are mutually orthonormal: \(A A^* = I_n\).
If \(m \leq n\), the columns are mutually orthonormal: \(A^* A = I_m\).
Parameters:
key – a PRNG key used as the random key.
n – an integer indicating the number of rows.
shape – optional, the batch dimensions of the result. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
m – an integer indicating the number of columns. Defaults to n.
out_sharding –
optional, specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array of shape (*shape, n, m) and specified dtype.
Sample Pareto random values with given shape and float dtype.
Wrapped version of jax.random.pareto. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
\[f(x; b) = b / x^{b + 1}\]
on the domain \(1 \le x < \infty\) with \(b > 0\)
Parameters:
key – a PRNG key used as the random key.
b – a float or array of floats broadcast-compatible with shape
representing the parameter of the distribution.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with b. The default (None)
produces a result shape equal to b.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by b.shape.
Sample Poisson random values with given shape and integer dtype.
Wrapped version of jax.random.poisson. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability mass function:
Where k is a non-negative integer and \(\lambda > 0\).
Parameters:
key – a PRNG key used as the random key.
lam – rate parameter (mean of the distribution), must be >= 0. Must be broadcast-compatible with shape
shape – optional, a tuple of nonnegative integers representing the result
shape. Default (None) produces a result shape equal to lam.shape.
dtype – optional, a integer dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shapeisnotNone,orelseby``lam.shape.
Wrapped version of jax.random.rademacher. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability mass function:
on the domain \(k \in \{-1, 1\}\), where \(\delta(x)\) is the dirac delta function.
Parameters:
key – a PRNG key.
shape – The shape of the returned samples. Default ().
dtype – The type used for samples.
out_sharding –
optional, specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A jnp.array of samples, of shape shape. Each element in the output has
a 50% change of being 1 or -1.
Sample uniform random values in [minval, maxval) with given shape/dtype.
Wrapped version of jax.random.randint. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
Parameters:
key – a PRNG key used as the random key.
shape – a tuple of nonnegative integers representing the shape.
minval – int or array of ints broadcast-compatible with shape, a minimum
(inclusive) value for the range.
maxval – int or array of ints broadcast-compatible with shape, a maximum
(exclusive) value for the range.
dtype – optional, an int dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Note
randint uses a modulus-based computation that is known to produce
slightly biased values in some cases. The magnitude of the bias scales as
(maxval-minval)*((2**nbits)%(maxval-minval))/2**nbits:
in words, the bias goes to zero when (maxval-minval) is a power of 2,
and otherwise the bias will be small whenever (maxval-minval) is
small compared to the range of the sampled type.
To reduce this bias, 8-bit and 16-bit values will always be sampled at 32-bit and
then cast to the requested type. If you find yourself sampling values for which
this bias may be problematic, a possible alternative is to sample via uniform:
But keep in mind this method has its own biases due to floating point rounding
errors, and in particular there may be some integers in the range
[minval,maxval) that are impossible to produce with this approach.
Sample Rayleigh random values with given shape and float dtype.
Wrapped version of jax.random.rayleigh. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability density function:
\[f(x;\sigma) \propto xe^{-x^2/(2\sigma^2)}\]
on the domain \(-\infty < x < \infty\), and where \(\sigma > 0\) is the scale
parameter of the distribution.
Parameters:
key – a PRNG key used as the random key.
scale – a float or array of floats broadcast-compatible with shape
representing the parameter of the distribution.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with scale. The default (None)
produces a result shape equal to scale.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by scale.shape.
Sample Student’s t random values with given shape and float dtype.
Wrapped version of jax.random.t. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function:
Where \(\nu > 0\) is the degrees of freedom, given by the parameter df.
Parameters:
key – a PRNG key used as the random key.
df – a float or array of floats broadcast-compatible with shape
representing the degrees of freedom parameter of the distribution.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with df. The default (None)
produces a result shape equal to df.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by df.shape.
Sample Triangular random values with given shape and float dtype.
Wrapped version of jax.random.triangular. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability density function:
\[\begin{split}f(x; a, b, c) = \frac{2}{c-a} \left\{ \begin{array}{ll} \frac{x-a}{b-a} & a \leq x \leq b \\ \frac{c-x}{c-b} & b \leq x \leq c \end{array} \right.\end{split}\]
on the domain \(a \leq x \leq c\).
Parameters:
key – a PRNG key used as the random key.
left – a float or array of floats broadcast-compatible with shape
representing the lower limit parameter of the distribution.
mode – a float or array of floats broadcast-compatible with shape
representing the peak value parameter of the distribution, value must
fulfill the condition left<=mode<=right.
right – a float or array of floats broadcast-compatible with shape
representing the upper limit parameter of the distribution, must be
larger than left.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with left,``mode`` and right.
The default (None) produces a result shape equal to left.shape, mode.shape
and right.shape.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by left.shape, mode.shape and right.shape.
Sample truncated standard normal random values with given shape and dtype.
Wrapped version of jax.random.truncated_normal. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability density function:
\[f(x) \propto e^{-x^2/2}\]
on the domain \(\rm{lower} < x < \rm{upper}\).
Parameters:
key – a PRNG key used as the random key.
lower – a float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with upper.
upper – a float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with lower.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with lower and upper. The
default (None) produces a result shape by broadcasting lower and
upper.
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and shape given by shape if
shape is not None, or else by broadcasting lower and upper.
Returns values in the open interval (lower,upper).
Sample uniform random values in [minval, maxval) with given shape/dtype.
Wrapped version of jax.random.uniform. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
Parameters:
key – a PRNG key used as the random key.
shape – optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
minval – optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
maxval – optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
out_sharding –
Optional. Specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified shape and dtype.
Sample Wald random values with given shape and float dtype.
Wrapped version of jax.random.wald. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are returned according to the probability density function:
on the domain \(-\infty < x < \infty\), and where \(\mu > 0\) is the location
parameter of the distribution.
Parameters:
key – a PRNG key used as the random key.
mean – a float or array of floats broadcast-compatible with shape
representing the mean parameter of the distribution.
shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with mean. The default
(None) produces a result shape equal to np.shape(mean).
dtype – optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
out_sharding –
optional, specifies how the output array should be sharded
across devices in multi-device computation. Can be a
NamedSharding, a PartitionSpec
(P), or None (default). When specified, the output will be sharded
according to the given sharding specification. Primarily used in explicit
sharding mode.
See the explicit sharding tutorial
for more details.
Returns:
A random array with the specified dtype and with shape given by shape if
shape is not None, or else by mean.shape.
Wrapped version of jax.random.weibull_min. The SCICO version of this function moves the key argument to the end of the argument list, adds an additional seed argument after that, and allows the shape argument to accept a nested list, in which case a BlockArray is returned. Always returns a (result, key) tuple. Original docstring below.
The values are distributed according to the probability density function: