scico.numpy

BlockArray and compatible functions.

This module consists of BlockArray and functions that support both instances of this class and jax arrays. This includes all the functions from jax.numpy and numpy.testing, where many have been extended to automatically map over block array blocks as described in NumPy and SciPy Functions. Also included are additional functions unique to SCICO in util.

Modules

scico.numpy.fft

Discrete Fourier Transform functions.

scico.numpy.linalg

Linear algebra functions.

scico.numpy.testing

Test support functions.

scico.numpy.util

Utility functions for working with jax arrays and BlockArrays.

Functions

abs(x, /)

Alias of jax.numpy.absolute.

absolute(x, /)

Calculate the absolute value element-wise.

add(*args[, out, where])

Add two arrays element-wise.

all(a[, axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

allclose(a, b[, rtol, atol, equal_nan])

Check if two arrays are element-wise approximately equal within a tolerance.

amax(a[, axis, out, keepdims, initial, where])

Alias of jax.numpy.max.

amin(a[, axis, out, keepdims, initial, where])

Alias of jax.numpy.min.

angle(z[, deg])

Return the angle of a complex valued number or array.

any(a[, axis, out, keepdims, where])

Test whether any of the array elements along a given axis evaluate to True.

append(arr, values[, axis])

Return a new array with values appended to the end of the original array.

apply_along_axis(func1d, axis, arr, *args, ...)

Apply a function to 1D array slices along an axis.

apply_over_axes(func, a, axes)

Apply a function repeatedly over specified axes.

arange(start[, stop, step, dtype, device, ...])

Create an array of evenly-spaced values.

arccos(x, /)

Compute element-wise inverse of trigonometric cosine of input.

arccosh(x, /)

Calculate element-wise inverse of hyperbolic cosine of input.

arcsin(x, /)

Compute element-wise inverse of trigonometric sine of input.

arcsinh(x, /)

Calculate element-wise inverse of hyperbolic sine of input.

arctan(x, /)

Compute element-wise inverse of trigonometric tangent of input.

arctan2(x1, x2, /)

Compute the arctangent of x1/x2, choosing the correct quadrant.

arctanh(x, /)

Calculate element-wise inverse of hyperbolic tangent of input.

argmax(a[, axis, out, keepdims])

Return the index of the maximum value of an array.

argmin(a[, axis, out, keepdims])

Return the index of the minimum value of an array.

argsort(a[, axis, kind, order, stable, ...])

Return indices that sort an array.

argwhere(a, *[, size, fill_value])

Find the indices of nonzero array elements

around(a[, decimals, out])

Alias of jax.numpy.round

array(object[, dtype, copy, order, ndmin, ...])

Convert an object to a JAX array.

array_equal(a1, a2[, equal_nan])

Check if two arrays are element-wise equal.

array_equiv(a1, a2)

Check if two arrays are element-wise equal.

array_split(ary, indices_or_sections[, axis])

Split an array into sub-arrays.

asarray(a[, dtype, order, copy, device, ...])

Convert an object to a JAX array.

astype(x, dtype, /, *[, copy, device])

Convert an array to a specified dtype.

atleast_1d(*arys)

Convert inputs to arrays with at least 1 dimension.

atleast_2d(*arys)

Convert inputs to arrays with at least 2 dimensions.

atleast_3d(*arys)

Convert inputs to arrays with at least 3 dimensions.

average(a[, axis, weights, returned, keepdims])

Compute the weighed average.

bartlett(M)

Return a Bartlett window of size M.

bincount(x[, weights, minlength, length])

Count the number of occurrences of each value in an integer array.

blackman(M)

Return a Blackman window of size M.

block(arrays)

Create an array from a list of blocks.

blockarray(iterable)

Construct a BlockArray from a list or tuple of existing array-like.

broadcast_arrays(*args)

Broadcast arrays to a common shape.

broadcast_shapes(*shapes)

Broadcast input shapes to a common output shape.

broadcast_to(array, shape, *[, out_sharding])

Broadcast an array to a specified shape.

cbrt(x, /)

Calculates element-wise cube root of the input array.

ceil(x, /)

Round input to the nearest integer upwards.

choose(a, choices[, out, mode])

Construct an array by stacking slices of choice arrays.

clip([arr, min, max])

Clip array values to a specified range.

column_stack(tup)

Stack arrays column-wise.

compress(condition, a[, axis, size, ...])

Compress an array along a given axis using a boolean condition.

concat(arrays, /, *[, axis])

Join arrays along an existing axis.

concatenate(arrays[, axis, dtype])

Join arrays along an existing axis.

conj(x, /)

Alias of jax.numpy.conjugate

conjugate(x, /)

Return element-wise complex-conjugate of the input.

convolve(a, v[, mode, precision, ...])

Convolution of two one dimensional arrays.

copy(a[, order])

Return a copy of the array.

copysign(x1, x2, /)

Copies the sign of each element in x2 to the corresponding element in x1.

cos(x, /)

Compute a trigonometric cosine of each element of input.

cosh(x, /)

Calculate element-wise hyperbolic cosine of input.

count_nonzero(a[, axis, keepdims])

Return the number of nonzero elements along a given axis.

cross(a, b[, axisa, axisb, axisc, axis])

Compute the (batched) cross product of two arrays.

cumprod(a[, axis, dtype, out])

Cumulative product of elements along an axis.

cumsum(a[, axis, dtype, out])

Cumulative sum of elements along an axis.

cumulative_prod(x, /, *[, axis, dtype, ...])

Cumulative product along the axis of an array.

cumulative_sum(x, /, *[, axis, dtype, ...])

Cumulative sum along the axis of an array.

deg2rad(x, /)

Convert angles from degrees to radians.

degrees(x, /)

Alias of jax.numpy.rad2deg

delete(arr, obj[, axis, assume_unique_indices])

Delete entry or entries from an array.

diag(v[, k])

Returns the specified diagonal or constructs a diagonal array.

diag_indices(n[, ndim])

Return indices for accessing the main diagonal of a multidimensional array.

diag_indices_from(arr)

Return indices for accessing the main diagonal of a given array.

diagflat(v[, k])

Return a 2-D array with the flattened input array laid out on the diagonal.

diff(a[, n, axis, prepend, append])

Calculate n-th order difference between array elements along a given axis.

divide(x1, x2, /)

Alias of jax.numpy.true_divide.

divmod(x1, x2, /)

Calculates the integer quotient and remainder of x1 by x2 element-wise

dot(a, b, *[, precision, ...])

Compute the dot product of two arrays.

dsplit(ary, indices_or_sections)

Split an array into sub-arrays depth-wise.

dstack(tup[, dtype])

Stack arrays depth-wise.

ediff1d(ary[, to_end, to_begin])

Compute the differences of the elements of the flattened array.

einsum(subscripts, /, *operands[, out, ...])

Einstein summation

einsum_path(subscripts, /, *operands[, optimize])

Evaluates the optimal contraction path without evaluating the einsum.

empty(shape[, dtype, device, out_sharding])

Create an empty array.

empty_like(prototype[, dtype, shape, device])

Create an empty array with the same shape and dtype as an array.

equal(x, y, /)

Returns element-wise truth value of x == y.

exp(x, /)

Calculate element-wise exponential of the input.

exp2(x, /)

Calculate element-wise base-2 exponential of input.

expand_dims(a, axis)

Insert dimensions of length 1 into array

expm1(x, /)

Calculate exp(x)-1 of each element of the input.

extract(condition, arr, *[, size, fill_value])

Return the elements of an array that satisfy a condition.

eye(N[, M, k, dtype, device])

Create a square or rectangular identity matrix

fabs(x, /)

Compute the element-wise absolute values of the real-valued input.

fill_diagonal(a, val[, wrap, inplace])

Return a copy of the array with the diagonal overwritten.

flatnonzero(a, *[, size, fill_value])

Return indices of nonzero elements in a flattened array

flip(m[, axis])

Reverse the order of elements of an array along the given axis.

fliplr(m)

Reverse the order of elements of an array along axis 1.

flipud(m)

Reverse the order of elements of an array along axis 0.

float_power(x, y, /)

Calculate element-wise base x exponential of y.

floor(x, /)

Round input to the nearest integer downwards.

floor_divide(x1, x2, /)

Calculates the floor division of x1 by x2 element-wise

fmax(x1, x2)

Return element-wise maximum of the input arrays.

fmin(x1, x2)

Return element-wise minimum of the input arrays.

fmod(x1, x2, /)

Calculate element-wise floating-point modulo operation.

frexp(x, /)

Split floating point values into mantissa and twos exponent.

from_dlpack(x, /, *[, device, copy])

Construct a JAX array via DLPack.

frombuffer(buffer[, dtype, count, offset])

Convert a buffer into a 1-D JAX array.

fromfile(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromfile.

fromfunction(function, shape, *[, dtype])

Create an array from a function applied over indices.

fromiter(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromiter.

frompyfunc(func, /, nin, nout, *[, identity])

Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

fromstring(string[, dtype, count])

Convert a string of text into 1-D JAX array.

full(shape, fill_value[, dtype, device])

Create an array full of a specified value.

full_like(a, fill_value[, dtype, shape, device])

Create an array full of a specified value with the same shape and dtype as an array.

gcd(x1, x2)

Compute the greatest common divisor of two arrays.

geomspace(start, stop[, num, endpoint, ...])

Generate geometrically-spaced values.

get_printoptions()

Alias of numpy.get_printoptions.

gradient(f, *varargs[, axis, edge_order])

Compute the numerical gradient of a sampled function.

greater(x, y, /)

Return element-wise truth value of x > y.

greater_equal(x, y, /)

Return element-wise truth value of x >= y.

hamming(M)

Return a Hamming window of size M.

hanning(M)

Return a Hanning window of size M.

heaviside(x1, x2, /)

Compute the heaviside step function.

histogram(a[, bins, range, weights, density])

Compute a 1-dimensional histogram.

histogram2d(x, y[, bins, range, weights, ...])

Compute a 2-dimensional histogram.

histogram_bin_edges(a[, bins, range, weights])

Compute the bin edges for a histogram.

histogramdd(sample[, bins, range, weights, ...])

Compute an N-dimensional histogram.

hsplit(ary, indices_or_sections)

Split an array into sub-arrays horizontally.

hstack(tup[, dtype])

Horizontally stack arrays.

hypot(x1, x2, /)

Return element-wise hypotenuse for the given legs of a right angle triangle.

i0(x)

Calculate modified Bessel function of first kind, zeroth order.

identity(n[, dtype])

Create a square identity matrix

imag(val, /)

Return element-wise imaginary of part of the complex argument.

indices(dimensions[, dtype, sparse])

Generate arrays of grid indices.

inner(a, b, *[, precision, ...])

Compute the inner product of two arrays.

insert(arr, obj, values[, axis])

Insert entries into an array at specified indices.

interp(x, xp, fp[, left, right, period])

One-dimensional linear interpolation.

intersect1d(ar1, ar2[, assume_unique, ...])

Compute the set intersection of two 1D arrays.

isclose(a, b[, rtol, atol, equal_nan])

Check if the elements of two arrays are approximately equal within a tolerance.

iscomplex(x)

Return boolean array showing where the input is complex.

iscomplexobj(x)

Check if the input is a complex number or an array containing complex elements.

isdtype(dtype, kind)

Returns a boolean indicating whether a provided dtype is of a specified kind.

isfinite(x, /)

Return a boolean array indicating whether each element of input is finite.

isin(element, test_elements[, ...])

Determine whether elements in element appear in test_elements.

isinf(x, /)

Return a boolean array indicating whether each element of input is infinite.

isnan(x, /)

Returns a boolean array indicating whether each element of input is NaN.

isneginf(x, /[, out])

Return boolean array indicating whether each element of input is negative infinite.

isposinf(x, /[, out])

Return boolean array indicating whether each element of input is positive infinite.

isreal(x)

Return boolean array showing where the input is real.

isrealobj(x)

Check if the input is not a complex number or an array containing complex elements.

isscalar(element)

Return True if the input is a scalar.

issubdtype(arg1, arg2)

Return True if arg1 is equal or lower than arg2 in the type hierarchy.

iterable(y)

Check whether or not an object can be iterated over.

ix_(*args)

Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.

kaiser(M, beta)

Return a Kaiser window of size M.

kron(a, b)

Compute the Kronecker product of two input arrays.

lcm(x1, x2)

Compute the least common multiple of two arrays.

ldexp(x1, x2, /)

Compute x1 * 2 ** x2

less(x, y, /)

Return element-wise truth value of x < y.

less_equal(x, y, /)

Return element-wise truth value of x <= y.

lexsort(keys[, axis])

Sort a sequence of keys in lexicographic order.

linspace(start, stop[, num, endpoint, ...])

Return evenly-spaced numbers within an interval.

load(file, *args, **kwargs)

Load JAX arrays from npy files.

log(x, /)

Calculate element-wise natural logarithm of the input.

log10(x, /)

Calculates the base-10 logarithm of x element-wise

log1p(x, /)

Calculates element-wise logarithm of one plus input, log(x+1).

log2(x, /)

Calculates the base-2 logarithm of x element-wise.

logaddexp(*args[, out, where])

Compute log(exp(x1) + exp(x2)) avoiding overflow.

logaddexp2(*args[, out, where])

Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.

logical_and(*args[, out, where])

Compute the logical AND operation elementwise.

logical_not(x, /)

Compute NOT bool(x) element-wise.

logical_or(*args[, out, where])

Compute the logical OR operation elementwise.

logical_xor(*args[, out, where])

Compute the logical XOR operation elementwise.

logspace(start, stop[, num, endpoint, base, ...])

Generate logarithmically-spaced values.

mask_indices(n, mask_func[, k, size])

Return indices of a mask of an (n, n) array.

matmul(a, b, *[, precision, ...])

Perform a matrix multiplication.

matrix_transpose(x, /)

Transpose the last two dimensions of an array.

max(a[, axis, out, keepdims, initial, where])

Return the maximum of the array elements along a given axis.

maximum(*args[, out, where])

Return element-wise maximum of the input arrays.

mean(a[, axis, dtype, out, keepdims, where])

Return the mean of array elements along a given axis.

meshgrid(*xi[, copy, sparse, indexing])

Construct N-dimensional grid arrays from N 1-dimensional vectors.

min(a[, axis, out, keepdims, initial, where])

Return the minimum of array elements along a given axis.

minimum(*args[, out, where])

Return element-wise minimum of the input arrays.

mod(x1, x2, /)

Alias of jax.numpy.remainder

modf(x, /[, out])

Return element-wise fractional and integral parts of the input array.

moveaxis(a, source, destination)

Move an array axis to a new position

multiply(*args[, out, where])

Multiply two arrays element-wise.

nan_to_num(x[, copy, nan, posinf, neginf])

Replace NaN and infinite entries in an array.

nanargmax(a[, axis, out, keepdims])

Return the index of the maximum value of an array, ignoring NaNs.

nanargmin(a[, axis, out, keepdims])

Return the index of the minimum value of an array, ignoring NaNs.

nancumprod(a[, axis, dtype, out])

Cumulative product of elements along an axis, ignoring NaN values.

nancumsum(a[, axis, dtype, out])

Cumulative sum of elements along an axis, ignoring NaN values.

nanmax(a[, axis, out, keepdims, initial, where])

Return the maximum of the array elements along a given axis, ignoring NaNs.

nanmin(a[, axis, out, keepdims, initial, where])

Return the minimum of the array elements along a given axis, ignoring NaNs.

nanprod(a[, axis, dtype, out, keepdims, ...])

Return the product of the array elements along a given axis, ignoring NaNs.

nansum(a[, axis, dtype, out, keepdims, ...])

Return the sum of the array elements along a given axis, ignoring NaNs.

ndim(a)

Return the number of dimensions of an array.

negative(*args[, out, where])

Return element-wise negative values of the input.

nextafter(x, y, /)

Return element-wise next floating point value after x towards y.

nonzero(a, *[, size, fill_value])

Return indices of nonzero elements of an array.

not_equal(x, y, /)

Returns element-wise truth value of x != y.

ones(shape[, dtype, device, out_sharding])

Create an array full of ones.

ones_like(a[, dtype, shape, device, ...])

Create an array of ones with the same shape and dtype as an array.

outer(a, b[, out])

Compute the outer product of two arrays.

pad(array, pad_width[, mode])

Add padding to an array.

partition(a, kth[, axis])

Returns a partially-sorted copy of an array.

permute_dims(a, /, axes)

Permute the axes/dimensions of an array.

piecewise(x, condlist, funclist, *args, **kw)

Evaluate a function defined piecewise across the domain.

place(arr, mask, vals, *[, inplace])

Update array elements based on a mask.

polydiv(u, v, *[, trim_leading_zeros])

Returns the quotient and remainder of polynomial division.

polymul(a1, a2, *[, trim_leading_zeros])

Returns the product of two polynomials.

positive(x, /)

Return element-wise positive values of the input.

pow(x1, x2, /)

Alias of jax.numpy.power

power(x1, x2, /)

Calculate element-wise base x1 exponential of x2.

printoptions(*args, **kwargs)

Alias of numpy.printoptions.

prod(a[, axis, dtype, out, keepdims, ...])

Return product of the array elements over a given axis.

promote_types(a, b)

Returns the type to which a binary operation should cast its arguments.

ptp(a[, axis, out, keepdims])

Return the peak-to-peak range along a given axis.

put(a, ind, v[, mode, inplace])

Put elements into an array at given indices.

rad2deg(x, /)

Convert angles from radians to degrees.

radians(x, /)

Alias of jax.numpy.deg2rad

ravel(ba)

Completely flatten a BlockArray into a single Array.

ravel_multi_index(multi_index, dims[, mode, ...])

Convert multi-dimensional indices into flat indices.

real(val, /)

Return element-wise real part of the complex argument.

reciprocal(x, /)

Calculate element-wise reciprocal of the input.

remainder(x1, x2, /)

Returns element-wise remainder of the division.

repeat(a, repeats[, axis, ...])

Construct an array from repeated elements.

reshape(a, shape[, order, copy, out_sharding])

Return a reshaped copy of an array.

resize(a, new_shape)

Return a new array with specified shape.

result_type(*args)

Return the result of applying JAX promotion rules to the inputs.

rint(x, /)

Rounds the elements of x to the nearest integer

roll(a, shift[, axis])

Roll the elements of an array along a specified axis.

rollaxis(a, axis[, start])

Roll the specified axis to a given position.

roots(p, *[, strip_zeros])

Returns the roots of a polynomial given the coefficients p.

rot90(m[, k, axes])

Rotate an array by 90 degrees counterclockwise in the plane specified by axes.

round(a[, decimals, out])

Round input evenly to the given number of decimals.

searchsorted(a, v[, side, sorter, method])

Perform a binary search within a sorted array.

select(condlist, choicelist[, default])

Select values based on a series of conditions.

set_printoptions(*args, **kwargs)

Alias of numpy.set_printoptions.

setdiff1d(ar1, ar2[, assume_unique, size, ...])

Compute the set difference of two 1D arrays.

setxor1d(ar1, ar2[, assume_unique, size, ...])

Compute the set-wise xor of elements in two arrays.

shape(a)

Return the shape an array.

sign(x, /)

Return an element-wise indication of sign of the input.

signbit(x, /)

Return the sign bit of array elements.

sin(x, /)

Compute a trigonometric sine of each element of input.

sinc(x, /)

Calculate the normalized sinc function.

sinh(x, /)

Calculate element-wise hyperbolic sine of input.

size(a[, axis])

Return number of elements along a given axis.

sort(a[, axis, kind, order, stable, descending])

Return a sorted copy of an array.

sort_complex(a)

Return a sorted copy of complex array.

split(ary, indices_or_sections[, axis])

Split an array into sub-arrays.

sqrt(x, /)

Calculates element-wise non-negative square root of the input array.

square(x, /)

Calculate element-wise square of the input array.

squeeze(a[, axis])

Remove one or more length-1 axes from array

stack(arrays[, axis, out, dtype])

Join arrays along a new axis.

std(a[, axis, dtype, out, ddof, keepdims, ...])

Compute the standard deviation along a given axis.

subtract(*args[, out, where])

Subtract two arrays element-wise.

sum(a[, axis, dtype, out, keepdims, ...])

Sum of the elements of the array over a given axis.

swapaxes(a, axis1, axis2)

Swap two axes of an array.

take(a, indices[, axis, out, mode, ...])

Take elements from an array.

tan(x, /)

Compute a trigonometric tangent of each element of input.

tanh(x, /)

Calculate element-wise hyperbolic tangent of input.

tensordot(a, b[, axes, precision, ...])

Compute the tensor dot product of two N-dimensional arrays.

tile(A, reps)

Construct an array by repeating A along specified dimensions.

trace(a[, offset, axis1, axis2, dtype, out])

Calculate sum of the diagonal of input along the given axes.

transpose(a[, axes])

Return a transposed version of an N-dimensional array.

tri(N[, M, k, dtype])

Return an array with ones on and below the diagonal and zeros elsewhere.

tril_indices(n[, k, m])

Return the indices of lower triangle of an array of size (n, m).

tril_indices_from(arr[, k])

Return the indices of lower triangle of a given array.

trim_zeros(filt[, trim, axis])

Trim leading and/or trailing zeros of the input array.

triu_indices(n[, k, m])

Return the indices of upper triangle of an array of size (n, m).

triu_indices_from(arr[, k])

Return the indices of upper triangle of a given array.

true_divide(x1, x2, /)

Calculates the division of x1 by x2 element-wise

trunc(x)

Round input to the nearest integer towards zero.

union1d(ar1, ar2, *[, size, fill_value])

Compute the set union of two 1D arrays.

unique(ar[, return_index, return_inverse, ...])

Return the unique values from an array.

unique_all(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unique_counts(x, /, *[, size, fill_value])

Return unique values from x, along with counts.

unique_inverse(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unique_values(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unravel_index(indices, shape)

Convert flat indices into multi-dimensional indices.

unwrap(p[, discont, axis, period])

Unwrap a periodic signal.

var(a[, axis, dtype, out, ddof, keepdims, ...])

Compute the variance along a given axis.

vdot(a, b, *[, precision, ...])

Perform a conjugate multiplication of two 1D vectors.

vecdot(x1, x2, /, *[, axis, precision, ...])

Perform a conjugate multiplication of two batched vectors.

vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

vsplit(ary, indices_or_sections)

Split an array into sub-arrays vertically.

vstack(tup[, dtype])

Vertically stack arrays.

where(condition[, x, y, size, fill_value])

Select elements from two arrays based on a condition.

zeros(shape[, dtype, device, out_sharding])

Create an array full of zeros.

zeros_like(a[, dtype, shape, device, ...])

Create an array full of zeros with the same shape and dtype as an array.

Classes

BlockArray(inputs)

Block array class.

class scico.numpy.BlockArray(inputs)

Bases: object

Block array class.

A block array provides a way to combine arrays of different shapes into a single object for use with other SCICO classes. For further information, see the detailed BlockArray documentation.

Example

>>> x = snp.blockarray((
...     [[1, 3, 7],
...      [2, 2, 1]],
...     [2, 4, 8]
... ))
>>> x.shape
((2, 3), (3,))
>>> snp.sum(x)
Array(30, dtype=int32)
property T

Compute the all-axis array transpose.

Refer to jax.numpy.transpose for details.

all(axis=None, out=None, keepdims=False, *, where=None)

Test whether all array elements along a given axis evaluate to True.

Refer to jax.numpy.all for the full documentation.

Return type:

Array

any(axis=None, out=None, keepdims=False, *, where=None)

Test whether any array elements along a given axis evaluate to True.

Refer to jax.numpy.any for the full documentation.

Return type:

Array

argmax(axis=None, out=None, keepdims=None)

Return the index of the maximum value.

Refer to jax.numpy.argmax for the full documentation.

Return type:

Array

argmin(axis=None, out=None, keepdims=None)

Return the index of the minimum value.

Refer to jax.numpy.argmin for the full documentation.

Return type:

Array

argpartition(kth, axis=-1)

Return the indices that partially sort the array.

Refer to jax.numpy.argpartition for the full documentation.

Return type:

Array

argsort(axis=-1, *, kind=None, order=None, stable=True, descending=False)

Return the indices that sort the array.

Refer to jax.numpy.argsort for the full documentation.

Return type:

Array

astype(dtype, copy=False, device=None)

Copy the array and cast to a specified dtype.

This is implemented via jax.lax.convert_element_type, which may have slightly different behavior than numpy.ndarray.astype in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.

Return type:

Array

block_until_ready

(self) -> object

static blockarray(iterable)

Construct a BlockArray from a list or tuple of existing array-like.

byteswap()

Swap the bytes of the array elements.

This switches between a little-endian and big-endian data representation.

Return type:

Array

Returns:

An array with the same dtype as self, with underlying bytes of each entry reversed.

Examples

>>> import jax.numpy as jnp
>>> x = jnp.arange(5, dtype='int32')
>>> x
Array([0, 1, 2, 3, 4], dtype=int32)
>>> x.byteswap()
Array([       0, 16777216, 33554432, 50331648, 67108864], dtype=int32)

When the resulting bytes are viewed as a big-endian dtype (possible in NumPy, but not in JAX) they represent the original values:

>>> import numpy as np
>>> np.array(x.byteswap()).view('>i4')  # view as big-endian
array([0, 1, 2, 3, 4], dtype='>i4')

Calling byteswap twice will return the original array:

>>> x.byteswap().byteswap()
Array([0, 1, 2, 3, 4], dtype=int32)
choose(choices, out=None, mode='raise')

Construct an array choosing from elements of multiple arrays.

Refer to jax.numpy.choose for the full documentation.

Return type:

Array

clip(min=None, max=None)

Return an array whose values are limited to a specified range.

Refer to jax.numpy.clip for full documentation.

Return type:

Array

clone

(self) -> Array

compress(condition, axis=None, *, out=None, size=None, fill_value=0)

Return selected slices of this array along given axis.

Refer to jax.numpy.compress for full documentation.

Return type:

Array

conj()

Return the complex conjugate of the array.

Refer to jax.numpy.conj for the full documentation.

Return type:

Array

conjugate()

Return the complex conjugate of the array.

Refer to jax.numpy.conjugate for the full documentation.

Return type:

Array

copy()

Return a copy of the array.

Refer to jax.numpy.copy for the full documentation.

Return type:

Array

cumprod(axis=None, dtype=None, out=None)

Return the cumulative product of the array.

Refer to jax.numpy.cumprod for the full documentation.

Return type:

Array

cumsum(axis=None, dtype=None, out=None)

Return the cumulative sum of the array.

Refer to jax.numpy.cumsum for the full documentation.

Return type:

Array

delete

(self) -> None

diagonal(offset=0, axis1=0, axis2=1)

Return the specified diagonal from the array.

Refer to jax.numpy.diagonal for the full documentation.

Return type:

Array

dot(b, *, precision=None, preferred_element_type=None)

Compute the dot product of two arrays.

Refer to jax.numpy.dot for the full documentation.

Return type:

Array

property dtype

Return the dtype of the blocks, which must currently be homogeneous.

This allows snp.zeros(x.shape, x.dtype) to work without a mechanism to handle lists of dtypes.

property flat

Use flatten instead.

Type:

Not implemented

flatten(order='C', *, out_sharding=None)

Flatten array into a 1-dimensional shape.

Refer to jax.numpy.ravel for the full documentation.

Return type:

Array

property global_shards

Returns list of all Shards of the Array across all devices.

The result includes shards that are not addressable by the current process. If a Shard is not addressable, then its data will be None.

property imag

Return the imaginary part of the array.

is_deleted

(self) -> bool

property is_fully_addressable

Is this Array fully addressable?

A jax.Array is fully addressable if the current process can address all of the devices named in the Sharding. is_fully_addressable is equivalent to “is_local” in multi-process JAX.

Note that fully replicated is not equal to fully addressable i.e. a jax.Array which is fully replicated can span across multiple hosts and is not fully addressable.

is_ready

(self) -> bool

item(*args)

Copy an element of an array to a standard Python scalar and return it.

Return type:

bool | int | float | complex

property itemsize

Length of one array element in bytes.

property mT

Compute the (batched) matrix transpose.

Refer to jax.numpy.matrix_transpose for details.

max(axis=None, out=None, keepdims=False, initial=None, where=None)

Return the maximum of array elements along a given axis.

Refer to jax.numpy.max for the full documentation.

Return type:

Array

mean(axis=None, dtype=None, out=None, keepdims=False, *, where=None)

Return the mean of array elements along a given axis.

Refer to jax.numpy.mean for the full documentation.

Return type:

Array

min(axis=None, out=None, keepdims=False, initial=None, where=None)

Return the minimum of array elements along a given axis.

Refer to jax.numpy.min for the full documentation.

Return type:

Array

property nbytes

Total bytes consumed by the elements of the array.

nonzero(*, fill_value=None, size=None)

Return indices of nonzero elements of an array.

Refer to jax.numpy.nonzero for the full documentation.

Return type:

tuple[Array, ...]

on_device_size_in_bytes

(self) -> int

platform

(self) -> str

prod(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)

Return product of the array elements over a given axis.

Refer to jax.numpy.prod for the full documentation.

Return type:

Array

ptp(axis=None, out=None, keepdims=False)

Return the peak-to-peak range along a given axis.

Refer to jax.numpy.ptp for the full documentation.

Return type:

Array

ravel(order='C', *, out_sharding=None)

Flatten array into a 1-dimensional shape.

Refer to jax.numpy.ravel for the full documentation.

Return type:

Array

property real

Return the real part of the array.

repeat(repeats, axis=None, *, total_repeat_length=None, out_sharding=None)

Construct an array from repeated elements.

Refer to jax.numpy.repeat for the full documentation.

Return type:

Array

reshape(*args, order='C', out_sharding=None)

Returns an array containing the same data with a new shape.

Refer to jax.numpy.reshape for full documentation.

Return type:

Array

round(decimals=0, out=None)

Round array elements to a given decimal.

Refer to jax.numpy.round for full documentation.

Return type:

Array

searchsorted(v, side='left', sorter=None, *, method='scan')

Perform a binary search within a sorted array.

Refer to jax.numpy.searchsorted for full documentation.

Return type:

Array

sort(axis=-1, *, kind=None, order=None, stable=True, descending=False)

Return a sorted copy of an array.

Refer to jax.numpy.sort for full documentation.

Return type:

Array

squeeze(axis=None)

Remove one or more length-1 axes from array.

Refer to jax.numpy.squeeze for full documentation.

Return type:

Array

stack(axis=0)[source]

Collapse a BlockArray to jax.Array.

Collapse a BlockArray to jax.Array by stacking the blocks on axis axis.

Parameters:

axis – Index of new axis on which blocks are to be stacked.

Returns:

A jax.Array obtained by stacking.

Raises:

ValueError – When called on a BlockArray that is not stackable.

std(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, correction=None)

Compute the standard deviation along a given axis.

Refer to jax.numpy.std for full documentation.

Return type:

Array

sum(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)

Sum of the elements of the array over a given axis.

Refer to jax.numpy.sum for full documentation.

Return type:

Array

swapaxes(axis1, axis2)

Swap two axes of an array.

Refer to jax.numpy.swapaxes for full documentation.

Return type:

Array

take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)

Take elements from an array.

Refer to jax.numpy.take for full documentation.

Return type:

Array

to_device(device, *, stream=None)

Return a copy of the array on the specified device

Parameters:
  • device (Device | Sharding) – Device or Sharding to which the created array will be committed.

  • stream (int | Any | None) – not implemented, passing a non-None value will lead to an error.

Returns:

copy of array placed on the specified device or devices.

trace(offset=0, axis1=0, axis2=1, dtype=None, out=None)

Return the sum along the diagonal.

Refer to jax.numpy.trace for full documentation.

Return type:

Array

transpose(*args)

Returns a copy of the array with axes transposed.

Refer to jax.numpy.transpose for full documentation.

Return type:

Array

unsafe_buffer_pointer

(self) -> int

var(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, correction=None)

Compute the variance along a given axis.

Refer to jax.numpy.var for full documentation.

Return type:

Array

view(dtype=None, type=None)

Return a bitwise copy of the array, viewed as a new dtype.

This is fuller-featured wrapper around jax.lax.bitcast_convert_type.

If the source and target dtype have the same bitwidth, the result has the same shape as the input array. If the bitwidth of the target dtype is different from the source, the size of the last axis of the result is adjusted accordingly.

>>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape
(1, 2, 6)
>>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape
(1, 2, 2)

Conversions involving booleans are not well-defined in all situations. With regards to the shape of result as explained above, booleans are treated as having a bitwidth of 8. However, when converting to a boolean array, the input should only contain 0 or 1 bytes. Otherwise, results may be unpredictable or may change depending on how the result is used.

This conversion is guaranteed and safe:

>>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_)
Array([ True, False,  True], dtype=bool)

However, there are no guarantees about the results of any expression involving a view such as this: jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_). In particular, the results may change between JAX releases and depending on the platform. To safely convert such an array to a boolean array, compare it with 0:

>>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0
Array([ True,  True, False], dtype=bool)
Parameters:
  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – An optional output dtype. If not specified, the output dtype is the same as the input dtype.

  • type (None) – Not implemented; accepted for NumPy compatibility.

Return type:

Array

Returns:

The array, viewed as the new dtype. Unlike NumPy, the array may or may not be a copy of the input array.

scico.numpy.blockarray(iterable)

Construct a BlockArray from a list or tuple of existing array-like.

scico.numpy.ravel(ba)[source]

Completely flatten a BlockArray into a single Array.

When called on an Array, flattens the array.

Parameters:

ba (Union[Array, BlockArray]) – The BlockArray to flatten.

Return type:

Array

Returns:

ba flattened into a single Array.

scico.numpy.abs(x, /)

Alias of jax.numpy.absolute.

Return type:

Array

scico.numpy.absolute(x, /)

Calculate the absolute value element-wise.

JAX implementation of numpy.absolute.

This is the same function as jax.numpy.abs.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array

Return type:

Array

Returns:

An array-like object containing the absolute value of each element in x, with the same shape as x. For complex valued input, \(a + ib\), the absolute value is \(\sqrt{a^2+b^2}\).

Examples

>>> x1 = jnp.array([5, -2, 0, 12])
>>> jnp.absolute(x1)
Array([ 5,  2,  0, 12], dtype=int32)
>>> x2 = jnp.array([[ 8, -3, 1],[ 0, 9, -6]])
>>> jnp.absolute(x2)
Array([[8, 3, 1],
       [0, 9, 6]], dtype=int32)
>>> x3 = jnp.array([8 + 15j, 3 - 4j, -5 + 0j])
>>> jnp.absolute(x3)
Array([17.,  5.,  5.], dtype=float32)
scico.numpy.add(*args: ArrayLike, out: None = None, where: None = None) Any

Add two arrays element-wise.

JAX implementation of numpy.add. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc. This function provides the implementation of the + operator for JAX arrays.

Parameters:
  • x – arrays to add. Must be broadcastable to a common shape.

  • y – arrays to add. Must be broadcastable to a common shape.

Returns:

Array containing the result of the element-wise addition.

Examples

Calling add explicitly:

>>> x = jnp.arange(4)
>>> jnp.add(x, 10)
Array([10, 11, 12, 13], dtype=int32)

Calling add via the + operator:

>>> x + 10
Array([10, 11, 12, 13], dtype=int32)
scico.numpy.all(a, axis=None, out=None, keepdims=False, *, where=None)

Test whether all array elements along a given axis evaluate to True.

JAX implementation of numpy.all.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array.

  • axis (Union[int, Sequence[int], None]) – int or array, default=None. Axis along which to be tested. If None, tests along all the axes.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array of boolean dtype, default=None. The elements to be used in the test. Array should be broadcast compatible to the input.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of boolean values.

Examples

By default, jnp.all tests for True values along all the axes.

>>> x = jnp.array([[True, True, True, False],
...                [True, False, True, False],
...                [True, True, False, False]])
>>> jnp.all(x)
Array(False, dtype=bool)

If axis=0, tests for True values along axis 0.

>>> jnp.all(x, axis=0)
Array([ True, False, False, False], dtype=bool)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.all(x, axis=0, keepdims=True)
Array([[ True, False, False, False]], dtype=bool)

To include specific elements in testing for True values, you can use a``where``.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.all(x, axis=0, keepdims=True, where=where)
Array([[ True,  True, False, False]], dtype=bool)
scico.numpy.allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)

Check if two arrays are element-wise approximately equal within a tolerance.

JAX implementation of numpy.allclose.

Essentially this function evaluates the following condition:

\[|a - b| \le \mathtt{atol} + \mathtt{rtol} * |b|\]

jnp.inf in a will be considered equal to jnp.inf in b.

Parameters:
Return type:

Array

Returns:

Boolean scalar array indicating whether the input arrays are element-wise approximately equal within the specified tolerances.

Examples

>>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]), jnp.array([1e6, 2e6, 3e7]))
Array(False, dtype=bool)
>>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]),
...              jnp.array([1.00008e6, 2.00008e7, 3.00008e8]), rtol=1e3)
Array(True, dtype=bool)
>>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]),
...              jnp.array([1.00001e6, 2.00002e6, 3.00009e6]), atol=1e3)
Array(True, dtype=bool)
>>> jnp.allclose(jnp.array([jnp.nan, 1, 2]),
...              jnp.array([jnp.nan, 1, 2]), equal_nan=True)
Array(True, dtype=bool)
scico.numpy.amax(a, axis=None, out=None, keepdims=False, initial=None, where=None)

Alias of jax.numpy.max.

Return type:

Array

scico.numpy.amin(a, axis=None, out=None, keepdims=False, initial=None, where=None)

Alias of jax.numpy.min.

Return type:

Array

scico.numpy.angle(z, deg=False)

Return the angle of a complex valued number or array.

JAX implementation of numpy.angle.

Parameters:
  • z (Union[Array, ndarray, bool, number, bool, int, float, complex]) – A complex number or an array of complex numbers.

  • deg (bool) – Boolean. If True, returns the result in degrees else returns in radians. Default is False.

Return type:

Array

Returns:

An array of counterclockwise angle of each element of z, with the same shape as z of dtype float.

Examples

If z is a number

>>> z1 = 2+3j
>>> jnp.angle(z1)
Array(0.98279375, dtype=float32, weak_type=True)

If z is an array

>>> z2 = jnp.array([[1+3j, 2-5j],
...                 [4-3j, 3+2j]])
>>> with jnp.printoptions(precision=2, suppress=True):
...     print(jnp.angle(z2))
[[ 1.25 -1.19]
 [-0.64  0.59]]

If deg=True.

>>> with jnp.printoptions(precision=2, suppress=True):
...     print(jnp.angle(z2, deg=True))
[[ 71.57 -68.2 ]
 [-36.87  33.69]]
scico.numpy.any(a, axis=None, out=None, keepdims=False, *, where=None)

Test whether any of the array elements along a given axis evaluate to True.

JAX implementation of numpy.any.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array.

  • axis (Union[int, Sequence[int], None]) – int or array, default=None. Axis along which to be tested. If None, tests along all the axes.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array of boolean dtype, default=None. The elements to be used in the test. Array should be broadcast compatible to the input.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of boolean values.

Examples

By default, jnp.any tests along all the axes.

>>> x = jnp.array([[True, True, True, False],
...                [True, False, True, False],
...                [True, True, False, False]])
>>> jnp.any(x)
Array(True, dtype=bool)

If axis=0, tests along axis 0.

>>> jnp.any(x, axis=0)
Array([ True,  True,  True, False], dtype=bool)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.any(x, axis=0, keepdims=True)
Array([[ True,  True,  True, False]], dtype=bool)

To include specific elements in testing for True values, you can use a``where``.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 1, 0, 1],
...                  [1, 0, 1, 0]], dtype=bool)
>>> jnp.any(x, axis=0, keepdims=True, where=where)
Array([[ True, False,  True, False]], dtype=bool)
scico.numpy.append(arr, values, axis=None)

Return a new array with values appended to the end of the original array.

JAX implementation of numpy.append.

Parameters:
  • arr (Union[Array, ndarray, bool, number, bool, int, float, complex]) – original array.

  • values (Union[Array, ndarray, bool, number, bool, int, float, complex]) – values to be appended to the array. The values must have the same number of dimensions as arr, and all dimensions must match except in the specified axis.

  • axis (int | None) – axis along which to append values. If None (default), both arr and values will be flattened before appending.

Return type:

Array

Returns:

A new array with values appended to arr.

Examples

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.append(a, b)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

Appending along a specific axis:

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([[5, 6]])
>>> jnp.append(a, b, axis=0)
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

Appending along a trailing axis:

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> b = jnp.array([[7], [8]])
>>> jnp.append(a, b, axis=1)
Array([[1, 2, 3, 7],
       [4, 5, 6, 8]], dtype=int32)
scico.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)

Apply a function to 1D array slices along an axis.

JAX implementation of numpy.apply_along_axis. While NumPy implements this iteratively, JAX implements this via jax.vmap, and so func1d must be compatible with vmap.

Parameters:
  • func1d (Callable) – a callable function with signature func1d(arr, /, *args, **kwargs) where *args and **kwargs are the additional positional and keyword arguments passed to apply_along_axis.

  • axis (int) – integer axis along which to apply the function.

  • arr (Union[Array, ndarray, bool, number, bool, int, float, complex]) – the array over which to apply the function.

  • args – additional positional and keyword arguments are passed through to func1d.

  • kwargs – additional positional and keyword arguments are passed through to func1d.

Return type:

Array

Returns:

The result of func1d applied along the specified axis.

See also

Examples

A simple example in two dimensions, where the function is applied either row-wise or column-wise:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> def func1d(x):
...   return jnp.sum(x ** 2)
>>> jnp.apply_along_axis(func1d, 0, x)
Array([17, 29, 45], dtype=int32)
>>> jnp.apply_along_axis(func1d, 1, x)
Array([14, 77], dtype=int32)

For 2D inputs, this can be equivalently expressed using jax.vmap, though note that vmap specifies the mapped axis rather than the applied axis:

>>> jax.vmap(func1d, in_axes=1)(x)  # same as applying along axis 0
Array([17, 29, 45], dtype=int32)
>>> jax.vmap(func1d, in_axes=0)(x)  # same as applying along axis 1
Array([14, 77], dtype=int32)

For 3D inputs, apply_along_axis is equivalent to mapping over two dimensions:

>>> x_3d = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.apply_along_axis(func1d, 2, x_3d)
Array([[  14,  126,  366],
       [ 734, 1230, 1854]], dtype=int32)
>>> jax.vmap(jax.vmap(func1d))(x_3d)
Array([[  14,  126,  366],
       [ 734, 1230, 1854]], dtype=int32)

The applied function may also take arbitrary positional or keyword arguments, which should be passed directly as additional arguments to apply_along_axis:

>>> def func1d(x, exponent):
...   return jnp.sum(x ** exponent)
>>> jnp.apply_along_axis(func1d, 0, x, exponent=3)
Array([ 65, 133, 243], dtype=int32)
scico.numpy.apply_over_axes(func, a, axes)

Apply a function repeatedly over specified axes.

JAX implementation of numpy.apply_over_axes.

Parameters:
Return type:

Array

Returns:

An N-dimensional array containing the result of the repeated function application.

See also

Examples

This function is designed to have similar semantics to typical associative jax.numpy reductions over one or more axes with keepdims=True. For example:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.apply_over_axes(jnp.sum, x, [0])
Array([[5, 7, 9]], dtype=int32)
>>> jnp.sum(x, [0], keepdims=True)
Array([[5, 7, 9]], dtype=int32)
>>> jnp.apply_over_axes(jnp.min, x, [1])
Array([[1],
       [4]], dtype=int32)
>>> jnp.min(x, [1], keepdims=True)
Array([[1],
       [4]], dtype=int32)
>>> jnp.apply_over_axes(jnp.prod, x, [0, 1])
Array([[720]], dtype=int32)
>>> jnp.prod(x, [0, 1], keepdims=True)
Array([[720]], dtype=int32)
scico.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None, out_sharding=None)

Create an array of evenly-spaced values.

JAX implementation of numpy.arange, implemented in terms of jax.lax.iota.

Similar to Python’s range function, this can be called with a few different positional signatures:

  • jnp.arange(stop): generate values from 0 to stop, stepping by 1.

  • jnp.arange(start, stop): generate values from start to stop, stepping by 1.

  • jnp.arange(start, stop, step): generate values from start to stop, stepping by step.

Like with Python’s range function, the starting value is inclusive, and the stop value is exclusive.

Parameters:
Return type:

Array

Returns:

Array of evenly-spaced values from start to stop, separated by step.

Note

Using arange with a floating-point step argument can lead to unexpected results due to accumulation of floating-point errors, especially with lower-precision data types like float8_* and bfloat16. To avoid precision errors, consider generating a range of integers, and scaling it to the desired range. For example, instead of this:

jnp.arange(-1, 1, 0.01, dtype='bfloat16')

it can be more accurate to generate a sequence of integers, and scale them:

(jnp.arange(-100, 100) * 0.01).astype('bfloat16')

Examples

Single-argument version specifies only the stop value:

>>> jnp.arange(4)
Array([0, 1, 2, 3], dtype=int32)

Passing a floating-point stop value leads to a floating-point result:

>>> jnp.arange(4.0)
Array([0., 1., 2., 3.], dtype=float32)

Two-argument version specifies start and stop, with step=1:

>>> jnp.arange(1, 6)
Array([1, 2, 3, 4, 5], dtype=int32)

Three-argument version specifies start, stop, and step:

>>> jnp.arange(0, 2, 0.5)
Array([0. , 0.5, 1. , 1.5], dtype=float32)

See also

scico.numpy.arccos(x, /)

Compute element-wise inverse of trigonometric cosine of input.

JAX implementation of numpy.arccos.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the inverse trigonometric cosine of each element of x in radians in the range [0, pi], promoting to inexact dtype.

Note

  • jnp.arccos returns nan when x is real-valued and not in the closed interval [-1, 1].

  • jnp.arccos follows the branch cut convention of numpy.arccos for complex inputs.

See also

Examples

>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arccos(x)
Array([  nan, 3.142, 2.094, 1.571, 1.047, 0.   ,   nan], dtype=float32)

For complex inputs:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arccos(4-1j)
Array(0.252+2.097j, dtype=complex64, weak_type=True)
scico.numpy.arccosh(x, /)

Calculate element-wise inverse of hyperbolic cosine of input.

JAX implementation of numpy.arccosh.

The inverse of hyperbolic cosine is defined by:

\[arccosh(x) = \ln(x + \sqrt{x^2 - 1})\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array of same shape as x containing the inverse of hyperbolic cosine of each element of x, promoting to inexact dtype.

Note

  • jnp.arccosh returns nan for real-values in the range [-inf, 1).

  • jnp.arccosh follows the branch cut convention of numpy.arccosh for complex inputs.

See also

  • jax.numpy.cosh: Computes the element-wise hyperbolic cosine of the input.

  • jax.numpy.arcsinh: Computes the element-wise inverse of hyperbolic sine of the input.

  • jax.numpy.arctanh: Computes the element-wise inverse of hyperbolic tangent of the input.

Examples

>>> x = jnp.array([[1, 3, -4],
...                [-5, 2, 7]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arccosh(x)
Array([[0.   , 1.763,   nan],
       [  nan, 1.317, 2.634]], dtype=float32)

For complex-valued input:

>>> x1 = jnp.array([-jnp.inf+0j, 1+2j, -5+0j])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arccosh(x1)
Array([  inf+3.142j, 1.529+1.144j, 2.292+3.142j], dtype=complex64)
scico.numpy.arcsin(x, /)

Compute element-wise inverse of trigonometric sine of input.

JAX implementation of numpy.arcsin.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the inverse trigonometric sine of each element of x in radians in the range [-pi/2, pi/2], promoting to inexact dtype.

Note

  • jnp.arcsin returns nan when x is real-valued and not in the closed interval [-1, 1].

  • jnp.arcsin follows the branch cut convention of numpy.arcsin for complex inputs.

See also

Examples

>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arcsin(x)
Array([   nan, -1.571, -0.524,  0.   ,  0.524,  1.571,    nan], dtype=float32)

For complex-valued inputs:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arcsin(3+4j)
Array(0.634+2.306j, dtype=complex64, weak_type=True)
scico.numpy.arcsinh(x, /)

Calculate element-wise inverse of hyperbolic sine of input.

JAX implementation of numpy.arcsinh.

The inverse of hyperbolic sine is defined by:

\[arcsinh(x) = \ln(x + \sqrt{1 + x^2})\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array of same shape as x containing the inverse of hyperbolic sine of each element of x, promoting to inexact dtype.

Note

  • jnp.arcsinh returns nan for values outside the range (-inf, inf).

  • jnp.arcsinh follows the branch cut convention of numpy.arcsinh for complex inputs.

See also

  • jax.numpy.sinh: Computes the element-wise hyperbolic sine of the input.

  • jax.numpy.arccosh: Computes the element-wise inverse of hyperbolic cosine of the input.

  • jax.numpy.arctanh: Computes the element-wise inverse of hyperbolic tangent of the input.

Examples

>>> x = jnp.array([[-2, 3, 1],
...                [4, 9, -5]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arcsinh(x)
Array([[-1.444,  1.818,  0.881],
       [ 2.095,  2.893, -2.312]], dtype=float32)

For complex-valued inputs:

>>> x1 = jnp.array([4-3j, 2j])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arcsinh(x1)
Array([2.306-0.634j, 1.317+1.571j], dtype=complex64)
scico.numpy.arctan(x, /)

Compute element-wise inverse of trigonometric tangent of input.

JAX implement of numpy.arctan.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the inverse trigonometric tangent of each element x in radians in the range [-pi/2, pi/2], promoting to inexact dtype.

Note

jnp.arctan follows the branch cut convention of numpy.arctan for complex inputs.

See also

Examples

>>> x = jnp.array([-jnp.inf, -20, -1, 0, 1, 20, jnp.inf])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arctan(x)
Array([-1.571, -1.521, -0.785,  0.   ,  0.785,  1.521,  1.571], dtype=float32)

For complex-valued inputs:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arctan(2+7j)
Array(1.532+0.133j, dtype=complex64, weak_type=True)
scico.numpy.arctan2(x1, x2, /)

Compute the arctangent of x1/x2, choosing the correct quadrant.

JAX implementation of numpy.arctan2

Parameters:
Return type:

Array

Returns:

The elementwise arctangent of x1 / x2, tracking the correct quadrant.

See also

Examples

Consider a sequence of angles in radians between 0 and \(2\pi\):

>>> theta = jnp.linspace(-jnp.pi, jnp.pi, 9)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(theta)
[-3.14 -2.36 -1.57 -0.79  0.    0.79  1.57  2.36  3.14]

These angles can equivalently be represented by (x, y) coordinates on a unit circle:

>>> x, y = jnp.cos(theta), jnp.sin(theta)

To reconstruct the input angle, we might be tempted to use the identity \(\tan(\theta) = y / x\), and compute \(\theta = \tan^{-1}(y/x)\). Unfortunately, this does not recover the input angle:

>>> with jnp.printoptions(precision=2, suppress=True):  
...   print(jnp.arctan(y / x))
[-0.    0.79  1.57 -0.79  0.    0.79  1.57 -0.79  0.  ]

The problem is that \(y/x\) contains some ambiguity: although \((y, x) = (-1, -1)\) and \((y, x) = (1, 1)\) represent different points in Cartesian space, in both cases \(y / x = 1\), and so the simple arctan approach loses information about which quadrant the angle lies in. arctan2 is built to address this:

>>> with jnp.printoptions(precision=2, suppress=True):
...  print(jnp.arctan2(y, x))
[ 3.14 -2.36 -1.57 -0.79  0.    0.79  1.57  2.36 -3.14]

The results match the input theta, except at the endpoints where \(+\pi\) and \(-\pi\) represent indistinguishable points on the unit circle. By convention, arctan2 always returns values between \(-\pi\) and \(+\pi\) inclusive.

scico.numpy.arctanh(x, /)

Calculate element-wise inverse of hyperbolic tangent of input.

JAX implementation of numpy.arctanh.

The inverse of hyperbolic tangent is defined by:

\[arctanh(x) = \frac{1}{2} [\ln(1 + x) - \ln(1 - x)]\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array of same shape as x containing the inverse of hyperbolic tangent of each element of x, promoting to inexact dtype.

Note

  • jnp.arctanh returns nan for real-values outside the range [-1, 1].

  • jnp.arctanh follows the branch cut convention of numpy.arctanh for complex inputs.

See also

  • jax.numpy.tanh: Computes the element-wise hyperbolic tangent of the input.

  • jax.numpy.arcsinh: Computes the element-wise inverse of hyperbolic sine of the input.

  • jax.numpy.arccosh: Computes the element-wise inverse of hyperbolic cosine of the input.

Examples

>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arctanh(x)
Array([   nan,   -inf, -0.549,  0.   ,  0.549,    inf,    nan], dtype=float32)

For complex-valued input:

>>> x1 = jnp.array([-2+0j, 3+0j, 4-1j])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.arctanh(x1)
Array([-0.549+1.571j,  0.347+1.571j,  0.239-1.509j], dtype=complex64)
scico.numpy.argmax(a, axis=None, out=None, keepdims=None)

Return the index of the maximum value of an array.

JAX implementation of numpy.argmax.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array

  • axis (int | None) – optional integer specifying the axis along which to find the maximum value. If axis is not specified, a will be flattened.

  • out (None) – unused by JAX

  • keepdims (bool | None) – if True, then return an array with the same number of dimensions as a.

Return type:

Array

Returns:

an array containing the index of the maximum value along the specified axis.

See also

Note

When the maximum value occurs more than once along a particular axis, the smallest index is returned.

Examples

>>> x = jnp.array([1, 3, 5, 4, 2])
>>> jnp.argmax(x)
Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, 2],
...                [5, 4, 1]])
>>> jnp.argmax(x, axis=1)
Array([1, 0], dtype=int32)
>>> jnp.argmax(x, axis=1, keepdims=True)
Array([[1],
       [0]], dtype=int32)
scico.numpy.argmin(a, axis=None, out=None, keepdims=None)

Return the index of the minimum value of an array.

JAX implementation of numpy.argmin.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array

  • axis (int | None) – optional integer specifying the axis along which to find the minimum value. If axis is not specified, a will be flattened.

  • out (None) – unused by JAX

  • keepdims (bool | None) – if True, then return an array with the same number of dimensions as a.

Return type:

Array

Returns:

an array containing the index of the minimum value along the specified axis.

Note

When the minimum value occurs more than once along a particular axis, the smallest index is returned.

See also

Examples

>>> x = jnp.array([1, 3, 5, 4, 2])
>>> jnp.argmin(x)
Array(0, dtype=int32)
>>> x = jnp.array([[1, 3, 2],
...                [5, 4, 1]])
>>> jnp.argmin(x, axis=1)
Array([0, 2], dtype=int32)
>>> jnp.argmin(x, axis=1, keepdims=True)
Array([[0],
       [2]], dtype=int32)
scico.numpy.argsort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False, dtype=None)

Return indices that sort an array.

JAX implementation of numpy.argsort.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array to sort

  • axis (int | None) – integer axis along which to sort. Defaults to -1, i.e. the last axis. If None, then a is flattened before being sorted.

  • stable (bool) – boolean specifying whether a stable sort should be used. Default=True.

  • descending (bool) – boolean specifying whether to sort in descending order. Default=False.

  • kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.

  • order (None) – not supported by JAX

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optionally specify the dtype of the resulting indices. If not specified, the default integer dtype will be used.

Return type:

Array

Returns:

Array of indices that sort an array. Returned array will be of shape a.shape (if axis is an integer) or of shape (a.size,) (if axis is None).

Examples

Simple 1-dimensional sort

>>> x = jnp.array([1, 3, 5, 4, 2, 1])
>>> indices = jnp.argsort(x)
>>> indices
Array([0, 5, 4, 1, 3, 2], dtype=int32)
>>> x[indices]
Array([1, 1, 2, 3, 4, 5], dtype=int32)

Sort along the last axis of an array:

>>> x = jnp.array([[2, 1, 3],
...                [6, 4, 3]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 0, 2],
       [2, 1, 0]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1, 2, 3],
       [3, 4, 6]], dtype=int32)

See also

scico.numpy.argwhere(a, *, size=None, fill_value=None)

Find the indices of nonzero array elements

JAX implementation of numpy.argwhere.

jnp.argwhere(x) is essentially equivalent to jnp.column_stack(jnp.nonzero(x)) with special handling for zero-dimensional (i.e. scalar) inputs.

Because the size of the output of argwhere is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument, which specifies the size of the leading dimension of the output - it must be specified statically for jnp.argwhere to be compiled with non-static operands. See jax.numpy.nonzero for a full discussion of size and its semantics.

Parameters:
Return type:

Array

Returns:

a two-dimensional array of shape [size, x.ndim]. If size is not specified as an argument, it is equal to the number of nonzero elements in x.

Examples

Two-dimensional array:

>>> x = jnp.array([[1, 0, 2],
...                [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

Equivalent computation using jax.numpy.column_stack and jax.numpy.nonzero:

>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

Special case for zero-dimensional (i.e. scalar) inputs:

>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)
scico.numpy.around(a, decimals=0, out=None)

Alias of jax.numpy.round

Return type:

Array

scico.numpy.array(object, dtype=None, *args, copy=True, order='K', ndmin=0, device=None, out_sharding=None)

Convert an object to a JAX array.

JAX implementation of numpy.array.

Parameters:
  • object (Any) – an object that is convertible to an array. This includes JAX arrays, NumPy arrays, Python scalars, Python collections like lists and tuples, objects with a __jax_array__ method, and objects supporting the Python buffer protocol.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optionally specify the dtype of the output array. If not specified it will be inferred from the input.

  • copy (bool) – specify whether to force a copy of the input. Default: True.

  • order (str | None) – not implemented in JAX

  • ndmin (int) – integer specifying the minimum number of dimensions in the output array.

  • device (Device | Sharding | None) – optional Device or Sharding to which the created array will be committed.

  • out_sharding (NamedSharding | P | None) – (optional) PartitionSpec or NamedSharding representing the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying both out_sharding and device will result in an error.

Return type:

Array

Returns:

A JAX array constructed from the input.

See also

Examples

Constructing JAX arrays from Python scalars:

>>> jnp.array(True)
Array(True, dtype=bool)
>>> jnp.array(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.array(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.array(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)

Constructing JAX arrays from Python collections:

>>> jnp.array([1, 2, 3])  # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.array([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.array(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)

Constructing JAX arrays from NumPy arrays:

>>> jnp.array(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

Constructing a JAX array via the Python buffer interface, using Python’s built-in array module.

>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.array(pybuffer)
Array([2, 3, 5, 7], dtype=int32)
scico.numpy.array_equal(a1, a2, equal_nan=False)

Check if two arrays are element-wise equal.

JAX implementation of numpy.array_equal.

Parameters:
Return type:

Array

Returns:

Boolean scalar array indicating whether the input arrays are element-wise equal.

Examples

>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3]))
Array(True, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, float('nan')]),
...                 jnp.array([1, 2, float('nan')]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, float('nan')]),
...                 jnp.array([1, 2, float('nan')]), equal_nan=True)
Array(True, dtype=bool)
scico.numpy.array_equiv(a1, a2)

Check if two arrays are element-wise equal.

JAX implementation of numpy.array_equiv.

This function will return False if the input arrays cannot be broadcasted to the same shape.

Parameters:
Return type:

Array

Returns:

Boolean scalar array indicating whether the input arrays are element-wise equal after broadcasting.

Examples

>>> jnp.array_equiv(jnp.array([1, 2, 3]), jnp.array([1, 2, 3]))
Array(True, dtype=bool)
>>> jnp.array_equiv(jnp.array([1, 2, 3]), jnp.array([1, 2, 4]))
Array(False, dtype=bool)
>>> jnp.array_equiv(jnp.array([[1, 2, 3], [1, 2, 3]]),
...                 jnp.array([1, 2, 3]))
Array(True, dtype=bool)
scico.numpy.array_split(ary, indices_or_sections, axis=0)

Split an array into sub-arrays.

JAX implementation of numpy.array_split.

Refer to the documentation of jax.numpy.split for details; array_split is equivalent to split, but allows integer indices_or_sections which does not evenly divide the split axis.

Return type:

list[Array]

Examples

>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> chunks = jnp.array_split(x, 4)
>>> print(*chunks)
[1 2 3] [4 5] [6 7] [8 9]

See also

scico.numpy.asarray(a, dtype=None, order=None, *, copy=None, device=None, out_sharding=None)

Convert an object to a JAX array.

JAX implementation of numpy.asarray.

Parameters:
  • a (Any) – an object that is convertible to an array. This includes JAX arrays, NumPy arrays, Python scalars, Python collections like lists and tuples, objects with a __jax_array__ method, and objects supporting the Python buffer protocol.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optionally specify the dtype of the output array. If not specified it will be inferred from the input.

  • order (str | None) – not implemented in JAX

  • copy (bool | None) – optional boolean specifying the copy mode. If True, then always return a copy. If False, then error if a copy is necessary. Default is None, which will only copy when necessary.

  • device (Device | Sharding | None) – optional Device or Sharding to which the created array will be committed.

Return type:

Array

Returns:

A JAX array constructed from the input.

See also

Examples

Constructing JAX arrays from Python scalars:

>>> jnp.asarray(True)
Array(True, dtype=bool)
>>> jnp.asarray(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.asarray(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.asarray(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)

Constructing JAX arrays from Python collections:

>>> jnp.asarray([1, 2, 3])  # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.asarray([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.asarray(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)

Constructing JAX arrays from NumPy arrays:

>>> jnp.asarray(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

Constructing a JAX array via the Python buffer interface, using Python’s built-in array module.

>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.asarray(pybuffer)
Array([2, 3, 5, 7], dtype=int32)
scico.numpy.astype(x, dtype, /, *, copy=False, device=None)

Convert an array to a specified dtype.

JAX implementation of numpy.astype.

This is implemented via jax.lax.convert_element_type, which may have slightly different behavior than numpy.astype in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.

Parameters:
Return type:

Array

Returns:

An array with the same shape as x, containing values of the specified dtype.

See also

Examples

>>> x = jnp.array([0, 1, 2, 3])
>>> x
Array([0, 1, 2, 3], dtype=int32)
>>> x.astype('float32')
Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0])
>>> y.astype(int)  # truncates fractional values
Array([0, 0, 1], dtype=int32)
scico.numpy.atleast_1d(*arys)

Convert inputs to arrays with at least 1 dimension.

JAX implementation of numpy.atleast_1d.

Parameters:

arguments. (zero or more arraylike)

Return type:

Array | list[Array]

Returns:

an array or list of arrays corresponding to the input values. Arrays of shape () are converted to shape (1,), and arrays with other shapes are returned unchanged.

Examples

Scalar arguments are converted to 1D, length-1 arrays:

>>> x = jnp.float32(1.0)
>>> jnp.atleast_1d(x)
Array([1.], dtype=float32)

Higher dimensional inputs are returned unchanged:

>>> y = jnp.arange(4)
>>> jnp.atleast_1d(y)
Array([0, 1, 2, 3], dtype=int32)

Multiple arguments can be passed to the function at once, in which case a list of results is returned:

>>> jnp.atleast_1d(x, y)
[Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]
scico.numpy.atleast_2d(*arys)

Convert inputs to arrays with at least 2 dimensions.

JAX implementation of numpy.atleast_2d.

Parameters:

arguments. (zero or more arraylike)

Return type:

Array | list[Array]

Returns:

an array or list of arrays corresponding to the input values. Arrays of shape () are converted to shape (1, 1), 1D arrays of shape (N,) are converted to shape (1, N), and arrays of all other shapes are returned unchanged.

Examples

Scalar arguments are converted to 2D, size-1 arrays:

>>> x = jnp.float32(1.0)
>>> jnp.atleast_2d(x)
Array([[1.]], dtype=float32)

One-dimensional arguments have a unit dimension prepended to the shape:

>>> y = jnp.arange(4)
>>> jnp.atleast_2d(y)
Array([[0, 1, 2, 3]], dtype=int32)

Higher dimensional inputs are returned unchanged:

>>> z = jnp.ones((2, 3))
>>> jnp.atleast_2d(z)
Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

Multiple arguments can be passed to the function at once, in which case a list of results is returned:

>>> jnp.atleast_2d(x, y)
[Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]
scico.numpy.atleast_3d(*arys)

Convert inputs to arrays with at least 3 dimensions.

JAX implementation of numpy.atleast_3d.

Parameters:

arguments. (zero or more arraylike)

Return type:

Array | list[Array]

Returns:

an array or list of arrays corresponding to the input values. Arrays of shape () are converted to shape (1, 1, 1), 1D arrays of shape (N,) are converted to shape (1, N, 1), 2D arrays of shape (M, N) are converted to shape (M, N, 1), and arrays of all other shapes are returned unchanged.

Examples

Scalar arguments are converted to 3D, size-1 arrays:

>>> x = jnp.float32(1.0)
>>> jnp.atleast_3d(x)
Array([[[1.]]], dtype=float32)

1D arrays have a unit dimension prepended and appended:

>>> y = jnp.arange(4)
>>> jnp.atleast_3d(y).shape
(1, 4, 1)

2D arrays have a unit dimension appended:

>>> z = jnp.ones((2, 3))
>>> jnp.atleast_3d(z).shape
(2, 3, 1)

Multiple arguments can be passed to the function at once, in which case a list of results is returned:

>>> x3, y3 = jnp.atleast_3d(x, y)
>>> print(x3)
[[[1.]]]
>>> print(y3)
[[[0]
  [1]
  [2]
  [3]]]
scico.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)

Compute the weighed average.

JAX Implementation of numpy.average.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array to be averaged

  • axis (Union[int, Sequence[int], None]) – an optional integer or sequence of integers specifying the axis along which the mean to be computed. If not specified, mean is computed along all the axes.

  • weights (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – an optional array of weights for a weighted average. This must either exactly match the shape of a, or if axis is specified, it must have shape a.shape[axis] for a single axis, or shape tuple(a.shape[ax] for ax in axis) for multiple axes.

  • returned (bool) – If False (default) then return only the average. If True then return both the average and the normalization factor (i.e. the sum of weights).

  • keepdims (bool) – If True, reduced axes are left in the result with size 1. If False (default) then reduced axes are squeezed out.

Return type:

Array | tuple[Array, Array]

Returns:

An array average or tuple of arrays (average, normalization) if returned is True.

See also

Examples

Simple average:

>>> x = jnp.array([1, 2, 3, 2, 4])
>>> jnp.average(x)
Array(2.4, dtype=float32)

Weighted average:

>>> weights = jnp.array([2, 1, 3, 2, 2])
>>> jnp.average(x, weights=weights)
Array(2.5, dtype=float32)

Use returned=True to optionally return the normalization, i.e. the sum of weights:

>>> jnp.average(x, returned=True)
(Array(2.4, dtype=float32), Array(5., dtype=float32))
>>> jnp.average(x, weights=weights, returned=True)
(Array(2.5, dtype=float32), Array(10., dtype=float32))

Weighted average along a specified axis:

>>> x = jnp.array([[8, 2, 7],
...                [3, 6, 4]])
>>> weights = jnp.array([1, 2, 3])
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)
scico.numpy.bartlett(M)

Return a Bartlett window of size M.

JAX implementation of numpy.bartlett.

Parameters:

M (int) – The window size.

Return type:

Array

Returns:

An array of size M containing the Bartlett window.

Examples

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.bartlett(4))
[0.   0.67 0.67 0.  ]

See also

scico.numpy.bincount(x, weights=None, minlength=0, *, length=None)

Count the number of occurrences of each value in an integer array.

JAX implementation of numpy.bincount.

For an array of non-negative integers x, this function returns an array counts of size x.max() + 1, such that counts[i] contains the number of occurrences of the value i in x.

The JAX version has a few differences from the NumPy version:

  • In NumPy, passing an array x with negative entries will result in an error. In JAX, negative values are clipped to zero.

  • JAX adds an optional length parameter which can be used to statically specify the length of the output array so that this function can be used with transformations like jax.jit. In this case, items larger than length + 1 will be dropped.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – 1-dimensional array of non-negative integers

  • weights (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – optional array of weights associated with x. If not specified, the weight for each entry will be 1.

  • minlength (int) – the minimum length of the output counts array.

  • length (int | None) – the length of the output counts array. Must be specified statically for bincount to be used with jax.jit and other JAX transformations.

Return type:

Array

Returns:

An array of counts or summed weights reflecting the number of occurrences of values in x.

Examples

Basic bincount:

>>> x = jnp.array([1, 1, 2, 3, 3, 3])
>>> jnp.bincount(x)
Array([0, 2, 1, 3], dtype=int32)

Weighted bincount:

>>> weights = jnp.array([1, 2, 3, 4, 5, 6])
>>> jnp.bincount(x, weights)
Array([ 0,  3,  3, 15], dtype=int32)

Specifying a static length makes this jit-compatible:

>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length'])
>>> jit_bincount(x, length=5)
Array([0, 2, 1, 3, 0], dtype=int32)

Any negative numbers are clipped to the first bin, and numbers beyond the specified length are dropped:

>>> x = jnp.array([-1, -1, 1, 3, 10])
>>> jnp.bincount(x, length=5)
Array([2, 1, 0, 1, 0], dtype=int32)
scico.numpy.blackman(M)

Return a Blackman window of size M.

JAX implementation of numpy.blackman.

Parameters:

M (int) – The window size.

Return type:

Array

Returns:

An array of size M containing the Blackman window.

Examples

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.blackman(4))
[-0.    0.63  0.63 -0.  ]

See also

scico.numpy.block(arrays)

Create an array from a list of blocks.

JAX implementation of numpy.block.

Parameters:

arrays (Union[Array, ndarray, bool, number, bool, int, float, complex, list[Union[Array, ndarray, bool, number, bool, int, float, complex]]]) – an array, or nested list of arrays which will be concatenated together to form the final array.

Return type:

Array

Returns:

a single array constructed from the inputs.

See also

Examples

consider these blocks:

>>> zeros = jnp.zeros((2, 2))
>>> ones = jnp.ones((2, 2))
>>> twos = jnp.full((2, 2), 2)
>>> threes = jnp.full((2, 2), 3)

Passing a single array to block returns the array:

>>> jnp.block(zeros)
Array([[0., 0.],
       [0., 0.]], dtype=float32)

Passing a simple list of arrays concatenates them along the last axis:

>>> jnp.block([zeros, ones])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.]], dtype=float32)

Passing a doubly-nested list of arrays concatenates the inner list along the last axis, and the outer list along the second-to-last axis:

>>> jnp.block([[zeros, ones],
...            [twos, threes]])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.],
       [2., 2., 3., 3.],
       [2., 2., 3., 3.]], dtype=float32)

Note that blocks need not align in all dimensions, though the size along the axis of concatenation must match. For example, this is valid because after the inner, horizontal concatenation, the resulting blocks have a valid shape for the outer, vertical concatenation.

>>> a = jnp.zeros((2, 1))
>>> b = jnp.ones((2, 3))
>>> c = jnp.full((1, 2), 2)
>>> d = jnp.full((1, 2), 3)
>>> jnp.block([[a, b], [c, d]])
Array([[0., 1., 1., 1.],
       [0., 1., 1., 1.],
       [2., 2., 3., 3.]], dtype=float32)

Note also that this logic generalizes to blocks in 3 or more dimensions. Here’s a 3-dimensional block-wise array:

>>> x = jnp.arange(6).reshape((1, 2, 3))
>>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)]
>>> jnp.block(blocks).shape
(5, 8, 9)
scico.numpy.broadcast_arrays(*args)

Broadcast arrays to a common shape.

JAX implementation of numpy.broadcast_arrays. JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.

Parameters:

args (Union[Array, ndarray, bool, number, bool, int, float, complex]) – zero or more array-like objects to be broadcasted.

Return type:

list[Array]

Returns:

a list of arrays containing broadcasted copies of the inputs.

See also

Examples

>>> x = jnp.arange(3)
>>> y = jnp.int32(1)
>>> jnp.broadcast_arrays(x, y)
[Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)]
>>> x = jnp.array([[1, 2, 3]])
>>> y = jnp.array([[10],
...                [20]])
>>> x2, y2 = jnp.broadcast_arrays(x, y)
>>> x2
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>> y2
Array([[10, 10, 10],
       [20, 20, 20]], dtype=int32)
scico.numpy.broadcast_shapes(*shapes)

Broadcast input shapes to a common output shape.

JAX implementation of numpy.broadcast_shapes. JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.

Parameters:

shapes – 0 or more shapes specified as sequences of integers

Returns:

The broadcasted shape as a tuple of integers.

See also

Examples

Some compatible shapes:

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

Incompatible shapes:

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]
scico.numpy.broadcast_to(array, shape, *, out_sharding=None)

Broadcast an array to a specified shape.

JAX implementation of numpy.broadcast_to. JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.

Parameters:
Return type:

Array

Returns:

a copy of array broadcast to the specified shape.

See also

Examples

>>> x = jnp.int32(1)
>>> jnp.broadcast_to(x, (1, 4))
Array([[1, 1, 1, 1]], dtype=int32)
>>> x = jnp.array([1, 2, 3])
>>> jnp.broadcast_to(x, (2, 3))
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>> x = jnp.array([[2], [4]])
>>> jnp.broadcast_to(x, (2, 4))
Array([[2, 2, 2, 2],
       [4, 4, 4, 4]], dtype=int32)
scico.numpy.cbrt(x, /)

Calculates element-wise cube root of the input array.

JAX implementation of numpy.cbrt.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar. complex dtypes are not supported.

Return type:

Array

Returns:

An array containing the cube root of the elements of x.

See also

  • jax.numpy.sqrt: Calculates the element-wise non-negative square root of the input.

  • jax.numpy.square: Calculates the element-wise square of the input.

Examples

>>> x = jnp.array([[216, 125, 64],
...                [-27, -8, -1]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.cbrt(x)
Array([[ 6.,  5.,  4.],
       [-3., -2., -1.]], dtype=float32)
scico.numpy.ceil(x, /)

Round input to the nearest integer upwards.

JAX implementation of numpy.ceil.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar. Must not have complex dtype.

Return type:

Array

Returns:

An array with same shape and dtype as x containing the values rounded to the nearest integer that is greater than or equal to the value itself.

See also

  • jax.numpy.fix: Rounds the input to the nearest integer towards zero.

  • jax.numpy.trunc: Rounds the input to the nearest integer towards zero.

  • jax.numpy.floor: Rounds the input down to the nearest integer.

Examples

>>> key = jax.random.key(1)
>>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
>>> with jnp.printoptions(precision=2, suppress=True):
...     print(x)
[[-0.61  0.34 -0.54]
 [-0.62  3.97  0.59]
 [ 4.84  3.42 -1.14]]
>>> jnp.ceil(x)
Array([[-0.,  1., -0.],
       [-0.,  4.,  1.],
       [ 5.,  4., -1.]], dtype=float32)
scico.numpy.choose(a, choices, out=None, mode='raise')

Construct an array by stacking slices of choice arrays.

JAX implementation of numpy.choose.

The semantics of this function can be confusing, but in the simplest case where a is a one-dimensional array, choices is a two-dimensional array, and all entries of a are in-bounds (i.e. 0 <= a_i < len(choices)), then the function is equivalent to the following:

def choose(a, choices):
  return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])

In the more general case, a may have any number of dimensions and choices may be an arbitrary sequence of broadcast-compatible arrays. In this case, again for in-bound indices, the logic is equivalent to:

def choose(a, choices):
  a, *choices = jnp.broadcast_arrays(a, *choices)
  choices = jnp.array(choices)
  return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])

The only additional complexity comes from the mode argument, which controls the behavior for out-of-bound indices in a as described below.

Parameters:
Return type:

Array

Returns:

an array containing stacked slices from choices at the indices specified by a. The shape of the result is broadcast_shapes(a.shape, *(c.shape for c in choices)).

See also

Examples

Here is the simplest case of a 1D index array with a 2D choice array, in which case this chooses the indexed value from each column:

>>> choices = jnp.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [ 9, 10, 11, 12]])
>>> a = jnp.array([2, 0, 1, 0])
>>> jnp.choose(a, choices)
Array([9, 2, 7, 4], dtype=int32)

The mode argument specifies what to do with out-of-bound indices; options are to either wrap or clip:

>>> a2 = jnp.array([2, 0, 1, 4])  # last index out-of-bound
>>> jnp.choose(a2, choices, mode='clip')
Array([ 9,  2,  7, 12], dtype=int32)
>>> jnp.choose(a2, choices, mode='wrap')
Array([9, 2, 7, 8], dtype=int32)

In the more general case, choices may be a sequence of array-like objects with any broadcast-compatible shapes.

>>> choice_1 = jnp.array([1, 2, 3, 4])
>>> choice_2 = 99
>>> choice_3 = jnp.array([[10],
...                       [20],
...                       [30]])
>>> a = jnp.array([[0, 1, 2, 0],
...                [1, 2, 0, 1],
...                [2, 0, 1, 2]])
>>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap')
Array([[ 1, 99, 10,  4],
       [99, 20,  3, 99],
       [30,  2, 99, 30]], dtype=int32)
scico.numpy.clip(arr=None, /, min=None, max=None)

Clip array values to a specified range.

JAX implementation of numpy.clip.

Parameters:
Return type:

Array

Returns:

An array containing values from arr, with values smaller than min set to min, and values larger than max set to max. Wherever min is larger than max, the value of max is returned.

See also

  • jax.numpy.minimum: Compute the element-wise minimum value of two arrays.

  • jax.numpy.maximum: Compute the element-wise maximum value of two arrays.

Examples

>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
>>> jnp.clip(arr, 2, 5)
Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)
scico.numpy.column_stack(tup)

Stack arrays column-wise.

JAX implementation of numpy.column_stack.

For arrays of two or more dimensions, this is equivalent to jax.numpy.concatenate with axis=1.

Parameters:
  • tup (ndarray | Array | Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – a sequence of arrays to stack; each must have the same leading dimension. Input arrays will be promoted to at least rank 2. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.

  • dtype – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Return type:

Array

Returns:

the stacked result.

See also

Examples

Scalar values:

>>> jnp.column_stack([1, 2, 3])
Array([[1, 2, 3]], dtype=int32, weak_type=True)

1D arrays:

>>> x = jnp.arange(3)
>>> y = jnp.ones(3)
>>> jnp.column_stack([x, y])
Array([[0., 1.],
       [1., 1.],
       [2., 1.]], dtype=float32)

2D arrays:

>>> x = x.reshape(3, 1)
>>> y = y.reshape(3, 1)
>>> jnp.column_stack([x, y])
Array([[0., 1.],
       [1., 1.],
       [2., 1.]], dtype=float32)
scico.numpy.compress(condition, a, axis=None, *, size=None, fill_value=0, out=None)

Compress an array along a given axis using a boolean condition.

JAX implementation of numpy.compress.

Parameters:
Return type:

Array

Returns:

An array of dimension a.ndim, compressed along the specified axis.

See also

Notes

This function does not require strict shape agreement between condition and a. If condition.size > a.shape[axis], then condition will be truncated, and if a.shape[axis] > condition.size, then a will be truncated.

Examples

Compressing along the rows of a 2D array:

>>> a = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9,  10, 11, 12]])
>>> condition = jnp.array([True, False, True])
>>> jnp.compress(condition, a, axis=0)
Array([[ 1,  2,  3,  4],
       [ 9, 10, 11, 12]], dtype=int32)

For convenience, you can equivalently use the compress method of JAX arrays:

>>> a.compress(condition, axis=0)
Array([[ 1,  2,  3,  4],
       [ 9, 10, 11, 12]], dtype=int32)

Note that the condition need not match the shape of the specified axis; here we compress the columns with the length-3 condition. Values beyond the size of the condition are ignored:

>>> jnp.compress(condition, a, axis=1)
Array([[ 1,  3],
       [ 5,  7],
       [ 9, 11]], dtype=int32)

The optional size argument lets you specify a static output size so that the output is statically-shaped, and so this function can be used with transformations like jit and vmap:

>>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0)
>>> mask = (a % 3 == 0)
>>> jax.vmap(f)(mask, a)
Array([[ 3,  0,  0,  0],
       [ 6,  0,  0,  0],
       [ 9, 12,  0,  0]], dtype=int32)
scico.numpy.concat(arrays, /, *, axis=0)

Join arrays along an existing axis.

JAX implementation of array_api.concat.

Parameters:
  • arrays (Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – a sequence of arrays to concatenate; each must have the same shape except along the specified axis. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.

  • axis (int | None) – specify the axis along which to concatenate.

Return type:

Array

Returns:

the concatenated result.

See also

Examples

One-dimensional concatenation:

>>> x = jnp.arange(3)
>>> y = jnp.zeros(3, dtype=int)
>>> jnp.concat([x, y])
Array([0, 1, 2, 0, 0, 0], dtype=int32)

Two-dimensional concatenation:

>>> x = jnp.ones((2, 3))
>>> y = jnp.zeros((2, 1))
>>> jnp.concat([x, y], axis=1)
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.]], dtype=float32)
scico.numpy.concatenate(arrays, axis=0, dtype=None)

Join arrays along an existing axis.

JAX implementation of numpy.concatenate.

Parameters:
  • arrays (ndarray | Array | Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – a sequence of arrays to concatenate; each must have the same shape except along the specified axis. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.

  • axis (int | None) – specify the axis along which to concatenate. If None, the arrays are flattened before concatenation.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Return type:

Array

Returns:

the concatenated result.

See also

Examples

One-dimensional concatenation:

>>> x = jnp.arange(3)
>>> y = jnp.zeros(3, dtype=int)
>>> jnp.concatenate([x, y])
Array([0, 1, 2, 0, 0, 0], dtype=int32)

Two-dimensional concatenation:

>>> x = jnp.ones((2, 3))
>>> y = jnp.zeros((2, 1))
>>> jnp.concatenate([x, y], axis=1)
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.]], dtype=float32)
scico.numpy.conj(x, /)

Alias of jax.numpy.conjugate

Return type:

Array

scico.numpy.conjugate(x, /)

Return element-wise complex-conjugate of the input.

JAX implementation of numpy.conjugate.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – inpuat array or scalar.

Return type:

Array

Returns:

An array containing the complex-conjugate of x.

See also

  • jax.numpy.real: Returns the element-wise real part of the complex argument.

  • jax.numpy.imag: Returns the element-wise imaginary part of the complex argument.

Examples

>>> jnp.conjugate(3)
Array(3, dtype=int32, weak_type=True)
>>> x = jnp.array([2-1j, 3+5j, 7])
>>> jnp.conjugate(x)
Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64)
scico.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)

Convolution of two one dimensional arrays.

JAX implementation of numpy.convolve.

Convolution of one dimensional arrays is defined as:

\[c_k = \sum_j a_{k - j} v_j\]
Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – left-hand input to the convolution. Must have a.ndim == 1.

  • v (Union[Array, ndarray, bool, number, bool, int, float, complex]) – right-hand input to the convolution. Must have v.ndim == 1.

  • mode (str) –

    controls the size of the output. Available operations are:

    • "full": (default) output the full convolution of the inputs.

    • "same": return a centered portion of the "full" output which is the same size as a.

    • "valid": return the portion of the "full" output which do not depend on padding at the array edges.

  • precision (Union[None, str, Precision, tuple[str, str], tuple[Precision, Precision], DotAlgorithm, DotAlgorithmPreset]) – Specify the precision of the computation. Refer to jax.lax.Precision for a description of available values.

  • preferred_element_type (Union[str, type[Any], dtype, SupportsDType, None]) – A datatype, indicating to accumulate results to and return a result with that datatype. Default is None, which means the default accumulation type for the input types.

Return type:

Array

Returns:

Array containing the convolved result.

See also

Examples

A few 1D convolution examples:

>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([4, 1, 2])

jax.numpy.convolve, by default, returns full convolution using implicit zero-padding at the edges:

>>> jnp.convolve(x, y)
Array([ 4.,  9., 16., 15., 12.,  5.,  2.], dtype=float32)

Specifying mode = 'same' returns a centered convolution the same size as the first input:

>>> jnp.convolve(x, y, mode='same')
Array([ 9., 16., 15., 12.,  5.], dtype=float32)

Specifying mode = 'valid' returns only the portion where the two arrays fully overlap:

>>> jnp.convolve(x, y, mode='valid')
Array([16., 15., 12.], dtype=float32)

For complex-valued inputs:

>>> x1 = jnp.array([3+1j, 2, 4-3j])
>>> y1 = jnp.array([1, 2-3j, 4+5j])
>>> jnp.convolve(x1, y1)
Array([ 3. +1.j, 11. -7.j, 15.+10.j,  7. -8.j, 31. +8.j], dtype=complex64)
scico.numpy.copy(a, order=None)

Return a copy of the array.

JAX implementation of numpy.copy.

Parameters:
Return type:

Array

Returns:

a copy of the input array a.

See also

Examples

Since JAX arrays are immutable, in most cases explicit array copies are not necessary. One exception is when using a function with donated arguments (see the donate_argnums argument to jax.jit).

>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0)
>>> x = jnp.arange(4)
>>> y = f(x)
>>> print(y)
[0 2 4 6]

Because we marked x as being donated, the original array is no longer available:

>>> print(x)  
Traceback (most recent call last):
RuntimeError: Array has been deleted with shape=int32[4].

In situations like this, an explicit copy will let you keep access to the original buffer:

>>> x = jnp.arange(4)
>>> y = f(x.copy())
>>> print(y)
[0 2 4 6]
>>> print(x)
[0 1 2 3]
scico.numpy.copysign(x1, x2, /)

Copies the sign of each element in x2 to the corresponding element in x1.

JAX implementation of numpy.copysign.

Parameters:
Return type:

Array

Returns:

An array object containing the potentially changed elements of x1, always promotes to inexact dtype, and has a shape of jnp.broadcast_shapes(x1.shape, x2.shape)

Examples

>>> x1 = jnp.array([5, 2, 0])
>>> x2 = -1
>>> jnp.copysign(x1, x2)
Array([-5., -2., -0.], dtype=float32)
>>> x1 = jnp.array([6, 8, 0])
>>> x2 = 2
>>> jnp.copysign(x1, x2)
Array([6., 8., 0.], dtype=float32)
>>> x1 = jnp.array([2, -3])
>>> x2 = jnp.array([[1],[-4], [5]])
>>> jnp.copysign(x1, x2)
Array([[ 2.,  3.],
       [-2., -3.],
       [ 2.,  3.]], dtype=float32)
scico.numpy.cos(x, /)

Compute a trigonometric cosine of each element of input.

JAX implementation of numpy.cos.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Angle in radians.

Return type:

Array

Returns:

An array containing the cosine of each element in x, promotes to inexact dtype.

See also

Examples

>>> pi = jnp.pi
>>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6])
>>> with jnp.printoptions(precision=3, suppress=True):
...   print(jnp.cos(x))
[ 0.707 -0.    -0.707 -0.866]
scico.numpy.cosh(x, /)

Calculate element-wise hyperbolic cosine of input.

JAX implementation of numpy.cosh.

The hyperbolic cosine is defined by:

\[cosh(x) = \frac{e^x + e^{-x}}{2}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the hyperbolic cosine of each element of x, promoting to inexact dtype.

Note

jnp.cosh is equivalent to computing jnp.cos(1j * x).

See also

  • jax.numpy.sinh: Computes the element-wise hyperbolic sine of the input.

  • jax.numpy.tanh: Computes the element-wise hyperbolic tangent of the input.

  • jax.numpy.arccosh: Computes the element-wise inverse of hyperbolic cosine of the input.

Examples

>>> x = jnp.array([[3, -1, 0],
...                [4, 7, -5]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.cosh(x)
Array([[ 10.068,   1.543,   1.   ],
       [ 27.308, 548.317,  74.21 ]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.cos(1j * x)
Array([[ 10.068+0.j,   1.543+0.j,   1.   +0.j],
       [ 27.308+0.j, 548.317+0.j,  74.21 +0.j]],      dtype=complex64, weak_type=True)

For complex-valued input:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.cosh(5+1j)
Array(40.096+62.44j, dtype=complex64, weak_type=True)
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.cos(1j * (5+1j))
Array(40.096+62.44j, dtype=complex64, weak_type=True)
scico.numpy.count_nonzero(a, axis=None, keepdims=False)

Return the number of nonzero elements along a given axis.

JAX implementation of numpy.count_nonzero.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array.

  • axis (Union[int, Sequence[int], None]) – optional, int or sequence of ints, default=None. Axis along which the number of nonzeros are counted. If None, counts within the flattened array.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

Return type:

Array

Returns:

An array with number of nonzeros elements along specified axis of the input.

Examples

By default, jnp.count_nonzero counts the nonzero values along all axes.

>>> x = jnp.array([[1, 0, 0, 0],
...                [0, 0, 1, 0],
...                [1, 1, 1, 0]])
>>> jnp.count_nonzero(x)
Array(5, dtype=int32)

If axis=1, counts along axis 1.

>>> jnp.count_nonzero(x, axis=1)
Array([1, 1, 3], dtype=int32)

To preserve the dimensions of input, you can set keepdims=True.

>>> jnp.count_nonzero(x, axis=1, keepdims=True)
Array([[1],
       [1],
       [3]], dtype=int32)
scico.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)

Compute the (batched) cross product of two arrays.

JAX implementation of numpy.cross.

This computes the 2-dimensional or 3-dimensional cross product,

\[c = a \times b\]

In 3 dimensions, c is a length-3 array. In 2 dimensions, c is a scalar.

Parameters:
  • a – N-dimensional array. a.shape[axisa] indicates the dimension of the cross product, and must be 2 or 3.

  • b – N-dimensional array. Must have b.shape[axisb] == a.shape[axisb], and other dimensions of a and b must be broadcast compatible.

  • axisa (int) – specicy the axis of a along which to compute the cross product.

  • axisb (int) – specicy the axis of b along which to compute the cross product.

  • axisc (int) – specicy the axis of c along which the cross product result will be stored.

  • axis (int | None) – if specified, this overrides axisa, axisb, and axisc with a single value.

Returns:

The array c containing the (batched) cross product of a and b along the specified axes.

See also

Examples

A 2-dimensional cross product returns a scalar:

>>> a = jnp.array([1, 2])
>>> b = jnp.array([3, 4])
>>> jnp.cross(a, b)
Array(-2, dtype=int32)

A 3-dimensional cross product returns a length-3 vector:

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.cross(a, b)
Array([-3,  6, -3], dtype=int32)

With multi-dimensional inputs, the cross-product is computed along the last axis by default. Here’s a batched 3-dimensional cross product, operating on the rows of the inputs:

>>> a = jnp.array([[1, 2, 3],
...                [3, 4, 3]])
>>> b = jnp.array([[2, 3, 2],
...                [4, 5, 6]])
>>> jnp.cross(a, b)
Array([[-5,  4, -1],
       [ 9, -6, -1]], dtype=int32)

Specifying axis=0 makes this a batched 2-dimensional cross product, operating on the columns of the inputs:

>>> jnp.cross(a, b, axis=0)
Array([-2, -2, 12], dtype=int32)

Equivalently, we can independently specify the axis of the inputs a and b and the output c:

>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
Array([-2, -2, 12], dtype=int32)
scico.numpy.cumprod(a, axis=None, dtype=None, out=None)

Cumulative product of elements along an axis.

JAX implementation of numpy.cumprod.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array to be accumulated.

  • axis (int | None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.

  • out (None) – unused by JAX

Return type:

Array

Returns:

An array containing the accumulated product along the given axis.

See also

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.cumprod(x)  # flattened cumulative product
Array([  1,   2,   6,  24, 120, 720], dtype=int32)
>>> jnp.cumprod(x, axis=1)  # cumulative product along axis 1
Array([[  1,   2,   6],
       [  4,  20, 120]], dtype=int32)
scico.numpy.cumsum(a, axis=None, dtype=None, out=None)

Cumulative sum of elements along an axis.

JAX implementation of numpy.cumsum.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array to be accumulated.

  • axis (int | None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.

  • out (None) – unused by JAX

Return type:

Array

Returns:

An array containing the accumulated sum along the given axis.

See also

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.cumsum(x)  # flattened cumulative sum
Array([ 1,  3,  6, 10, 15, 21], dtype=int32)
>>> jnp.cumsum(x, axis=1)  # cumulative sum along axis 1
Array([[ 1,  3,  6],
       [ 4,  9, 15]], dtype=int32)
scico.numpy.cumulative_prod(x, /, *, axis=None, dtype=None, include_initial=False)

Cumulative product along the axis of an array.

JAX implementation of numpy.cumulative_prod.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array

  • axis (int | None) – integer axis along which to accumulate. If x is one-dimensional, this argument is optional and defaults to zero.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype of the output.

  • include_initial (bool) – if True, then include the initial value in the cumulative product. Default is False.

Return type:

Array

Returns:

An array containing the accumulated values.

See also

  • jax.numpy.cumprod: alternative API for cumulative product.

  • jax.numpy.nancumprod: cumulative product while ignoring NaN values.

  • jax.numpy.multiply.accumulate: cumulative product via the ufunc API.

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.cumulative_prod(x, axis=1)
Array([[  1,   2,   6],
       [  4,  20, 120]], dtype=int32)
>>> jnp.cumulative_prod(x, axis=1, include_initial=True)
Array([[  1,   1,   2,   6],
       [  1,   4,  20, 120]], dtype=int32)
scico.numpy.cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False)

Cumulative sum along the axis of an array.

JAX implementation of numpy.cumulative_sum.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array

  • axis (int | None) – integer axis along which to accumulate. If x is one-dimensional, this argument is optional and defaults to zero.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype of the output.

  • include_initial (bool) – if True, then include the initial value in the cumulative sum. Default is False.

Return type:

Array

Returns:

An array containing the accumulated values.

See also

  • jax.numpy.cumsum: alternative API for cumulative sum.

  • jax.numpy.nancumsum: cumulative sum while ignoring NaN values.

  • jax.numpy.add.accumulate: cumulative sum via the ufunc API.

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.cumulative_sum(x, axis=1)
Array([[ 1,  3,  6],
       [ 4,  9, 15]], dtype=int32)
>>> jnp.cumulative_sum(x, axis=1, include_initial=True)
Array([[ 0,  1,  3,  6],
       [ 0,  4,  9, 15]], dtype=int32)
scico.numpy.deg2rad(x, /)

Convert angles from degrees to radians.

JAX implementation of numpy.deg2rad.

The angle in degrees is converted to radians by:

\[deg2rad(x) = x * \frac{pi}{180}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Specifies the angle in degrees.

Return type:

Array

Returns:

An array containing the angles in radians.

See also

Examples

>>> x = jnp.array([60, 90, 120, 180])
>>> jnp.deg2rad(x)
Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32)
>>> x * jnp.pi / 180
Array([1.0471976, 1.5707964, 2.0943952, 3.1415927],      dtype=float32, weak_type=True)
scico.numpy.degrees(x, /)

Alias of jax.numpy.rad2deg

Return type:

Array

scico.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)

Delete entry or entries from an array.

JAX implementation of numpy.delete.

Parameters:
  • arr (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array from which entries will be deleted.

  • obj (Union[Array, ndarray, bool, number, bool, int, float, complex, slice]) – index, indices, or slice to be deleted.

  • axis (int | None) – axis along which entries will be deleted.

  • assume_unique_indices (bool) – In case of array-like integer (not boolean) indices, assume the indices are unique, and perform the deletion in a way that is compatible with JIT and other JAX transformations.

Return type:

Array

Returns:

Copy of arr with specified indices deleted.

Note

delete() usually requires the index specification to be static. If the index is an integer array that is guaranteed to contain unique entries, you may specify assume_unique_indices=True to perform the operation in a manner that does not require static indices.

See also

Examples

Delete entries from a 1D array:

>>> a = jnp.array([4, 5, 6, 7, 8, 9])
>>> jnp.delete(a, 2)
Array([4, 5, 7, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(1, 4))  # delete a[1:4]
Array([4, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(None, None, 2))  # delete a[::2]
Array([5, 7, 9], dtype=int32)

Delete entries from a 2D array along a specified axis:

>>> a2 = jnp.array([[4, 5, 6],
...                 [7, 8, 9]])
>>> jnp.delete(a2, 1, axis=1)
Array([[4, 6],
       [7, 9]], dtype=int32)

Delete multiple entries via a sequence of indices:

>>> indices = jnp.array([0, 1, 3])
>>> jnp.delete(a, indices)
Array([6, 8, 9], dtype=int32)

This will fail under jit and other transformations, because the output shape cannot be known with the possibility of duplicate indices:

>>> jax.jit(jnp.delete)(a, indices)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].

If you can ensure that the indices are unique, pass assume_unique_indices to allow this to be executed under JIT:

>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices'])
>>> jit_delete(a, indices, assume_unique_indices=True)
Array([6, 8, 9], dtype=int32)
scico.numpy.diag(v, k=0)

Returns the specified diagonal or constructs a diagonal array.

JAX implementation of numpy.diag.

The JAX version always returns a copy of the input, although if this is used within a JIT compilation, the compiler may avoid the copy.

Parameters:
  • v (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array. Can be a 1-D array to create a diagonal matrix or a 2-D array to extract a diagonal.

  • k (int) – optional, default=0. Diagonal offset. Positive values place the diagonal above the main diagonal, negative values place it below the main diagonal.

Return type:

Array

Returns:

If v is a 2-D array, a 1-D array containing the diagonal elements. If v is a 1-D array, a 2-D array with the input elements placed along the specified diagonal.

Examples

Creating a diagonal matrix from a 1-D array:

>>> jnp.diag(jnp.array([1, 2, 3]))
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)

Specifying a diagonal offset:

>>> jnp.diag(jnp.array([1, 2, 3]), k=1)
Array([[0, 1, 0, 0],
       [0, 0, 2, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 0]], dtype=int32)

Extracting a diagonal from a 2-D array:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.diag(x)
Array([1, 5, 9], dtype=int32)
scico.numpy.diag_indices(n, ndim=2)

Return indices for accessing the main diagonal of a multidimensional array.

JAX implementation of numpy.diag_indices.

Parameters:
  • n (int) – int. The size of each dimension of the square array.

  • ndim (int) – optional, int, default=2. The number of dimensions of the array.

Return type:

tuple[Array, ...]

Returns:

A tuple of arrays, each of length n, containing the indices to access the main diagonal.

Examples

>>> jnp.diag_indices(3)
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
>>> jnp.diag_indices(4, ndim=3)
(Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32))
scico.numpy.diag_indices_from(arr)

Return indices for accessing the main diagonal of a given array.

JAX implementation of numpy.diag_indices_from.

Parameters:

arr (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array. Must be at least 2-dimensional and have equal length along all dimensions.

Return type:

tuple[Array, ...]

Returns:

A tuple of arrays containing the indices to access the main diagonal of the input array.

Examples

>>> arr = jnp.array([[1, 2, 3],
...                  [4, 5, 6],
...                  [7, 8, 9]])
>>> jnp.diag_indices_from(arr)
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
>>> arr = jnp.array([[[1, 2], [3, 4]],
...                  [[5, 6], [7, 8]]])
>>> jnp.diag_indices_from(arr)
(Array([0, 1], dtype=int32),
Array([0, 1], dtype=int32),
Array([0, 1], dtype=int32))
scico.numpy.diagflat(v, k=0)

Return a 2-D array with the flattened input array laid out on the diagonal.

JAX implementation of numpy.diagflat.

This differs from np.diagflat for some scalar values of v. JAX always returns a two-dimensional array, whereas NumPy may return a scalar depending on the type of v.

Parameters:
  • v (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array. Can be N-dimensional but is flattened to 1D.

  • k (int) – optional, default=0. Diagonal offset. Positive values place the diagonal above the main diagonal, negative values place it below the main diagonal.

Return type:

Array

Returns:

A 2D array with the input elements placed along the diagonal with the specified offset (k). The remaining entries are filled with zeros.

Examples

>>> jnp.diagflat(jnp.array([1, 2, 3]))
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)
>>> jnp.diagflat(jnp.array([1, 2, 3]), k=1)
Array([[0, 1, 0, 0],
       [0, 0, 2, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 0]], dtype=int32)
>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.diagflat(a)
Array([[1, 0, 0, 0],
       [0, 2, 0, 0],
       [0, 0, 3, 0],
       [0, 0, 0, 4]], dtype=int32)
scico.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)

Calculate n-th order difference between array elements along a given axis.

JAX implementation of numpy.diff.

The first order difference is computed by a[i+1] - a[i], and the n-th order difference is computed n times recursively.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array. Must have a.ndim >= 1.

  • n (int) – int, optional, default=1. Order of the difference. Specifies the number of times the difference is computed. If n=0, no difference is computed and input is returned as is.

  • axis (int) – int, optional, default=-1. Specifies the axis along which the difference is computed. The difference is computed along axis -1 by default.

  • prepend (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – scalar or array, optional, default=None. Specifies the values to be prepended along axis before computing the difference.

  • append (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – scalar or array, optional, default=None. Specifies the values to be appended along axis before computing the difference.

Return type:

Array

Returns:

An array containing the n-th order difference between the elements of a.

See also

Examples

jnp.diff computes the first order difference along axis, by default.

>>> a = jnp.array([[1, 5, 2, 9],
...                [3, 8, 7, 4]])
>>> jnp.diff(a)
Array([[ 4, -3,  7],
       [ 5, -1, -3]], dtype=int32)

When n = 2, second order difference is computed along axis.

>>> jnp.diff(a, n=2)
Array([[-7, 10],
       [-6, -2]], dtype=int32)

When prepend = 2, it is prepended to a along axis before computing the difference.

>>> jnp.diff(a, prepend=2)
Array([[-1,  4, -3,  7],
       [ 1,  5, -1, -3]], dtype=int32)

When append = jnp.array([[3],[1]]), it is appended to a along axis before computing the difference.

>>> jnp.diff(a, append=jnp.array([[3],[1]]))
Array([[ 4, -3,  7, -6],
       [ 5, -1, -3, -3]], dtype=int32)
scico.numpy.divide(x1, x2, /)

Alias of jax.numpy.true_divide.

Return type:

Array

scico.numpy.divmod(x1, x2, /)

Calculates the integer quotient and remainder of x1 by x2 element-wise

JAX implementation of numpy.divmod.

Parameters:
Return type:

tuple[Array, Array]

Returns:

A tuple of arrays (x1 // x2, x1 % x2).

See also

Examples

>>> x1 = jnp.array([10, 20, 30])
>>> x2 = jnp.array([3, 4, 7])
>>> jnp.divmod(x1, x2)
(Array([3, 5, 4], dtype=int32), Array([1, 0, 2], dtype=int32))
>>> x1 = jnp.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5])
>>> x2 = 3
>>> jnp.divmod(x1, x2)
(Array([-2, -2, -1, -1, -1,  0,  0,  0,  1,  1,  1], dtype=int32),
 Array([1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], dtype=int32))
>>> x1 = jnp.array([6, 6, 6], dtype=jnp.int32)
>>> x2 = jnp.array([1.9, 2.5, 3.1], dtype=jnp.float32)
>>> jnp.divmod(x1, x2)
(Array([3., 2., 1.], dtype=float32),
 Array([0.30000007, 1.        , 2.9       ], dtype=float32))
scico.numpy.dot(a, b, *, precision=None, preferred_element_type=None, out_sharding=None)

Compute the dot product of two arrays.

JAX implementation of numpy.dot.

This differs from jax.numpy.matmul in two respects:

  • if either a or b is a scalar, the result of dot is equivalent to jax.numpy.multiply, while the result of matmul is an error.

  • if a and b have more than 2 dimensions, the batch indices are stacked rather than broadcast.

Parameters:
Return type:

Array

Returns:

array containing the dot product of the inputs, with batch dimensions of a and b stacked rather than broadcast.

See also

Examples

For scalar inputs, dot computes the element-wise product:

>>> x = jnp.array([1, 2, 3])
>>> jnp.dot(x, 2)
Array([2, 4, 6], dtype=int32)

For vector or matrix inputs, dot computes the vector or matrix product:

>>> M = jnp.array([[2, 3, 4],
...                [5, 6, 7],
...                [8, 9, 0]])
>>> jnp.dot(M, x)
Array([20, 38, 26], dtype=int32)
>>> jnp.dot(M, M)
Array([[ 51,  60,  29],
       [ 96, 114,  62],
       [ 61,  78,  95]], dtype=int32)

For higher-dimensional matrix products, batch dimensions are stacked, whereas in matmul they are broadcast. For example:

>>> a = jnp.zeros((3, 2, 4))
>>> b = jnp.zeros((3, 4, 1))
>>> jnp.dot(a, b).shape
(3, 2, 3, 1)
>>> jnp.matmul(a, b).shape
(3, 2, 1)
scico.numpy.dsplit(ary, indices_or_sections)

Split an array into sub-arrays depth-wise.

JAX implementation of numpy.dsplit.

Refer to the documentation of jax.numpy.split for details. dsplit is equivalent to split with axis=2.

Return type:

list[Array]

Examples

>>> x = jnp.arange(12).reshape(3, 1, 4)
>>> print(x)
[[[ 0  1  2  3]]

 [[ 4  5  6  7]]

 [[ 8  9 10 11]]]
>>> x1, x2 = jnp.dsplit(x, 2)
>>> print(x1)
[[[0 1]]

 [[4 5]]

 [[8 9]]]
>>> print(x2)
[[[ 2  3]]

 [[ 6  7]]

 [[10 11]]]

See also

scico.numpy.dstack(tup, dtype=None)

Stack arrays depth-wise.

JAX implementation of numpy.dstack.

For arrays of three or more dimensions, this is equivalent to jax.numpy.concatenate with axis=2.

Parameters:
  • tup (ndarray | Array | Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – a sequence of arrays to stack; each must have the same shape along all but the third axis. Input arrays will be promoted to at least rank 3. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Return type:

Array

Returns:

the stacked result.

See also

Examples

Scalar values:

>>> jnp.dstack([1, 2, 3])
Array([[[1, 2, 3]]], dtype=int32, weak_type=True)

1D arrays:

>>> x = jnp.arange(3)
>>> y = jnp.ones(3)
>>> jnp.dstack([x, y])
Array([[[0., 1.],
        [1., 1.],
        [2., 1.]]], dtype=float32)

2D arrays:

>>> x = x.reshape(1, 3)
>>> y = y.reshape(1, 3)
>>> jnp.dstack([x, y])
Array([[[0., 1.],
        [1., 1.],
        [2., 1.]]], dtype=float32)
scico.numpy.ediff1d(ary, to_end=None, to_begin=None)

Compute the differences of the elements of the flattened array.

JAX implementation of numpy.ediff1d.

Parameters:
Return type:

Array

Returns:

An array containing the differences between the elements of the input array.

Note

Unlike NumPy’s implementation of ediff1d, jax.numpy.ediff1d will not issue an error if casting to_end or to_begin to the type of ary loses precision.

See also

  • jax.numpy.diff: Computes the n-th order difference between elements of the array along a given axis.

  • jax.numpy.cumsum: Computes the cumulative sum of the elements of the array along a given axis.

  • jax.numpy.gradient: Computes the gradient of an N-dimensional array.

Examples

>>> a = jnp.array([2, 3, 5, 9, 1, 4])
>>> jnp.ediff1d(a)
Array([ 1,  2,  4, -8,  3], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10)
Array([-10,   1,   2,   4,  -8,   3], dtype=int32)
>>> jnp.ediff1d(a, to_end=jnp.array([20, 30]))
Array([ 1,  2,  4, -8,  3, 20, 30], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30]))
Array([-10,   1,   2,   4,  -8,   3,  20,  30], dtype=int32)

For array with ndim > 1, the differences are computed after flattening the input array.

>>> a1 = jnp.array([[2, -1, 4, 7],
...                 [3, 5, -6, 9]])
>>> jnp.ediff1d(a1)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)
>>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9])
>>> jnp.ediff1d(a2)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)
scico.numpy.einsum(subscripts, /, *operands, out=None, optimize='auto', precision=None, preferred_element_type=None, _dot_general=<function dot_general>, out_sharding=None)

Einstein summation

JAX implementation of numpy.einsum.

einsum is a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. It has a somewhat complicated overloaded API; the arguments below reflect the most common calling convention. The Examples section below demonstrates some of the alternative calling conventions.

Parameters:
  • subscripts – string containing axes names separated by commas.

  • *operands – sequence of one or more arrays corresponding to the subscripts.

  • optimize (str | bool | list[tuple[int, ...]]) – specify how to optimize the order of computation. In JAX this defaults to "auto" which produces optimized expressions via the opt_einsum package. Other options are True (same as "optimal"), False (unoptimized), or any string supported by opt_einsum, which includes "optimal", "greedy", "eager", and others. It may also be a pre-computed path (see einsum_path).

  • precision (Union[None, str, Precision, tuple[str, str], tuple[Precision, Precision], DotAlgorithm, DotAlgorithmPreset]) – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).

  • preferred_element_type (Union[str, type[Any], dtype, SupportsDType, None]) – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

  • out (None) – unsupported by JAX

  • _dot_general (Callable[..., Array]) – optionally override the dot_general callable used by einsum. This parameter is experimental, and may be removed without warning at any time.

Return type:

Array

Returns:

array containing the result of the einstein summation.

Examples

The mechanics of einsum are perhaps best demonstrated by example. Here we show how to use einsum to compute a number of quantities from one or more arrays. For more discussion and examples of einsum, see the documentation of numpy.einsum.

>>> M = jnp.arange(16).reshape(4, 4)
>>> x = jnp.arange(4)
>>> y = jnp.array([5, 4, 3, 2])

Vector product

>>> jnp.einsum('i,i', x, y)
Array(16, dtype=int32)
>>> jnp.vecdot(x, y)
Array(16, dtype=int32)

Here are some alternative einsum calling conventions to compute the same result:

>>> jnp.einsum('i,i->', x, y)  # explicit form
Array(16, dtype=int32)
>>> jnp.einsum(x, (0,), y, (0,))  # implicit form via indices
Array(16, dtype=int32)
>>> jnp.einsum(x, (0,), y, (0,), ())  # explicit form via indices
Array(16, dtype=int32)

Matrix product

>>> jnp.einsum('ij,j->i', M, x)  # explicit form
Array([14, 38, 62, 86], dtype=int32)
>>> jnp.matmul(M, x)
Array([14, 38, 62, 86], dtype=int32)

Here are some alternative einsum calling conventions to compute the same result:

>>> jnp.einsum('ij,j', M, x) # implicit form
Array([14, 38, 62, 86], dtype=int32)
>>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices
Array([14, 38, 62, 86], dtype=int32)
>>> jnp.einsum(M, (0, 1), x, (1,))  # implicit form via indices
Array([14, 38, 62, 86], dtype=int32)

Outer product

>>> jnp.einsum("i,j->ij", x, y)
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)
>>> jnp.outer(x, y)
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)

Some other ways of computing outer products:

>>> jnp.einsum("i,j", x, y)  # implicit form
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)
>>> jnp.einsum(x, (0,), y, (1,), (0, 1))  # explicit form via indices
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)
>>> jnp.einsum(x, (0,), y, (1,))  # implicit form via indices
Array([[ 0,  0,  0,  0],
       [ 5,  4,  3,  2],
       [10,  8,  6,  4],
       [15, 12,  9,  6]], dtype=int32)

1D array sum

>>> jnp.einsum("i->", x)  # requires explicit form
Array(6, dtype=int32)
>>> jnp.einsum(x, (0,), ())  # explicit form via indices
Array(6, dtype=int32)
>>> jnp.sum(x)
Array(6, dtype=int32)

Sum along an axis

>>> jnp.einsum("...j->...", M)  # requires explicit form
Array([ 6, 22, 38, 54], dtype=int32)
>>> jnp.einsum(M, (..., 0), (...,))  # explicit form via indices
Array([ 6, 22, 38, 54], dtype=int32)
>>> M.sum(-1)
Array([ 6, 22, 38, 54], dtype=int32)

Matrix transpose

>>> y = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.einsum("ij->ji", y)  # explicit form
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> jnp.einsum("ji", y)  # implicit form
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> jnp.einsum(y, (1, 0))  # implicit form via indices
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> jnp.einsum(y, (0, 1), (1, 0))  # explicit form via indices
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> jnp.transpose(y)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

Matrix diagonal

>>> jnp.einsum("ii->i", M)
Array([ 0,  5, 10, 15], dtype=int32)
>>> jnp.diagonal(M)
Array([ 0,  5, 10, 15], dtype=int32)

Matrix trace

>>> jnp.einsum("ii", M)
Array(30, dtype=int32)
>>> jnp.trace(M)
Array(30, dtype=int32)

Tensor products

>>> x = jnp.arange(30).reshape(2, 3, 5)
>>> y = jnp.arange(60).reshape(3, 4, 5)
>>> jnp.einsum('ijk,jlk->il', x, y)  # explicit form
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)
>>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)])
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)
>>> jnp.einsum('ijk,jlk', x, y)  # implicit form
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)
>>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3))  # explicit form via indices
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)
>>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2))  # implicit form via indices
Array([[ 3340,  3865,  4390,  4915],
       [ 8290,  9940, 11590, 13240]], dtype=int32)

Chained dot products

>>> w = jnp.arange(5, 9).reshape(2, 2)
>>> x = jnp.arange(6).reshape(2, 3)
>>> y = jnp.arange(-2, 4).reshape(3, 2)
>>> z = jnp.array([[2, 4, 6], [3, 5, 7]])
>>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z)
Array([[ 481,  831, 1181],
       [ 651, 1125, 1599]], dtype=int32)
>>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4))  # implicit, via indices
Array([[ 481,  831, 1181],
       [ 651, 1125, 1599]], dtype=int32)
>>> w @ x @ y @ z  # direct chain of matmuls
Array([[ 481,  831, 1181],
       [ 651, 1125, 1599]], dtype=int32)
>>> jnp.linalg.multi_dot([w, x, y, z])
Array([[ 481,  831, 1181],
       [ 651, 1125, 1599]], dtype=int32)
scico.numpy.einsum_path(subscripts, /, *operands, optimize='auto')

Evaluates the optimal contraction path without evaluating the einsum.

JAX implementation of numpy.einsum_path. This function calls into the opt_einsum package, and makes use of its optimization routines.

Parameters:
  • subscripts – string containing axes names separated by commas.

  • *operands – sequence of one or more arrays corresponding to the subscripts.

  • optimize (bool | str | list[tuple[int, ...]]) – specify how to optimize the order of computation. In JAX this defaults to "auto". Other options are True (same as "optimize"), False (unoptimized), or any string supported by opt_einsum, which includes "optimize",, "greedy", "eager", and others.

Return type:

tuple[list[tuple[int, ...]], Any]

Returns:

A tuple containing the path that may be passed to einsum, and a printable object representing this optimal path.

Examples

>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3))
>>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100))
>>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5))
>>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal")
>>> print(path)
[(1, 2), (0, 1)]
>>> print(path_info)
      Complete contraction:  ij,jk,kl->il
            Naive scaling:  4
        Optimized scaling:  3
          Naive FLOP count:  9.000e+3
      Optimized FLOP count:  3.060e+3
      Theoretical speedup:  2.941e+0
      Largest intermediate:  1.500e+1 elements
    --------------------------------------------------------------------------------
    scaling        BLAS                current                             remaining
    --------------------------------------------------------------------------------
      3           GEMM              kl,jk->lj                             ij,lj->il
      3           GEMM              lj,ij->il                                il->il

Use the computed path in einsum:

>>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path)
Array([[-754,  324, -142,   82,   50],
       [ 408,  -50,   87,  -29,    7]], dtype=int32)
scico.numpy.empty(shape, dtype=None, *, device=None, out_sharding=None)

Create an empty array.

JAX implementation of numpy.empty.

Note

For historical reasons, jax.numpy.empty is currently equivalent to jax.numpy.zeros: i.e. it returns a buffer initialized with zeros. To create a buffer of uninitialized values, please use jax.lax.empty.

Parameters:
  • shape (Any) – int or sequence of ints specifying the shape of the created array.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype for the created array; defaults to float32 or float64 depending on the X64 configuration (see Default dtypes and the X64 flag).

  • device (Device | Sharding | None) – (optional) Device or Sharding to which the created array will be committed. This argument exists for compatibility with the Python Array API standard.

  • out_sharding (NamedSharding | P | None) – (optional) PartitionSpec or NamedSharding representing the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying both out_sharding and device will result in an error.

Return type:

Array

Returns:

Array of the specified shape and dtype, with the given device/sharding if specified.

Examples

>>> jnp.empty(4)  
Array([0., 0., 0., 0.], dtype=float32)
>>> jnp.empty((2, 3), dtype=bool)  
Array([[False, False, False],
       [False, False, False]], dtype=bool)
scico.numpy.empty_like(prototype, dtype=None, shape=None, *, device=None)

Create an empty array with the same shape and dtype as an array.

JAX implementation of numpy.empty_like. Because XLA cannot create an un-initialized array, jax.numpy.empty will always return an array full of zeros.

Parameters:
  • a – Array-like object with shape and dtype attributes.

  • shape (Any) – optionally override the shape of the created array.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optionally override the dtype of the created array.

  • device (Device | Sharding | None) – (optional) Device or Sharding to which the created array will be committed.

Return type:

Array

Returns:

Array of the specified shape and dtype, on the specified device if specified.

Examples

>>> x = jnp.arange(4)
>>> jnp.empty_like(x)  
Array([0, 0, 0, 0], dtype=int32)
>>> jnp.empty_like(x, dtype=bool)  
Array([False, False, False, False], dtype=bool)
>>> jnp.empty_like(x, shape=(2, 3))  
Array([[0, 0, 0],
       [0, 0, 0]], dtype=int32)
scico.numpy.equal(x, y, /)

Returns element-wise truth value of x == y.

JAX implementation of numpy.equal. This function provides the implementation of the == operator for JAX arrays.

Parameters:
Return type:

Array

Returns:

A boolean array containing True where the elements of x == y and False otherwise.

See also

Examples

>>> jnp.equal(0., -0.)
Array(True, dtype=bool, weak_type=True)
>>> jnp.equal(1, 1.)
Array(True, dtype=bool, weak_type=True)
>>> jnp.equal(5, jnp.array(5))
Array(True, dtype=bool, weak_type=True)
>>> jnp.equal(2, -2)
Array(False, dtype=bool, weak_type=True)
>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> y = jnp.array([1, 5, 9])
>>> jnp.equal(x, y)
Array([[ True, False, False],
       [False,  True, False],
       [False, False,  True]], dtype=bool)
>>> x == y
Array([[ True, False, False],
       [False,  True, False],
       [False, False,  True]], dtype=bool)
scico.numpy.exp(x, /)

Calculate element-wise exponential of the input.

JAX implementation of numpy.exp.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar

Return type:

Array

Returns:

An array containing the exponential of each element in x, promotes to inexact dtype.

See also

Examples

jnp.exp follows the properties of exponential such as \(e^{(a+b)} = e^a * e^b\).

>>> x1 = jnp.array([2, 4, 3, 1])
>>> x2 = jnp.array([1, 3, 2, 3])
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.exp(x1+x2))
[  20.09 1096.63  148.41   54.6 ]
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.exp(x1)*jnp.exp(x2))
[  20.09 1096.63  148.41   54.6 ]

This property holds for complex input also:

>>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j))
Array(True, dtype=bool)
scico.numpy.exp2(x, /)

Calculate element-wise base-2 exponential of input.

JAX implementation of numpy.exp2.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar

Return type:

Array

Returns:

An array containing the base-2 exponential of each element in x, promotes to inexact dtype.

See also

Examples

jnp.exp2 follows the properties of the exponential such as \(2^{a+b} = 2^a * 2^b\).

>>> x1 = jnp.array([2, -4, 3, -1])
>>> x2 = jnp.array([-1, 3, -2, 3])
>>> jnp.exp2(x1+x2)
Array([2. , 0.5, 2. , 4. ], dtype=float32)
>>> jnp.exp2(x1)*jnp.exp2(x2)
Array([2. , 0.5, 2. , 4. ], dtype=float32)
scico.numpy.expand_dims(a, axis)

Insert dimensions of length 1 into array

JAX implementation of numpy.expand_dims, implemented via jax.lax.expand_dims.

Parameters:
Return type:

Array

Returns:

Copy of a with added dimensions.

Notes

Unlike numpy.expand_dims, jax.numpy.expand_dims will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> x = jnp.array([1, 2, 3])
>>> x.shape
(3,)

Expand the leading dimension:

>>> jnp.expand_dims(x, 0)
Array([[1, 2, 3]], dtype=int32)
>>> _.shape
(1, 3)

Expand the trailing dimension:

>>> jnp.expand_dims(x, 1)
Array([[1],
       [2],
       [3]], dtype=int32)
>>> _.shape
(3, 1)

Expand multiple dimensions:

>>> jnp.expand_dims(x, (0, 1, 3))
Array([[[[1],
         [2],
         [3]]]], dtype=int32)
>>> _.shape
(1, 1, 3, 1)

Dimensions can also be expanded more succinctly by indexing with None:

>>> x[None]  # equivalent to jnp.expand_dims(x, 0)
Array([[1, 2, 3]], dtype=int32)
>>> x[:, None]  # equivalent to jnp.expand_dims(x, 1)
Array([[1],
       [2],
       [3]], dtype=int32)
>>> x[None, None, :, None]  # equivalent to jnp.expand_dims(x, (0, 1, 3))
Array([[[[1],
         [2],
         [3]]]], dtype=int32)
scico.numpy.expm1(x, /)

Calculate exp(x)-1 of each element of the input.

JAX implementation of numpy.expm1.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing exp(x)-1 of each element in x, promotes to inexact dtype.

Note

jnp.expm1 has much higher precision than the naive computation of exp(x)-1 for small values of x.

See also

  • jax.numpy.log1p: Calculates element-wise logarithm of one plus input.

  • jax.numpy.exp: Calculates element-wise exponential of the input.

  • jax.numpy.exp2: Calculates base-2 exponential of each element of the input.

Examples

>>> x = jnp.array([2, -4, 3, -1])
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.expm1(x))
[ 6.39 -0.98 19.09 -0.63]
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.exp(x)-1)
[ 6.39 -0.98 19.09 -0.63]

For values very close to 0, jnp.expm1(x) is much more accurate than jnp.exp(x)-1:

>>> x1 = jnp.array([1e-4, 1e-6, 2e-10])
>>> jnp.expm1(x1)
Array([1.0000500e-04, 1.0000005e-06, 2.0000000e-10], dtype=float32)
>>> jnp.exp(x1)-1
Array([1.00016594e-04, 9.53674316e-07, 0.00000000e+00], dtype=float32)
scico.numpy.extract(condition, arr, *, size=None, fill_value=0)

Return the elements of an array that satisfy a condition.

JAX implementation of numpy.extract.

Parameters:
Return type:

Array

Returns:

1D array of extracted entries . If size is specified, the result will have shape (size,) and be right-padded with fill_value. If size is not specified, the output shape will depend on the number of True entries in condition.

Notes

This function does not require strict shape agreement between condition and arr. If condition.size > arr.size, then condition will be truncated, and if arr.size > condition.size, then arr will be truncated.

See also

jax.numpy.compress: multi-dimensional version of extract.

Examples

Extract values from a 1D array:

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> mask = (x % 2 == 0)
>>> jnp.extract(mask, x)
Array([2, 4, 6], dtype=int32)

In the simplest case, this is equivalent to boolean indexing:

>>> x[mask]
Array([2, 4, 6], dtype=int32)

For use with JAX transformations, you can pass the size argument to specify a static shape for the output, along with an optional fill_value that defaults to zero:

>>> jnp.extract(mask, x, size=len(x), fill_value=0)
Array([2, 4, 6, 0, 0, 0], dtype=int32)

Notice that unlike with boolean indexing, extract does not require strict agreement between the sizes of the array and condition, and will effectively truncate both to the minimum size:

>>> short_mask = jnp.array([False, True])
>>> jnp.extract(short_mask, x)
Array([2], dtype=int32)
>>> long_mask = jnp.array([True, False, True, False, False, False, False, False])
>>> jnp.extract(long_mask, x)
Array([1, 3], dtype=int32)
scico.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)

Create a square or rectangular identity matrix

JAX implementation of numpy.eye.

Parameters:
  • N (Union[int, Any]) – integer specifying the first dimension of the array.

  • M (Union[int, Any, None]) – optional integer specifying the second dimension of the array; defaults to the same value as N.

  • k (Union[int, Array, ndarray, bool, number, bool, float, complex]) – optional integer specifying the offset of the diagonal. Use positive values for upper diagonals, and negative values for lower diagonals. Default is zero.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype; defaults to floating point.

  • device (Device | Sharding | None) – optional Device or Sharding to which the created array will be committed.

Return type:

Array

Returns:

Identity array of shape (N, M), or (N, N) if M is not specified.

See also

jax.numpy.identity: Simpler API for generating square identity matrices.

Examples

A simple 3x3 identity matrix:

>>> jnp.eye(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

Integer identity matrices with offset diagonals:

>>> jnp.eye(3, k=1, dtype=int)
Array([[0, 1, 0],
       [0, 0, 1],
       [0, 0, 0]], dtype=int32)
>>> jnp.eye(3, k=-1, dtype=int)
Array([[0, 0, 0],
       [1, 0, 0],
       [0, 1, 0]], dtype=int32)

Non-square identity matrix:

>>> jnp.eye(3, 5, k=1)
Array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.]], dtype=float32)
scico.numpy.fabs(x, /)

Compute the element-wise absolute values of the real-valued input.

JAX implementation of numpy.fabs.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar. Must not have a complex dtype.

Return type:

Array

Returns:

An array with same shape as x and dtype float, containing the element-wise absolute values.

See also

  • jax.numpy.absolute: Computes the absolute values of the input including complex dtypes.

  • jax.numpy.abs: Computes the absolute values of the input including complex dtypes.

Examples

For integer inputs:

>>> x = jnp.array([-5, -9, 1, 10, 15])
>>> jnp.fabs(x)
Array([ 5.,  9.,  1., 10., 15.], dtype=float32)

For float type inputs:

>>> x1 = jnp.array([-1.342, 5.649, 3.927])
>>> jnp.fabs(x1)
Array([1.342, 5.649, 3.927], dtype=float32)

For boolean inputs:

>>> x2 = jnp.array([True, False])
>>> jnp.fabs(x2)
Array([1., 0.], dtype=float32)
scico.numpy.fill_diagonal(a, val, wrap=False, *, inplace=True)

Return a copy of the array with the diagonal overwritten.

JAX implementation of numpy.fill_diagonal.

The semantics of numpy.fill_diagonal are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set to False` by the user as a reminder of this API difference.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array. Must have a.ndim >= 2. If a.ndim >= 3, then all dimensions must be the same size.

  • val (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array with which to fill the diagonal. If an array, it will be flattened and repeated to fill the diagonal entries.

  • wrap (bool) – Not implemented by JAX. Only the default value of False is supported.

  • inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.

Return type:

Array

Returns:

A copy of a with the diagonal set to val.

Examples

>>> x = jnp.zeros((3, 3), dtype=int)
>>> jnp.fill_diagonal(x, jnp.array([1, 2, 3]), inplace=False)
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)

Unlike numpy.fill_diagonal, the input x is not modified.

If the diagonal value has too many entries, it will be truncated

>>> jnp.fill_diagonal(x, jnp.arange(100, 200), inplace=False)
Array([[100,   0,   0],
       [  0, 101,   0],
       [  0,   0, 102]], dtype=int32)

If the diagonal has too few entries, it will be repeated:

>>> x = jnp.zeros((4, 4), dtype=int)
>>> jnp.fill_diagonal(x, jnp.array([3, 4]), inplace=False)
Array([[3, 0, 0, 0],
       [0, 4, 0, 0],
       [0, 0, 3, 0],
       [0, 0, 0, 4]], dtype=int32)

For non-square arrays, the diagonal of the leading square slice is filled:

>>> x = jnp.zeros((3, 5), dtype=int)
>>> jnp.fill_diagonal(x, 1, inplace=False)
Array([[1, 0, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 0, 1, 0, 0]], dtype=int32)

And for square N-dimensional arrays, the N-dimensional diagonal is filled:

>>> y = jnp.zeros((2, 2, 2))
>>> jnp.fill_diagonal(y, 1, inplace=False)
Array([[[1., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 1.]]], dtype=float32)
scico.numpy.flatnonzero(a, *, size=None, fill_value=None)

Return indices of nonzero elements in a flattened array

JAX implementation of numpy.flatnonzero.

jnp.flatnonzero(x) is equivalent to nonzero(ravel(a))[0]. For a full discussion of the parameters to this function, refer to jax.numpy.nonzero.

Parameters:
Return type:

Array

Returns:

Array containing the indices of each nonzero value in the flattened array.

Examples

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 8]])
>>> jnp.flatnonzero(x)
Array([1, 3, 5], dtype=int32)

This is equivalent to calling nonzero on the flattened array, and extracting the first entry in the resulting tuple:

>>> jnp.nonzero(x.ravel())[0]
Array([1, 3, 5], dtype=int32)

The returned indices can be used to extract nonzero entries from the flattened array:

>>> indices = jnp.flatnonzero(x)
>>> x.ravel()[indices]
Array([5, 6, 8], dtype=int32)
scico.numpy.flip(m, axis=None)

Reverse the order of elements of an array along the given axis.

JAX implementation of numpy.flip.

Parameters:
Return type:

Array

Returns:

An array with the elements in reverse order along axis.

See also

Examples

>>> x1 = jnp.array([[1, 2],
...                 [3, 4]])
>>> jnp.flip(x1)
Array([[4, 3],
       [2, 1]], dtype=int32)

If axis is specified with an integer, then jax.numpy.flip reverses the array along that particular axis only.

>>> jnp.flip(x1, axis=1)
Array([[2, 1],
       [4, 3]], dtype=int32)
>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x2
Array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.flip(x2)
Array([[[8, 7],
        [6, 5]],

       [[4, 3],
        [2, 1]]], dtype=int32)

When axis is specified with a sequence of integers, then jax.numpy.flip reverses the array along the specified axes.

>>> jnp.flip(x2, axis=[1, 2])
Array([[[4, 3],
        [2, 1]],

       [[8, 7],
        [6, 5]]], dtype=int32)
scico.numpy.fliplr(m)

Reverse the order of elements of an array along axis 1.

JAX implementation of numpy.fliplr.

Parameters:

m (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Array with at least two dimensions.

Return type:

Array

Returns:

An array with the elements in reverse order along axis 1.

See also

Examples

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.fliplr(x)
Array([[2, 1],
       [4, 3]], dtype=int32)
scico.numpy.flipud(m)

Reverse the order of elements of an array along axis 0.

JAX implementation of numpy.flipud.

Parameters:

m (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Array with at least one dimension.

Return type:

Array

Returns:

An array with the elements in reverse order along axis 0.

See also

Examples

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.flipud(x)
Array([[3, 4],
       [1, 2]], dtype=int32)
scico.numpy.float_power(x, y, /)

Calculate element-wise base x exponential of y.

JAX implementation of numpy.float_power.

Parameters:
Return type:

Array

Returns:

An array containing the base x exponentials of y, promoting to the inexact dtype.

See also

  • jax.numpy.exp: Calculates element-wise exponential of the input.

  • jax.numpy.exp2: Calculates base-2 exponential of each element of the input.

Examples

Inputs with same shape:

>>> x = jnp.array([3, 1, -5])
>>> y = jnp.array([2, 4, -1])
>>> jnp.float_power(x, y)
Array([ 9. ,  1. , -0.2], dtype=float32)

Inputs with broadcast compatibility:

>>> x1 = jnp.array([[2, -4, 1],
...                 [-1, 2, 3]])
>>> y1 = jnp.array([-2, 1, 4])
>>> jnp.float_power(x1, y1)
Array([[ 0.25, -4.  ,  1.  ],
       [ 1.  ,  2.  , 81.  ]], dtype=float32)

jnp.float_power produces nan for negative values raised to a non-integer values.

>>> jnp.float_power(-3, 1.7)
Array(nan, dtype=float32, weak_type=True)
scico.numpy.floor(x, /)

Round input to the nearest integer downwards.

JAX implementation of numpy.floor.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar. Must not have complex dtype.

Return type:

Array

Returns:

An array with same shape and dtype as x containing the values rounded to the nearest integer that is less than or equal to the value itself.

See also

  • jax.numpy.fix: Rounds the input to the nearest integer towards zero.

  • jax.numpy.trunc: Rounds the input to the nearest integer towards zero.

  • jax.numpy.ceil: Rounds the input up to the nearest integer.

Examples

>>> key = jax.random.key(42)
>>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5)
>>> with jnp.printoptions(precision=2, suppress=True):
...     print(x)
[[-0.11  1.8   1.16]
 [ 0.61 -0.49  0.86]
 [-4.25  2.75  1.99]]
>>> jnp.floor(x)
Array([[-1.,  1.,  1.],
       [ 0., -1.,  0.],
       [-5.,  2.,  1.]], dtype=float32)
scico.numpy.floor_divide(x1, x2, /)

Calculates the floor division of x1 by x2 element-wise

JAX implementation of numpy.floor_divide.

Parameters:
Return type:

Array

Returns:

An array-like object containing each of the quotients rounded down to the nearest integer towards negative infinity. This is equivalent to x1 // x2 in Python.

Note

x1 // x2 is equivalent to jnp.floor_divide(x1, x2) for arrays x1 and x2

See also

jax.numpy.divide and jax.numpy.true_divide for floating point division.

Examples

>>> x1 = jnp.array([10, 20, 30])
>>> x2 = jnp.array([3, 4, 7])
>>> jnp.floor_divide(x1, x2)
Array([3, 5, 4], dtype=int32)
>>> x1 = jnp.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5])
>>> x2 = 3
>>> jnp.floor_divide(x1, x2)
Array([-2, -2, -1, -1, -1,  0,  0,  0,  1,  1,  1], dtype=int32)
>>> x1 = jnp.array([6, 6, 6], dtype=jnp.int32)
>>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32)
>>> jnp.floor_divide(x1, x2)
Array([3., 2., 2.], dtype=float32)
scico.numpy.fmax(x1, x2)

Return element-wise maximum of the input arrays.

JAX implementation of numpy.fmax.

Parameters:
Return type:

Array

Returns:

An array containing the element-wise maximum of x1 and x2.

Note

For each pair of elements, jnp.fmax returns:
  • the larger of the two if both elements are finite numbers.

  • finite number if one element is nan.

  • nan if both elements are nan.

  • inf if one element is inf and the other is finite or nan.

  • -inf if one element is -inf and the other is nan.

Examples

>>> jnp.fmax(3, 7)
Array(7, dtype=int32, weak_type=True)
>>> jnp.fmax(5, jnp.array([1, 7, 9, 4]))
Array([5, 7, 9, 5], dtype=int32)
>>> x1 = jnp.array([1, 3, 7, 8])
>>> x2 = jnp.array([-1, 4, 6, 9])
>>> jnp.fmax(x1, x2)
Array([1, 4, 7, 9], dtype=int32)
>>> x3 = jnp.array([[2, 3, 5, 10],
...                 [11, 9, 7, 5]])
>>> jnp.fmax(x1, x3)
Array([[ 2,  3,  7, 10],
       [11,  9,  7,  8]], dtype=int32)
>>> x4 = jnp.array([jnp.inf, 6, -jnp.inf, nan])
>>> x5 = jnp.array([[3, 5, 7, nan],
...                 [nan, 9, nan, -1]])
>>> jnp.fmax(x4, x5)
Array([[ inf,   6.,   7.,  nan],
       [ inf,   9., -inf,  -1.]], dtype=float32)
scico.numpy.fmin(x1, x2)

Return element-wise minimum of the input arrays.

JAX implementation of numpy.fmin.

Parameters:
Return type:

Array

Returns:

An array containing the element-wise minimum of x1 and x2.

Note

For each pair of elements, jnp.fmin returns:
  • the smaller of the two if both elements are finite numbers.

  • finite number if one element is nan.

  • -inf if one element is -inf and the other is finite or nan.

  • inf if one element is inf and the other is nan.

  • nan if both elements are nan.

Examples

>>> jnp.fmin(2, 3)
Array(2, dtype=int32, weak_type=True)
>>> jnp.fmin(2, jnp.array([1, 4, 2, -1]))
Array([ 1,  2,  2, -1], dtype=int32)
>>> x1 = jnp.array([1, 3, 2])
>>> x2 = jnp.array([2, 1, 4])
>>> jnp.fmin(x1, x2)
Array([1, 1, 2], dtype=int32)
>>> x3 = jnp.array([1, 5, 3])
>>> x4 = jnp.array([[2, 3, 1],
...                 [5, 6, 7]])
>>> jnp.fmin(x3, x4)
Array([[1, 3, 1],
       [1, 5, 3]], dtype=int32)
>>> nan = jnp.nan
>>> x5 = jnp.array([jnp.inf, 5, nan])
>>> x6 = jnp.array([[2, 3, nan],
...                 [nan, 6, 7]])
>>> jnp.fmin(x5, x6)
Array([[ 2.,  3., nan],
       [inf,  5.,  7.]], dtype=float32)
scico.numpy.fmod(x1, x2, /)

Calculate element-wise floating-point modulo operation.

JAX implementation of numpy.fmod.

Parameters:
Return type:

Array

Returns:

An array containing the result of the element-wise floating-point modulo operation of x1 and x2 with same sign as the elements of x1.

Note

The result of jnp.fmod is equivalent to x1 - x2 * jnp.trunc(x1 / x2).

See also

Examples

>>> x1 = jnp.array([[3, -1, 4],
...                 [8, 5, -2]])
>>> x2 = jnp.array([2, 3, -5])
>>> jnp.fmod(x1, x2)
Array([[ 1, -1,  4],
       [ 0,  2, -2]], dtype=int32)
>>> x1 - x2 * jnp.trunc(x1 / x2)
Array([[ 1., -1.,  4.],
       [ 0.,  2., -2.]], dtype=float32)
scico.numpy.frexp(x, /)

Split floating point values into mantissa and twos exponent.

JAX implementation of numpy.frexp.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – real-valued array

Return type:

tuple[Array, Array]

Returns:

A tuple (mantissa, exponent) where mantissa is a floating point value between -1 and 1, and exponent is an integer such that x == mantissa * 2 ** exponent.

See also

Examples

Split values into mantissa and exponent:

>>> x = jnp.array([1., 2., 3., 4., 5.])
>>> m, e = jnp.frexp(x)
>>> m
Array([0.5  , 0.5  , 0.75 , 0.5  , 0.625], dtype=float32)
>>> e
Array([1, 2, 2, 3, 3], dtype=int32)

Reconstruct the original array:

>>> m * 2 ** e
Array([1., 2., 3., 4., 5.], dtype=float32)
scico.numpy.from_dlpack(x, /, *, device=None, copy=None)

Construct a JAX array via DLPack.

JAX implementation of numpy.from_dlpack.

Parameters:
  • x (Any) – An object that implements the DLPack protocol via the __dlpack__ and __dlpack_device__ methods, or a legacy DLPack tensor on either CPU or GPU.

  • device (Device | Sharding | None) – An optional Device or Sharding, representing the single device onto which the returned array should be placed. If given, then the result is committed to the device. If unspecified, the resulting array will be unpacked onto the same device it originated from. Setting device to a device different from the source of external_array will require a copy, meaning copy must be set to either True or None.

  • copy (bool | None) – An optional boolean, controlling whether or not a copy is performed. If copy=True then a copy is always performed, even if unpacked onto the same device. If copy=False then the copy is never performed and will raise an error if necessary. When copy=None (default) then a copy may be performed if needed for a device transfer.

Return type:

Array

Returns:

A JAX array of the input buffer.

Note

While JAX arrays are always immutable, dlpack buffers cannot be marked as immutable, and it is possible for processes external to JAX to mutate them in-place. If a JAX Array is constructed from a dlpack buffer without copying and the source buffer is later modified in-place, it may lead to undefined behavior when using the associated JAX array.

Examples

Passing data between NumPy and JAX via DLPack:

>>> import numpy as np
>>> rng = np.random.default_rng(42)
>>> x_numpy = rng.random(4, dtype='float32')
>>> print(x_numpy)
[0.08925092 0.773956   0.6545715  0.43887842]
>>> hasattr(x_numpy, "__dlpack__")  # NumPy supports the DLPack interface
True
>>> import jax.numpy as jnp
>>> x_jax = jnp.from_dlpack(x_numpy)
>>> print(x_jax)
[0.08925092 0.773956   0.6545715  0.43887842]
>>> hasattr(x_jax, "__dlpack__")  # JAX supports the DLPack interface
True
>>> x_numpy_round_trip = np.from_dlpack(x_jax)
>>> print(x_numpy_round_trip)
[0.08925092 0.773956   0.6545715  0.43887842]
scico.numpy.frombuffer(buffer, dtype=<class 'float'>, count=-1, offset=0)

Convert a buffer into a 1-D JAX array.

JAX implementation of numpy.frombuffer.

Parameters:
  • buffer (bytes | Any) – an object containing the data. It must be either a bytes object with a length that is an integer multiple of the dtype element size, or it must be an object exporting the Python buffer interface.

  • dtype (Union[str, type[Any], dtype, SupportsDType]) – optional. Desired data type for the array. Default is float64. This specifies the dtype used to parse the buffer, but note that after parsing, 64-bit values will be cast to 32-bit JAX arrays if the jax_enable_x64 flag is set to False.

  • count (int) – optional integer specifying the number of items to read from the buffer. If -1 (default), all items from the buffer are read.

  • offset (int) – optional integer specifying the number of bytes to skip at the beginning of the buffer. Default is 0.

Return type:

Array

Returns:

A 1-D JAX array representing the interpreted data from the buffer.

See also

Examples

Using a bytes buffer:

>>> buf = b"\x00\x01\x02\x03\x04"
>>> jnp.frombuffer(buf, dtype=jnp.uint8)
Array([0, 1, 2, 3, 4], dtype=uint8)
>>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1)
Array([1, 2, 3, 4], dtype=uint8)

Constructing a JAX array via the Python buffer interface, using Python’s built-in array module.

>>> from array import array
>>> pybuffer = array('i', [0, 1, 2, 3, 4])
>>> jnp.frombuffer(pybuffer, dtype=jnp.int32)
Array([0, 1, 2, 3, 4], dtype=int32)
scico.numpy.fromfile(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromfile.

This function is left deliberately unimplemented because it may be non-pure and thus unsafe for use with JIT and other JAX transformations. Consider using jnp.asarray(np.fromfile(...)) instead, although care should be taken if np.fromfile is used within jax transformations because of its potential side-effect of consuming the file object; for more information see Common Gotchas: Pure Functions.

scico.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)

Create an array from a function applied over indices.

JAX implementation of numpy.fromfunction. The JAX implementation differs in that it dispatches via jax.vmap, and so unlike in NumPy the function logically operates on scalar inputs, and need not explicitly handle broadcasted inputs (See Examples below).

Parameters:
  • function (Callable[..., Array]) – a function that takes N dynamic scalars and outputs a scalar.

  • shape (Any) – a length-N tuple of integers specifying the output shape.

  • dtype (Union[str, type[Any], dtype, SupportsDType]) – optionally specify the dtype of the inputs. Defaults to floating-point.

  • kwargs – additional keyword arguments are passed statically to function.

Return type:

Array

Returns:

An array of shape shape if function returns a scalar, or in general a pytree of arrays with leading dimensions shape, as determined by the output of function.

See also

Examples

Generate a multiplication table of a given shape:

>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int)
Array([[ 0,  0,  0,  0,  0,  0],
       [ 0,  1,  2,  3,  4,  5],
       [ 0,  2,  4,  6,  8, 10]], dtype=int32)

When function returns a non-scalar the output will have leading dimension of shape:

>>> def f(x):
...   return (x + 1) * jnp.arange(3)
>>> jnp.fromfunction(f, shape=(2,))
Array([[0., 1., 2.],
       [0., 2., 4.]], dtype=float32)

function may return multiple results, in which case each is mapped independently:

>>> def f(x, y):
...   return x + y, x * y
>>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5))
>>> print(x_plus_y)
[[0. 1. 2. 3. 4.]
 [1. 2. 3. 4. 5.]
 [2. 3. 4. 5. 6.]]
>>> print(x_times_y)
[[0. 0. 0. 0. 0.]
 [0. 1. 2. 3. 4.]
 [0. 2. 4. 6. 8.]]

The JAX implementation differs slightly from NumPy’s implementation. In numpy.fromfunction, the function is expected to explicitly operate element-wise on the full grid of input values:

>>> def f(x, y):
...   print(f"{x.shape = }\n{y.shape = }")
...   return x + y
...
>>> np.fromfunction(f, (2, 3))
x.shape = (2, 3)
y.shape = (2, 3)
array([[0., 1., 2.],
       [1., 2., 3.]])

In jax.numpy.fromfunction, the function is vectorized via jax.vmap, and so is expected to operate on scalar values:

>>> jnp.fromfunction(f, (2, 3))
x.shape = ()
y.shape = ()
Array([[0., 1., 2.],
       [1., 2., 3.]], dtype=float32)
scico.numpy.fromiter(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromiter.

This function is left deliberately unimplemented because it may be non-pure and thus unsafe for use with JIT and other JAX transformations. Consider using jnp.asarray(np.fromiter(...)) instead, although care should be taken if np.fromiter is used within jax transformations because of its potential side-effect of consuming the iterable object; for more information see Common Gotchas: Pure Functions.

scico.numpy.frompyfunc(func, /, nin, nout, *, identity=None)

Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

Parameters:
  • func (Callable[..., Any]) – a callable that takes nin scalar arguments and returns nout outputs.

  • nin (int) – integer specifying the number of scalar inputs

  • nout (int) – integer specifying the number of scalar outputs

  • identity (Any) – (optional) a scalar specifying the identity of the operation, if any.

Return type:

ufunc

Returns:

wrapped – jax.numpy.ufunc wrapper of func.

Examples

Here is an example of creating a ufunc similar to jax.numpy.add:

>>> import operator
>>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)

Now all the standard jax.numpy.ufunc methods are available:

>>> x = jnp.arange(4)
>>> add(x, 10)
Array([10, 11, 12, 13], dtype=int32)
>>> add.outer(x, x)
Array([[0, 1, 2, 3],
       [1, 2, 3, 4],
       [2, 3, 4, 5],
       [3, 4, 5, 6]], dtype=int32)
>>> add.reduce(x)
Array(6, dtype=int32)
>>> add.accumulate(x)
Array([0, 1, 3, 6], dtype=int32)
>>> add.at(x, 1, 10, inplace=False)
Array([ 0, 11,  2,  3], dtype=int32)
scico.numpy.fromstring(string, dtype=<class 'float'>, count=-1, *, sep)

Convert a string of text into 1-D JAX array.

JAX implementation of numpy.fromstring.

Parameters:
  • string (str) – input string containing the data.

  • dtype (Union[str, type[Any], dtype, SupportsDType]) – optional. Desired data type for the array. Default is float.

  • count (int) – optional integer specifying the number of items to read from the string. If -1 (default), all items are read.

  • sep (str) – the string used to separate values in the input string.

Return type:

Array

Returns:

A 1-D JAX array containing the parsed data from the input string.

See also

Examples

>>> jnp.fromstring("1 2 3", dtype=int, sep=" ")
Array([1, 2, 3], dtype=int32)
>>> jnp.fromstring("0.1, 0.2, 0.3", dtype=float, count=2, sep=",")
Array([0.1, 0.2], dtype=float32)
scico.numpy.full(shape, fill_value, dtype=None, *, device=None)

Create an array full of a specified value.

JAX implementation of numpy.full.

Parameters:
Return type:

Array

Returns:

Array of the specified shape and dtype, on the specified device if specified.

Examples

>>> jnp.full(4, 2, dtype=float)
Array([2., 2., 2., 2.], dtype=float32)
>>> jnp.full((2, 3), 0, dtype=bool)
Array([[False, False, False],
       [False, False, False]], dtype=bool)

fill_value may also be an array that is broadcast to the specified shape:

>>> jnp.full((2, 3), fill_value=jnp.arange(3))
Array([[0, 1, 2],
       [0, 1, 2]], dtype=int32)
scico.numpy.full_like(a, fill_value, dtype=None, shape=None, *, device=None)

Create an array full of a specified value with the same shape and dtype as an array.

JAX implementation of numpy.full_like.

Parameters:
Return type:

Array

Returns:

Array of the specified shape and dtype, on the specified device if specified.

Examples

>>> x = jnp.arange(4.0)
>>> jnp.full_like(x, 2)
Array([2., 2., 2., 2.], dtype=float32)
>>> jnp.full_like(x, 0, shape=(2, 3))
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

fill_value may also be an array that is broadcast to the specified shape:

>>> x = jnp.arange(6).reshape(2, 3)
>>> jnp.full_like(x, fill_value=jnp.array([[1], [2]]))
Array([[1, 1, 1],
       [2, 2, 2]], dtype=int32)
scico.numpy.gcd(x1, x2)

Compute the greatest common divisor of two arrays.

JAX implementation of numpy.gcd.

Parameters:
Return type:

Array

Returns:

An array containing the greatest common divisors of the corresponding elements from the absolute values of x1 and x2.

See also

  • jax.numpy.lcm: compute the least common multiple of two arrays.

Examples

Scalar inputs:

>>> jnp.gcd(12, 18)
Array(6, dtype=int32, weak_type=True)

Array inputs:

>>> x1 = jnp.array([12, 18, 24])
>>> x2 = jnp.array([5, 10, 15])
>>> jnp.gcd(x1, x2)
Array([1, 2, 3], dtype=int32)

Broadcasting:

>>> x1 = jnp.array([12])
>>> x2 = jnp.array([6, 9, 12])
>>> jnp.gcd(x1, x2)
Array([ 6,  3, 12], dtype=int32)
scico.numpy.geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0)

Generate geometrically-spaced values.

JAX implementation of numpy.geomspace.

Parameters:
  • start (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Specifies the starting values.

  • stop (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Specifies the stop values.

  • num (int) – int, optional, default=50. Number of values to generate.

  • endpoint (bool) – bool, optional, default=True. If True, then include the stop value in the result. If False, then exclude the stop value.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional. Specifies the dtype of the output.

  • axis (int) – int, optional, default=0. Axis along which to generate the geomspace.

Return type:

Array

Returns:

An array containing the geometrically-spaced values.

See also

Examples

List 5 geometrically-spaced values between 1 and 16:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.geomspace(1, 16, 5)
Array([ 1.,  2.,  4.,  8., 16.], dtype=float32)

List 4 geomtrically-spaced values between 1 and 16, with endpoint=False:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.geomspace(1, 16, 4, endpoint=False)
Array([1., 2., 4., 8.], dtype=float32)

Multi-dimensional geomspace:

>>> start = jnp.array([1, 1000])
>>> stop = jnp.array([27, 1])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.geomspace(start, stop, 4)
Array([[   1., 1000.],
       [   3.,  100.],
       [   9.,   10.],
       [  27.,    1.]], dtype=float32)
scico.numpy.get_printoptions()

Alias of numpy.get_printoptions.

JAX arrays are printed via NumPy, so NumPy’s printoptions configurations will apply to printed JAX arrays.

See the numpy.set_printoptions documentation for details on the available options and their meanings.

scico.numpy.gradient(f, *varargs, axis=None, edge_order=None)

Compute the numerical gradient of a sampled function.

JAX implementation of numpy.gradient.

The gradient in jnp.gradient is computed using second-order finite differences across the array of sampled function values. This should not be confused with jax.grad, which computes a precise gradient of a callable function via automatic differentiation.

Parameters:
  • f (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array of function values.

  • varargs (Union[Array, ndarray, bool, number, bool, int, float, complex]) –

    optional list of scalars or arrays specifying spacing of function evaluations. Options are:

    • not specified: unit spacing in all dimensions.

    • a single scalar: constant spacing in all dimensions.

    • N values: specify different spacing in each dimension:

      • scalar values indicate constant spacing in that dimension.

      • array values must match the length of the corresponding dimension, and specify the coordinates at which f is evaluated.

  • edge_order (int | None) – not implemented in JAX

  • axis (int | Sequence[int] | None) – integer or tuple of integers specifying the axis along which to compute the gradient. If None (default) calculates the gradient along all axes.

Return type:

Array | list[Array]

Returns:

an array or tuple of arrays containing the numerical gradient along each specified axis.

See also

  • jax.grad: automatic differentiation of a function with a single output.

Examples

Comparing numerical and automatic differentiation of a simple function:

>>> def f(x):
...   return jnp.sin(x) * jnp.exp(-x / 4)
...
>>> def gradf_exact(x):
...   # exact analytical gradient of f(x)
...   return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4)
...
>>> x = jnp.linspace(0, 5, 10)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print("numerical gradient:", jnp.gradient(f(x), x))
...   print("automatic gradient:", jax.vmap(jax.grad(f))(x))
...   print("exact gradient:    ", gradf_exact(x))
...
numerical gradient: [ 0.83  0.61  0.18 -0.2  -0.43 -0.49 -0.39 -0.21 -0.02  0.08]
automatic gradient: [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]
exact gradient:     [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]

Notice that, as expected, the numerical gradient has some approximation error compared to the automatic gradient computed via jax.grad.

scico.numpy.greater(x, y, /)

Return element-wise truth value of x > y.

JAX implementation of numpy.greater.

Parameters:
Return type:

Array

Returns:

An array containing boolean values. True if the elements of x > y, and False otherwise.

See also

Examples

Scalar inputs:

>>> jnp.greater(5, 2)
Array(True, dtype=bool, weak_type=True)

Inputs with same shape:

>>> x = jnp.array([5, 9, -2])
>>> y = jnp.array([4, -1, 6])
>>> jnp.greater(x, y)
Array([ True,  True, False], dtype=bool)

Inputs with broadcast compatibility:

>>> x1 = jnp.array([[5, -6, 7],
...                 [-2, 5, 9]])
>>> y1 = jnp.array([-4, 3, 10])
>>> jnp.greater(x1, y1)
Array([[ True, False, False],
       [ True,  True, False]], dtype=bool)
scico.numpy.greater_equal(x, y, /)

Return element-wise truth value of x >= y.

JAX implementation of numpy.greater_equal.

Parameters:
Return type:

Array

Returns:

An array containing boolean values. True if the elements of x >= y, and False otherwise.

See also

Examples

Scalar inputs:

>>> jnp.greater_equal(4, 7)
Array(False, dtype=bool, weak_type=True)

Inputs with same shape:

>>> x = jnp.array([2, 5, -1])
>>> y = jnp.array([-6, 4, 3])
>>> jnp.greater_equal(x, y)
Array([ True,  True, False], dtype=bool)

Inputs with broadcast compatibility:

>>> x1 = jnp.array([[3, -1, 4],
...                 [5, 9, -6]])
>>> y1 = jnp.array([-1, 4, 2])
>>> jnp.greater_equal(x1, y1)
Array([[ True, False,  True],
       [ True,  True, False]], dtype=bool)
scico.numpy.hamming(M)

Return a Hamming window of size M.

JAX implementation of numpy.hamming.

Parameters:

M (int) – The window size.

Return type:

Array

Returns:

An array of size M containing the Hamming window.

Examples

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.hamming(4))
[0.08 0.77 0.77 0.08]

See also

scico.numpy.hanning(M)

Return a Hanning window of size M.

JAX implementation of numpy.hanning.

Parameters:

M (int) – The window size.

Return type:

Array

Returns:

An array of size M containing the Hanning window.

Examples

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.hanning(4))
[0.   0.75 0.75 0.  ]

See also

scico.numpy.heaviside(x1, x2, /)

Compute the heaviside step function.

JAX implementation of numpy.heaviside.

The heaviside step function is defined by:

\[\begin{split}\mathrm{heaviside}(x1, x2) = \begin{cases} 0, & x1 < 0\\ x2, & x1 = 0\\ 1, & x1 > 0. \end{cases}\end{split}\]
Parameters:
Return type:

Array

Returns:

An array containing the heaviside step function of x1, promoting to inexact dtype.

Examples

>>> x1 = jnp.array([[-2, 0, 3],
...                 [5, -1, 0],
...                 [0, 7, -3]])
>>> x2 = jnp.array([2, 0.5, 1])
>>> jnp.heaviside(x1, x2)
Array([[0. , 0.5, 1. ],
       [1. , 0. , 1. ],
       [2. , 1. , 0. ]], dtype=float32)
>>> jnp.heaviside(x1, 0.5)
Array([[0. , 0.5, 1. ],
       [1. , 0. , 0.5],
       [0.5, 1. , 0. ]], dtype=float32)
>>> jnp.heaviside(-3, x2)
Array([0., 0., 0.], dtype=float32)
scico.numpy.histogram(a, bins=10, range=None, weights=None, density=None)

Compute a 1-dimensional histogram.

JAX implementation of numpy.histogram.

Parameters:
Return type:

tuple[Array, Array]

Returns:

A tuple of arrays (histogram, bin_edges), where histogram contains the aggregated data, and bin_edges specifies the boundaries of the bins.

See also

Examples

>>> a = jnp.array([1, 2, 3, 10, 11, 15, 19, 25])
>>> counts, bin_edges = jnp.histogram(a, bins=8)
>>> print(counts)
[3. 0. 0. 2. 1. 0. 1. 1.]
>>> print(bin_edges)
[ 1.  4.  7. 10. 13. 16. 19. 22. 25.]

Specifying the bin range:

>>> counts, bin_edges = jnp.histogram(a, range=(0, 25), bins=5)
>>> print(counts)
[3. 0. 2. 2. 1.]
>>> print(bin_edges)
[ 0.  5. 10. 15. 20. 25.]

Specifying the bin edges explicitly:

>>> bin_edges = jnp.array([0, 10, 20, 30])
>>> counts, _ = jnp.histogram(a, bins=bin_edges)
>>> print(counts)
[3. 4. 1.]

Using density=True returns a normalized histogram:

>>> density, bin_edges = jnp.histogram(a, density=True)
>>> dx = jnp.diff(bin_edges)
>>> normed_sum = jnp.sum(density * dx)
>>> jnp.allclose(normed_sum, 1.0)
Array(True, dtype=bool)
scico.numpy.histogram2d(x, y, bins=10, range=None, weights=None, density=None)

Compute a 2-dimensional histogram.

JAX implementation of numpy.histogram2d.

Parameters:
Return type:

tuple[Array, Array, Array]

Returns:

A tuple of arrays (histogram, x_edges, y_edges), where histogram contains the aggregated data, and x_edges and y_edges specify the boundaries of the bins.

See also

Examples

>>> x = jnp.array([1, 2, 3, 10, 11, 15, 19, 25])
>>> y = jnp.array([2, 5, 6, 8, 13, 16, 17, 18])
>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, bins=8)
>>> counts.shape
(8, 8)
>>> x_edges
Array([ 1.,  4.,  7., 10., 13., 16., 19., 22., 25.], dtype=float32)
>>> y_edges
Array([ 2.,  4.,  6.,  8., 10., 12., 14., 16., 18.], dtype=float32)

Specifying the bin range:

>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, range=[(0, 25), (0, 25)], bins=5)
>>> counts.shape
(5, 5)
>>> x_edges
Array([ 0.,  5., 10., 15., 20., 25.], dtype=float32)
>>> y_edges
Array([ 0.,  5., 10., 15., 20., 25.], dtype=float32)

Specifying the bin edges explicitly:

>>> x_edges = jnp.array([0, 10, 20, 30])
>>> y_edges = jnp.array([0, 10, 20, 30])
>>> counts, _, _ = jnp.histogram2d(x, y, bins=[x_edges, y_edges])
>>> counts
Array([[3, 0, 0],
       [1, 3, 0],
       [0, 1, 0]], dtype=int32)

Using density=True returns a normalized histogram:

>>> density, x_edges, y_edges = jnp.histogram2d(x, y, density=True)
>>> dx = jnp.diff(x_edges)
>>> dy = jnp.diff(y_edges)
>>> normed_sum = jnp.sum(density * dx[:, None] * dy[None, :])
>>> jnp.allclose(normed_sum, 1.0)
Array(True, dtype=bool)
scico.numpy.histogram_bin_edges(a, bins=10, range=None, weights=None)

Compute the bin edges for a histogram.

JAX implementation of numpy.histogram_bin_edges.

Parameters:
Return type:

Array

Returns:

An array of bin edges for the histogram.

See also

Examples

>>> a = jnp.array([2, 5, 3, 6, 4, 1])
>>> jnp.histogram_bin_edges(a, bins=5)
Array([1., 2., 3., 4., 5., 6.], dtype=float32)
>>> jnp.histogram_bin_edges(a, bins=5, range=(-10, 10))  
Array([-10.,  -6.,  -2.,   2.,   6.,  10.], dtype=float32)
scico.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)

Compute an N-dimensional histogram.

JAX implementation of numpy.histogramdd.

Parameters:
Return type:

tuple[Array, list[Array]]

Returns:

A tuple of arrays (histogram, bin_edges), where histogram contains the aggregated data, and bin_edges specifies the boundaries of the bins.

See also

Examples

A histogram over 100 points in three dimensions

>>> key = jax.random.key(42)
>>> a = jax.random.normal(key, (100, 3))
>>> counts, bin_edges = jnp.histogramdd(a, bins=6,
...                                     range=[(-3, 3), (-3, 3), (-3, 3)])
>>> counts.shape
(6, 6, 6)
>>> bin_edges  
[Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32),
 Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32),
 Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32)]

Using density=True returns a normalized histogram:

>>> density, bin_edges = jnp.histogramdd(a, density=True)
>>> bin_widths = map(jnp.diff, bin_edges)
>>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij')
>>> normed = jnp.sum(density * dx * dy * dz)
>>> jnp.allclose(normed, 1.0)
Array(True, dtype=bool)
scico.numpy.hsplit(ary, indices_or_sections)

Split an array into sub-arrays horizontally.

JAX implementation of numpy.hsplit.

Refer to the documentation of jax.numpy.split for details. hsplit is equivalent to split with axis=1, or axis=0 for one-dimensional arrays.

Return type:

list[Array]

Examples

1D array:

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> x1, x2 = jnp.hsplit(x, 2)
>>> print(x1, x2)
[1 2 3] [4 5 6]

2D array:

>>> x = jnp.array([[1, 2, 3, 4],
...                [5, 6, 7, 8]])
>>> x1, x2 = jnp.hsplit(x, 2)
>>> print(x1)
[[1 2]
 [5 6]]
>>> print(x2)
[[3 4]
 [7 8]]

See also

scico.numpy.hstack(tup, dtype=None)

Horizontally stack arrays.

JAX implementation of numpy.hstack.

For arrays of one or more dimensions, this is equivalent to jax.numpy.concatenate with axis=1.

Parameters:
  • tup (ndarray | Array | Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – a sequence of arrays to stack; each must have the same shape along all but the second axis. Input arrays will be promoted to at least rank 1. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Return type:

Array

Returns:

the stacked result.

See also

Examples

Scalar values:

>>> jnp.hstack([1, 2, 3])
Array([1, 2, 3], dtype=int32, weak_type=True)

1D arrays:

>>> x = jnp.arange(3)
>>> y = jnp.ones(3)
>>> jnp.hstack([x, y])
Array([0., 1., 2., 1., 1., 1.], dtype=float32)

2D arrays:

>>> x = x.reshape(3, 1)
>>> y = y.reshape(3, 1)
>>> jnp.hstack([x, y])
Array([[0., 1.],
       [1., 1.],
       [2., 1.]], dtype=float32)
scico.numpy.hypot(x1, x2, /)

Return element-wise hypotenuse for the given legs of a right angle triangle.

JAX implementation of numpy.hypot.

Parameters:
  • x1 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Specifies one of the legs of right angle triangle. complex dtype are not supported.

  • x2 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Specifies the other leg of right angle triangle. complex dtype are not supported. x1 and x2 must either have same shape or be broadcast compatible.

Return type:

Array

Returns:

An array containing the hypotenuse for the given given legs x1 and x2 of a right angle triangle, promoting to inexact dtype.

Note

jnp.hypot is a more numerically stable way of computing jnp.sqrt(x1 ** 2 + x2 **2).

Examples

>>> jnp.hypot(3, 4)
Array(5., dtype=float32, weak_type=True)
>>> x1 = jnp.array([[3, -2, 5],
...                 [9, 1, -4]])
>>> x2 = jnp.array([-5, 6, 8])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.hypot(x1, x2)
Array([[ 5.831,  6.325,  9.434],
       [10.296,  6.083,  8.944]], dtype=float32)
scico.numpy.i0(x)

Calculate modified Bessel function of first kind, zeroth order.

JAX implementation of numpy.i0.

Modified Bessel function of first kind, zeroth order is defined by:

\[\mathrm{i0}(x) = I_0(x) = \sum_{k=0}^{\infty} \frac{(x^2/4)^k}{(k!)^2}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Specifies the argument of Bessel function. Complex inputs are not supported.

Return type:

Array

Returns:

An array containing the corresponding values of the modified Bessel function of x.

See also

Examples

>>> x = jnp.array([-2, -1, 0, 1, 2])
>>> jnp.i0(x)
Array([2.2795851, 1.266066 , 1.0000001, 1.266066 , 2.2795851], dtype=float32)
scico.numpy.identity(n, dtype=None)

Create a square identity matrix

JAX implementation of numpy.identity.

Parameters:
  • n (Union[int, Any]) – integer specifying the size of each array dimension.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype; defaults to floating point.

Return type:

Array

Returns:

Identity array of shape (n, n).

See also

jax.numpy.eye: non-square and/or offset identity matrices.

Examples

A simple 3x3 identity matrix:

>>> jnp.identity(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

A 2x2 integer identity matrix:

>>> jnp.identity(2, dtype=int)
Array([[1, 0],
       [0, 1]], dtype=int32)
scico.numpy.imag(val, /)

Return element-wise imaginary of part of the complex argument.

JAX implementation of numpy.imag.

Parameters:

val (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the imaginary part of the elements of val.

See also

Examples

>>> jnp.imag(4)
Array(0, dtype=int32, weak_type=True)
>>> jnp.imag(5j)
Array(5., dtype=float32, weak_type=True)
>>> x = jnp.array([2+3j, 5-1j, -3])
>>> jnp.imag(x)
Array([ 3., -1.,  0.], dtype=float32)
scico.numpy.indices(dimensions, dtype=None, sparse=False)

Generate arrays of grid indices.

JAX implementation of numpy.indices.

Parameters:
  • dimensions (Sequence[int]) – the shape of the grid.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – the dtype of the indices (defaults to integer).

  • sparse (bool) – if True, then return sparse indices. Default is False, which returns dense indices.

Return type:

Array | tuple[Array, ...]

Returns:

An array of shape (len(dimensions), *dimensions) If sparse is False, or a sequence of arrays of the same length as dimensions if sparse is True.

See also

Examples

>>> jnp.indices((2, 3))
Array([[[0, 0, 0],
        [1, 1, 1]],

       [[0, 1, 2],
        [0, 1, 2]]], dtype=int32)
>>> jnp.indices((2, 3), sparse=True)
(Array([[0],
       [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))
scico.numpy.inner(a, b, *, precision=None, preferred_element_type=None)

Compute the inner product of two arrays.

JAX implementation of numpy.inner.

Unlike jax.numpy.matmul or jax.numpy.dot, this always performs a contraction along the last dimension of each input.

Parameters:
Return type:

Array

Returns:

array of shape (*a.shape[:-1], *b.shape[:-1]) containing the batched vector product of the inputs.

See also

Examples

For 1D inputs, this implements standard (non-conjugate) vector multiplication:

>>> a = jnp.array([1j, 3j, 4j])
>>> b = jnp.array([4., 2., 5.])
>>> jnp.inner(a, b)
Array(0.+30.j, dtype=complex64)

For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:

>>> a = jnp.ones((2, 3))
>>> b = jnp.ones((5, 3))
>>> jnp.inner(a, b).shape
(2, 5)
scico.numpy.insert(arr, obj, values, axis=None)

Insert entries into an array at specified indices.

JAX implementation of numpy.insert.

Parameters:
Return type:

Array

Returns:

A copy of arr with values inserted at the specified locations.

See also

Examples

Inserting a single value:

>>> x = jnp.arange(5)
>>> jnp.insert(x, 2, 99)
Array([ 0,  1, 99,  2,  3,  4], dtype=int32)

Inserting multiple identical values using a slice:

>>> jnp.insert(x, slice(None, None, 2), -1)
Array([-1,  0,  1, -1,  2,  3, -1,  4], dtype=int32)

Inserting multiple values using an index:

>>> indices = jnp.array([4, 2, 5])
>>> values = jnp.array([10, 11, 12])
>>> jnp.insert(x, indices, values)
Array([ 0,  1, 11,  2,  3, 10,  4, 12], dtype=int32)

Inserting columns into a 2D array:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> indices = jnp.array([1, 3])
>>> values = jnp.array([[10, 11],
...                     [12, 13]])
>>> jnp.insert(x, indices, values, axis=1)
Array([[ 1, 10,  2,  3, 11],
       [ 4, 12,  5,  6, 13]], dtype=int32)
scico.numpy.interp(x, xp, fp, left=None, right=None, period=None)

One-dimensional linear interpolation.

JAX implementation of numpy.interp.

Parameters:
Return type:

Array

Returns:

an array of shape x.shape containing the interpolated function at values x.

Examples

>>> xp = jnp.arange(10)
>>> fp = 2 * xp
>>> x = jnp.array([0.5, 2.0, 3.5])
>>> interp(x, xp, fp)
Array([1., 4., 7.], dtype=float32)

Unless otherwise specified, extrapolation will be constant:

>>> x = jnp.array([-10., 10.])
>>> interp(x, xp, fp)
Array([ 0., 18.], dtype=float32)

Use "extrapolate" mode for linear extrapolation:

>>> interp(x, xp, fp, left='extrapolate', right='extrapolate')
Array([-20.,  20.], dtype=float32)

For periodic interpolation, specify the period:

>>> xp = jnp.array([0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2])
>>> fp = jnp.sin(xp)
>>> x = 2 * jnp.pi  # note: not in input array
>>> jnp.interp(x, xp, fp, period=2 * jnp.pi)
Array(0., dtype=float32)
scico.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)

Compute the set intersection of two 1D arrays.

JAX implementation of numpy.intersect1d.

Because the size of the output of intersect1d is data-dependent, the function is not typically compatible with jit and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.intersect1d to be used in such contexts.

Parameters:
  • ar1 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – first array of values to intersect.

  • ar2 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – second array of values to intersect.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if assume_unique is True and the input arrays contain duplicates, the behavior is undefined. default: False.

  • return_indices (bool) – If True, return arrays of indices specifying where the intersected values first appear in the input arrays.

  • size (int | None) – if specified, return only the first size sorted elements. If there are fewer elements than size indicates, the return value will be padded with fill_value, and returned indices will be padded with an out-of-bound index.

  • fill_value (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the smallest value in the intersection.

Return type:

Array | tuple[Array, Array, Array]

Returns:

An array intersection, or if return_indices=True, a tuple of arrays (intersection, ar1_indices, ar2_indices). Returned values are

  • intersection: A 1D array containing each value that appears in both ar1 and ar2.

  • ar1_indices: (returned if return_indices=True) an array of shape intersection.shape containing the indices in flattened ar1 of values in intersection. For 1D inputs, intersection is equivalent to ar1[ar1_indices].

  • ar2_indices: (returned if return_indices=True) an array of shape intersection.shape containing the indices in flattened ar2 of values in intersection. For 1D inputs, intersection is equivalent to ar2[ar2_indices].

See also

Examples

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.intersect1d(ar1, ar2)
Array([3, 4], dtype=int32)

Computing intersection with indices:

>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True)
>>> intersection
Array([3, 4], dtype=int32)

ar1_indices gives the indices of the intersected values within ar1:

>>> ar1_indices
Array([2, 3], dtype=int32)
>>> jnp.all(intersection == ar1[ar1_indices])
Array(True, dtype=bool)

ar2_indices gives the indices of the intersected values within ar2:

>>> ar2_indices
Array([0, 1], dtype=int32)
>>> jnp.all(intersection == ar2[ar2_indices])
Array(True, dtype=bool)
scico.numpy.isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)

Check if the elements of two arrays are approximately equal within a tolerance.

JAX implementation of numpy.allclose.

Essentially this function evaluates the following condition:

\[|a - b| \le \mathtt{atol} + \mathtt{rtol} * |b|\]

jnp.inf in a will be considered equal to jnp.inf in b.

Parameters:
Return type:

Array

Returns:

A new array containing boolean values indicating whether the input arrays are element-wise approximately equal within the specified tolerances.

Examples

>>> jnp.isclose(jnp.array([1e6, 2e6, jnp.inf]), jnp.array([1e6, 2e7, jnp.inf]))
Array([ True, False,  True], dtype=bool)
>>> jnp.isclose(jnp.array([1e6, 2e6, 3e6]),
...              jnp.array([1.00008e6, 2.00008e7, 3.00008e8]), rtol=1e3)
Array([ True,  True,  True], dtype=bool)
>>> jnp.isclose(jnp.array([1e6, 2e6, 3e6]),
...              jnp.array([1.00001e6, 2.00002e6, 3.00009e6]), atol=1e3)
Array([ True,  True,  True], dtype=bool)
>>> jnp.isclose(jnp.array([jnp.nan, 1, 2]),
...              jnp.array([jnp.nan, 1, 2]), equal_nan=True)
Array([ True,  True,  True], dtype=bool)
scico.numpy.iscomplex(x)

Return boolean array showing where the input is complex.

JAX implementation of numpy.iscomplex.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array to check.

Return type:

Array

Returns:

A new array containing boolean values indicating complex elements.

Examples

>>> jnp.iscomplex(jnp.array([True, 0, 1, 2j, 1+2j]))
Array([False, False, False, True, True], dtype=bool)
scico.numpy.iscomplexobj(x)

Check if the input is a complex number or an array containing complex elements.

JAX implementation of numpy.iscomplexobj.

The function evaluates based on input type rather than value. Inputs with zero imaginary parts are still considered complex.

Parameters:

x (Any) – input object to check.

Return type:

bool

Returns:

True if x is a complex number or an array containing at least one complex element, False otherwise.

Examples

>>> jnp.iscomplexobj(True)
False
>>> jnp.iscomplexobj(0)
False
>>> jnp.iscomplexobj(jnp.array([1, 2]))
False
>>> jnp.iscomplexobj(1+2j)
True
>>> jnp.iscomplexobj(jnp.array([0, 1+2j]))
True
scico.numpy.isdtype(dtype, kind)

Returns a boolean indicating whether a provided dtype is of a specified kind.

Parameters:
  • dtype (Union[str, type[Any], dtype, SupportsDType]) – the input dtype

  • kind (Union[str, type[Any], dtype, SupportsDType, tuple[Union[str, type[Any], dtype, SupportsDType], ...]]) –

    the data type kind. If kind is dtype-like, return dtype = kind. If kind is a string, then return True if the dtype is in the specified category:

    • 'bool': {bool}

    • 'signed integer': {int4, int8, int16, int32, int64}

    • 'unsigned integer': {uint4, uint8, uint16, uint32, uint64}

    • 'integral': shorthand for ('signed integer', 'unsigned integer')

    • 'real floating': {float8_*, float16, bfloat16, float32, float64}

    • 'complex floating': {complex64, complex128}

    • 'numeric': shorthand for ('integral', 'real floating', 'complex floating')

    If kind is a tuple, then return True if dtype matches any entry of the tuple.

Return type:

bool

Returns:

True or False

scico.numpy.isfinite(x, /)

Return a boolean array indicating whether each element of input is finite.

JAX implementation of numpy.isfinite.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

A boolean array of same shape as x containing True where x is not inf, -inf, or NaN, and False otherwise.

See also

  • jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity.

  • jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity.

  • jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.

  • jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).

Examples

>>> x = jnp.array([-1, 3, jnp.inf, jnp.nan])
>>> jnp.isfinite(x)
Array([ True,  True, False, False], dtype=bool)
>>> jnp.isfinite(3-4j)
Array(True, dtype=bool, weak_type=True)
scico.numpy.isin(element, test_elements, assume_unique=False, invert=False, *, method='auto')

Determine whether elements in element appear in test_elements.

JAX implementation of numpy.isin.

Parameters:
  • element (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array of elements for which membership will be checked.

  • test_elements (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array of test values to check for the presence of each element.

  • invert (bool) – If True, return ~isin(element, test_elements). Default is False.

  • assume_unique (bool) – if true, input arrays are assumed to be unique, which can lead to more efficient computation. If the input arrays are not unique and assume_unique is set to True, the results are undefined.

  • method – string specifying the method used to compute the result. Supported options are ‘compare_all’, ‘binary_search’, ‘sort’, and ‘auto’ (default).

Return type:

Array

Returns:

A boolean array of shape element.shape that specifies whether each element appears in test_elements.

Examples

>>> elements = jnp.array([1, 2, 3, 4])
>>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]])
>>> jnp.isin(elements, test_elements)
Array([ True, False,  True, False], dtype=bool)
scico.numpy.isinf(x, /)

Return a boolean array indicating whether each element of input is infinite.

JAX implementation of numpy.isinf.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

A boolean array of same shape as x containing True where x is inf or -inf, and False otherwise.

See also

  • jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity.

  • jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.

  • jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite.

  • jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).

Examples

>>> jnp.isinf(jnp.inf)
Array(True, dtype=bool)
>>> x = jnp.array([2+3j, -jnp.inf, 6, jnp.inf, jnp.nan])
>>> jnp.isinf(x)
Array([False,  True, False,  True, False], dtype=bool)
scico.numpy.isnan(x, /)

Returns a boolean array indicating whether each element of input is NaN.

JAX implementation of numpy.isnan.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

A boolean array of same shape as x containing True where x is not a number (i.e. NaN) and False otherwise.

See also

  • jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite.

  • jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity.

  • jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity.

  • jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.

Examples

>>> jnp.isnan(6)
Array(False, dtype=bool, weak_type=True)
>>> x = jnp.array([2, 1+4j, jnp.inf, jnp.nan])
>>> jnp.isnan(x)
Array([False, False, False,  True], dtype=bool)
scico.numpy.isneginf(x, /, out=None)

Return boolean array indicating whether each element of input is negative infinite.

JAX implementation of numpy.isneginf.

Parameters:

x – input array or scalar. complex dtype are not supported.

Returns:

A boolean array of same shape as x containing True where x is -inf, and False otherwise.

See also

  • jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity.

  • jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity.

  • jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite.

  • jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).

Examples

>>> jnp.isneginf(jnp.inf)
Array(False, dtype=bool)
>>> x = jnp.array([-jnp.inf, 5, jnp.inf, jnp.nan, 1])
>>> jnp.isneginf(x)
Array([ True, False, False, False, False], dtype=bool)
scico.numpy.isposinf(x, /, out=None)

Return boolean array indicating whether each element of input is positive infinite.

JAX implementation of numpy.isposinf.

Parameters:

x – input array or scalar. complex dtype are not supported.

Returns:

A boolean array of same shape as x containing True where x is inf, and False otherwise.

See also

  • jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity.

  • jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.

  • jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite.

  • jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).

Examples

>>> jnp.isposinf(5)
Array(False, dtype=bool)
>>> x = jnp.array([-jnp.inf, 5, jnp.inf, jnp.nan, 1])
>>> jnp.isposinf(x)
Array([False, False,  True, False, False], dtype=bool)
scico.numpy.isreal(x)

Return boolean array showing where the input is real.

JAX implementation of numpy.isreal.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array to check.

Return type:

Array

Returns:

A new array containing boolean values indicating real elements.

Examples

>>> jnp.isreal(jnp.array([False, 0j, 1, 2.1, 1+2j]))
Array([ True,  True,  True,  True, False], dtype=bool)
scico.numpy.isrealobj(x)

Check if the input is not a complex number or an array containing complex elements.

JAX implementation of numpy.isrealobj.

The function evaluates based on input type rather than value. Inputs with zero imaginary parts are still considered complex.

Parameters:

x (Any) – input object to check.

Return type:

bool

Returns:

False if x is a complex number or an array containing at least one complex element, True otherwise.

Examples

>>> jnp.isrealobj(0)
True
>>> jnp.isrealobj(1.2)
True
>>> jnp.isrealobj(jnp.array([1, 2]))
True
>>> jnp.isrealobj(1+2j)
False
>>> jnp.isrealobj(jnp.array([0, 1+2j]))
False
scico.numpy.isscalar(element)

Return True if the input is a scalar.

JAX implementation of numpy.isscalar. JAX’s implementation differs from NumPy’s in that it considers zero-dimensional arrays to be scalars; see the Note below for more details.

Parameters:

element (Any) – input object to check; any type is valid input.

Return type:

bool

Returns:

True if element is a scalar value or an array-like object with zero dimensions, False otherwise.

Note

JAX and NumPy differ in their representation of scalar values. NumPy has special scalar objects (e.g. np.int32(0)) which are distinct from zero-dimensional arrays (e.g. np.array(0)), and numpy.isscalar returns True for the former and False for the latter.

JAX does not define special scalar objects, but rather represents scalars as zero-dimensional arrays. As such, jax.numpy.isscalar returns True for both scalar objects (e.g. 0.0 or np.float32(0.0)) and array-like objects with zero dimensions (e.g. jnp.array(0.0), np.array(0.0)).

One reason for the different conventions in isscalar is to maintain JIT-invariance: i.e. the property that the result of a function should not change when it is JIT-compiled. Because scalar inputs are cast to zero-dimensional JAX arrays at JIT boundaries, the semantics of numpy.isscalar are such that the result changes under JIT:

>>> np.isscalar(1.0)
True
>>> jax.jit(np.isscalar)(1.0)
Array(False, dtype=bool)

By treating zero-dimensional arrays as scalars, jax.numpy.isscalar avoids this issue:

>>> jnp.isscalar(1.0)
True
>>> jax.jit(jnp.isscalar)(1.0)
Array(True, dtype=bool)

Examples

In JAX, both scalars and zero-dimensional array-like objects are considered scalars:

>>> jnp.isscalar(1.0)
True
>>> jnp.isscalar(1 + 1j)
True
>>> jnp.isscalar(jnp.array(1))  # zero-dimensional JAX array
True
>>> jnp.isscalar(jnp.int32(1))  # JAX scalar constructor
True
>>> jnp.isscalar(np.array(1.0))  # zero-dimensional NumPy array
True
>>> jnp.isscalar(np.int32(1))  # NumPy scalar type
True

Arrays with one or more dimension are not considered scalars:

>>> jnp.isscalar(jnp.array([1]))
False
>>> jnp.isscalar(np.array([1]))
False

Compare this to numpy.isscalar, which returns True for scalar-typed objects, and False for all arrays, even those with zero dimensions:

>>> np.isscalar(np.int32(1))  # scalar object
True
>>> np.isscalar(np.array(1))  # zero-dimensional array
False

In JAX, as in NumPy, objects which are not array-like are not considered scalars:

>>> jnp.isscalar(None)
False
>>> jnp.isscalar([1])
False
>>> jnp.isscalar(())
False
>>> jnp.isscalar(slice(10))
False
scico.numpy.issubdtype(arg1, arg2)

Return True if arg1 is equal or lower than arg2 in the type hierarchy.

JAX implementation of numpy.issubdtype.

The main difference in JAX’s implementation is that it properly handles dtype extensions such as bfloat16.

Parameters:
  • arg1 (Union[str, type[Any], dtype, SupportsDType]) – dtype-like object. In typical usage, this will be a dtype specifier, such as "float32" (i.e. a string), np.dtype('int32') (i.e. an instance of numpy.dtype), jnp.complex64 (i.e. a JAX scalar constructor), or np.uint8 (i.e. a NumPy scalar type).

  • arg2 (Union[str, type[Any], dtype, SupportsDType]) – dtype-like object. In typical usage, this will be a generic scalar type, such as jnp.integer, jnp.floating, or jnp.complexfloating.

Return type:

bool

Returns:

True if arg1 represents a dtype that is equal or lower in the type hierarchy than arg2.

See also

Examples

>>> jnp.issubdtype('uint32', jnp.unsignedinteger)
True
>>> jnp.issubdtype(np.int32, jnp.integer)
True
>>> jnp.issubdtype(jnp.bfloat16, jnp.floating)
True
>>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating)
True
>>> jnp.issubdtype('complex64', jnp.integer)
False

Be aware that while this is very similar to numpy.issubdtype, the results of these differ in the case of JAX’s custom floating point types:

>>> np.issubdtype('bfloat16', np.floating)
False
>>> jnp.issubdtype('bfloat16', jnp.floating)
True
scico.numpy.iterable(y)

Check whether or not an object can be iterated over.

Parameters:

y (object) – Input object.

Returns:

b (bool) – Return True if the object has an iterator method or is a sequence and False otherwise.

Examples

>>> import numpy as np
>>> np.iterable([1, 2, 3])
True
>>> np.iterable(2)
False

Notes

In most cases, the results of np.iterable(obj) are consistent with isinstance(obj, collections.abc.Iterable). One notable exception is the treatment of 0-dimensional arrays:

>>> from collections.abc import Iterable
>>> a = np.array(1.0)  # 0-dimensional numpy array
>>> isinstance(a, Iterable)
True
>>> np.iterable(a)
False
scico.numpy.ix_(*args)

Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.

JAX implementation of numpy.ix_.

Parameters:

*args (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N one-dimensional arrays

Return type:

tuple[Array, ...]

Returns:

Tuple of Jax arrays forming an open mesh, each with N dimensions.

Examples

>>> rows = jnp.array([0, 2])
>>> cols = jnp.array([1, 3])
>>> open_mesh = jnp.ix_(rows, cols)
>>> open_mesh
(Array([[0],
      [2]], dtype=int32), Array([[1, 3]], dtype=int32))
>>> [grid.shape for grid in open_mesh]
[(2, 1), (1, 2)]
>>> x = jnp.array([[10, 20, 30, 40],
...                [50, 60, 70, 80],
...                [90, 100, 110, 120],
...                [130, 140, 150, 160]])
>>> x[open_mesh]
Array([[ 20,  40],
       [100, 120]], dtype=int32)
scico.numpy.kaiser(M, beta)

Return a Kaiser window of size M.

JAX implementation of numpy.kaiser.

Parameters:
Return type:

Array

Returns:

An array of size M containing the Kaiser window.

Examples

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.kaiser(4, 1.5))
[0.61 0.95 0.95 0.61]

See also

scico.numpy.kron(a, b)

Compute the Kronecker product of two input arrays.

JAX implementation of numpy.kron.

The Kronecker product is an operation on two matrices of arbitrary size that produces a block matrix. Each element of the first matrix a is multiplied by the entire second matrix b. If a has shape (m, n) and b has shape (p, q), the resulting matrix will have shape (m * p, n * q).

Parameters:
Return type:

Array

Returns:

A new array representing the Kronecker product of the inputs a and b. The shape of the output is the element-wise product of the input shapes.

See also

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([[5, 6],
...                [7, 8]])
>>> jnp.kron(a, b)
Array([[ 5,  6, 10, 12],
       [ 7,  8, 14, 16],
       [15, 18, 20, 24],
       [21, 24, 28, 32]], dtype=int32)
scico.numpy.lcm(x1, x2)

Compute the least common multiple of two arrays.

JAX implementation of numpy.lcm.

Parameters:
Return type:

Array

Returns:

An array containing the least common multiple of the corresponding elements from the absolute values of x1 and x2.

See also

  • jax.numpy.gcd: compute the greatest common divisor of two arrays.

Examples

Scalar inputs:

>>> jnp.lcm(12, 18)
Array(36, dtype=int32, weak_type=True)

Array inputs:

>>> x1 = jnp.array([12, 18, 24])
>>> x2 = jnp.array([5, 10, 15])
>>> jnp.lcm(x1, x2)
Array([ 60,  90, 120], dtype=int32)

Broadcasting:

>>> x1 = jnp.array([12])
>>> x2 = jnp.array([6, 9, 12])
>>> jnp.lcm(x1, x2)
Array([12, 36, 12], dtype=int32)
scico.numpy.ldexp(x1, x2, /)

Compute x1 * 2 ** x2

JAX implementation of numpy.ldexp.

Note that XLA does not provide an ldexp operation, so this is implemneted in JAX via a standard multiplication and exponentiation.

Parameters:
Return type:

Array

Returns:

x1 * 2 ** x2 computed element-wise.

See also

Examples

>>> x1 = jnp.arange(5.0)
>>> x2 = 10
>>> jnp.ldexp(x1, x2)
Array([   0., 1024., 2048., 3072., 4096.], dtype=float32)

ldexp can be used to reconstruct the input to frexp:

>>> x = jnp.array([2., 3., 5., 11.])
>>> m, e = jnp.frexp(x)
>>> m
Array([0.5   , 0.75  , 0.625 , 0.6875], dtype=float32)
>>> e
Array([2, 2, 3, 4], dtype=int32)
>>> jnp.ldexp(m, e)
Array([ 2.,  3.,  5., 11.], dtype=float32)
scico.numpy.less(x, y, /)

Return element-wise truth value of x < y.

JAX implementation of numpy.less.

Parameters:
Return type:

Array

Returns:

An array containing boolean values. True if the elements of x < y, and False otherwise.

See also

Examples

Scalar inputs:

>>> jnp.less(3, 7)
Array(True, dtype=bool, weak_type=True)

Inputs with same shape:

>>> x = jnp.array([5, 9, -3])
>>> y = jnp.array([1, 6, 4])
>>> jnp.less(x, y)
Array([False, False,  True], dtype=bool)

Inputs with broadcast compatibility:

>>> x1 = jnp.array([[2, -4, 6, -8],
...                 [-1, 5, -3, 7]])
>>> y1 = jnp.array([0, 3, -5, 9])
>>> jnp.less(x1, y1)
Array([[False,  True, False,  True],
       [ True, False, False,  True]], dtype=bool)
scico.numpy.less_equal(x, y, /)

Return element-wise truth value of x <= y.

JAX implementation of numpy.less_equal.

Parameters:
Return type:

Array

Returns:

An array containing the boolean values. True if the elements of x <= y, and False otherwise.

See also

Examples

Scalar inputs:

>>> jnp.less_equal(6, -2)
Array(False, dtype=bool, weak_type=True)

Inputs with same shape:

>>> x = jnp.array([-4, 1, 7])
>>> y = jnp.array([2, -3, 8])
>>> jnp.less_equal(x, y)
Array([ True, False,  True], dtype=bool)

Inputs with broadcast compatibility:

>>> x1 = jnp.array([2, -5, 9])
>>> y1 = jnp.array([[1, -6, 5],
...                 [-2, 4, -6]])
>>> jnp.less_equal(x1, y1)
Array([[False, False, False],
       [False,  True, False]], dtype=bool)
scico.numpy.lexsort(keys, axis=-1)

Sort a sequence of keys in lexicographic order.

JAX implementation of numpy.lexsort.

Parameters:
Return type:

Array

Returns:

An array of integers of shape keys[0].shape giving the indices of the entries in lexicographically-sorted order.

See also

Examples

lexsort with a single key is equivalent to argsort:

>>> key1 = jnp.array([4, 2, 3, 2, 5])
>>> jnp.lexsort([key1])
Array([1, 3, 2, 0, 4], dtype=int32)
>>> jnp.argsort(key1)
Array([1, 3, 2, 0, 4], dtype=int32)

With multiple keys, lexsort uses the last key as the primary key:

>>> key2 = jnp.array([2, 1, 1, 2, 2])
>>> jnp.lexsort([key1, key2])
Array([1, 2, 3, 0, 4], dtype=int32)

The meaning of the indices become more clear when printing the sorted keys:

>>> indices = jnp.lexsort([key1, key2])
>>> print(f"{key1[indices]}\n{key2[indices]}")
[2 3 2 4 5]
[1 1 2 2 2]

Notice that the elements of key2 appear in order, and within the sequences of duplicated values the corresponding elements of `key1 appear in order.

For multi-dimensional inputs, lexsort defaults to sorting along the last axis:

>>> key1 = jnp.array([[2, 4, 2, 3],
...                   [3, 1, 2, 2]])
>>> key2 = jnp.array([[1, 2, 1, 3],
...                   [2, 1, 2, 1]])
>>> jnp.lexsort([key1, key2])
Array([[0, 2, 1, 3],
       [1, 3, 2, 0]], dtype=int32)

A different sort axis can be chosen using the axis keyword; here we sort along the leading axis:

>>> jnp.lexsort([key1, key2], axis=0)
Array([[0, 1, 0, 1],
       [1, 0, 1, 0]], dtype=int32)
scico.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)

Return evenly-spaced numbers within an interval.

JAX implementation of numpy.linspace.

Parameters:
  • start (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array of starting values.

  • stop (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array of stop values.

  • num (int) – number of values to generate. Default: 50.

  • endpoint (bool) – if True (default) then include the stop value in the result. If False, then exclude the stop value.

  • retstep (bool) – If True, then return a (result, step) tuple, where step is the interval between adjacent values in result.

  • axis (int) – integer axis along which to generate the linspace. Defaults to zero.

  • device (Device | Sharding | None) – optional Device or Sharding to which the created array will be committed.

Return type:

Array | tuple[Array, Array]

Returns:

An array values, or a tuple (values, step) if retstep is True, where –

  • values is an array of evenly-spaced values from start to stop

  • step is the interval between adjacent values.

See also

Examples

List of 5 values between 0 and 10:

>>> jnp.linspace(0, 10, 5)
Array([ 0. ,  2.5,  5. ,  7.5, 10. ], dtype=float32)

List of 8 values between 0 and 10, excluding the endpoint:

>>> jnp.linspace(0, 10, 8, endpoint=False)
Array([0.  , 1.25, 2.5 , 3.75, 5.  , 6.25, 7.5 , 8.75], dtype=float32)

List of values and the step size between them

>>> vals, step = jnp.linspace(0, 10, 9, retstep=True)
>>> vals
Array([ 0.  ,  1.25,  2.5 ,  3.75,  5.  ,  6.25,  7.5 ,  8.75, 10.  ],      dtype=float32)
>>> step
Array(1.25, dtype=float32)

Multi-dimensional linspace:

>>> start = jnp.array([0, 5])
>>> stop = jnp.array([5, 10])
>>> jnp.linspace(start, stop, 5)
Array([[ 0.  ,  5.  ],
       [ 1.25,  6.25],
       [ 2.5 ,  7.5 ],
       [ 3.75,  8.75],
       [ 5.  , 10.  ]], dtype=float32)
scico.numpy.load(file, *args, **kwargs)

Load JAX arrays from npy files.

JAX wrapper of numpy.load.

This function is a simple wrapper of numpy.load, but in the case of .npy files created with numpy.save or jax.numpy.save, the output will be returned as a jax.Array, and bfloat16 data types will be restored. For .npz files, results will be returned as normal NumPy arrays.

This function requires concrete array inputs, and is not compatible with transformations like jax.jit or jax.vmap.

Parameters:
Return type:

Array

Returns:

the array stored in the file.

See also

Examples

>>> import io
>>> f = io.BytesIO()  # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)
scico.numpy.log(x, /)

Calculate element-wise natural logarithm of the input.

JAX implementation of numpy.log.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the logarithm of each element in x, promotes to inexact dtype.

See also

Examples

jnp.log and jnp.exp are inverse functions of each other. Applying jnp.log on the result of jnp.exp(x) yields the original input x.

>>> x = jnp.array([2, 3, 4, 5])
>>> jnp.log(jnp.exp(x))
Array([2., 3., 4., 5.], dtype=float32)

Using jnp.log we can demonstrate well-known properties of logarithms, such as \(log(a*b) = log(a)+log(b)\).

>>> x1 = jnp.array([2, 1, 3, 1])
>>> x2 = jnp.array([1, 3, 2, 4])
>>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2))
Array(True, dtype=bool)
scico.numpy.log10(x, /)

Calculates the base-10 logarithm of x element-wise

JAX implementation of numpy.log10.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array

Return type:

Array

Returns:

An array containing the base-10 logarithm of each element in x, promotes to inexact dtype.

Examples

>>> x1 = jnp.array([0.01, 0.1, 1, 10, 100, 1000])
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.log10(x1))
[-2. -1.  0.  1.  2.  3.]
scico.numpy.log1p(x, /)

Calculates element-wise logarithm of one plus input, log(x+1).

JAX implementation of numpy.log1p.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the logarithm of one plus of each element in x, promotes to inexact dtype.

Note

jnp.log1p is more accurate than when using the naive computation of log(x+1) for small values of x.

See also

Examples

>>> x = jnp.array([2, 5, 9, 4])
>>> jnp.allclose(jnp.log1p(x), jnp.log(x+1))
Array(True, dtype=bool)

For values very close to 0, jnp.log1p(x) is more accurate than jnp.log(x+1):

>>> x1 = jnp.array([1e-4, 1e-6, 2e-10])
>>> jnp.expm1(jnp.log1p(x1))  
Array([1.00000005e-04, 9.99999997e-07, 2.00000003e-10], dtype=float32)
>>> jnp.expm1(jnp.log(x1+1))  
Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32)
scico.numpy.log2(x, /)

Calculates the base-2 logarithm of x element-wise.

JAX implementation of numpy.log2.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array

Return type:

Array

Returns:

An array containing the base-2 logarithm of each element in x, promotes to inexact dtype.

Examples

>>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8])
>>> jnp.log2(x1)
Array([-2., -1.,  0.,  1.,  2.,  3.], dtype=float32)
scico.numpy.logaddexp(*args: ArrayLike, out: None = None, where: None = None) Any

Compute log(exp(x1) + exp(x2)) avoiding overflow.

JAX implementation of numpy.logaddexp

Parameters:
  • x1 – input array

  • x2 – input array

Returns:

array containing the result.

Examples:

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> result1 = jnp.logaddexp(x1, x2)
>>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2))
>>> print(jnp.allclose(result1, result2))
True
scico.numpy.logaddexp2(*args: ArrayLike, out: None = None, where: None = None) Any

Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.

JAX implementation of numpy.logaddexp2.

Parameters:
  • x1 – input array or scalar.

  • x2 – input array or scalar. x1 and x2 should either have same shape or be broadcast compatible.

Returns:

An array containing the result, \(log_2(2^{x1}+2^{x2})\), element-wise.

See also

  • jax.numpy.logaddexp: Computes log(exp(x1) + exp(x2)), element-wise.

  • jax.numpy.log2: Calculates the base-2 logarithm of x element-wise.

Examples

>>> x1 = jnp.array([[3, -1, 4],
...                 [8, 5, -2]])
>>> x2 = jnp.array([2, 3, -5])
>>> result1 = jnp.logaddexp2(x1, x2)
>>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2))
>>> jnp.allclose(result1, result2)
Array(True, dtype=bool)
scico.numpy.logical_and(*args: ArrayLike, out: None = None, where: None = None) Any

Compute the logical AND operation elementwise.

JAX implementation of numpy.logical_and. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc.

Parameters:
  • x – input arrays. Must be broadcastable to a common shape.

  • y – input arrays. Must be broadcastable to a common shape.

Returns:

Array containing the result of the element-wise logical AND.

Examples

>>> x = jnp.arange(4)
>>> jnp.logical_and(x, 1)
Array([False,  True,  True,  True], dtype=bool)
scico.numpy.logical_not(x, /)

Compute NOT bool(x) element-wise.

JAX implementation of numpy.logical_not.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array of any dtype.

Return type:

Array

Returns:

A boolean array that computes NOT bool(x) element-wise

See also

Examples

Compute NOT x element-wise on a boolean array:

>>> x = jnp.array([True, False, True])
>>> jnp.logical_not(x)
Array([False,  True, False], dtype=bool)

For boolean input, this is equivalent to invert, which implements the unary ~ operator:

>>> ~x
Array([False,  True, False], dtype=bool)

For non-boolean input, the input of logical_not is implicitly cast to boolean:

>>> x = jnp.array([-1, 0, 1])
>>> jnp.logical_not(x)
Array([False,  True, False], dtype=bool)
scico.numpy.logical_or(*args: ArrayLike, out: None = None, where: None = None) Any

Compute the logical OR operation elementwise.

JAX implementation of numpy.logical_or. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc.

Parameters:
  • x – input arrays. Must be broadcastable to a common shape.

  • y – input arrays. Must be broadcastable to a common shape.

Returns:

Array containing the result of the element-wise logical OR.

Examples

>>> x = jnp.arange(4)
>>> jnp.logical_or(x, 1)
Array([ True,  True,  True,  True], dtype=bool)
scico.numpy.logical_xor(*args: ArrayLike, out: None = None, where: None = None) Any

Compute the logical XOR operation elementwise.

JAX implementation of numpy.logical_xor. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc.

Parameters:
  • x – input arrays. Must be broadcastable to a common shape.

  • y – input arrays. Must be broadcastable to a common shape.

Returns:

Array containing the result of the element-wise logical XOR.

Examples

>>> x = jnp.arange(4)
>>> jnp.logical_xor(x, 1)
Array([ True, False, False, False], dtype=bool)
scico.numpy.logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0)

Generate logarithmically-spaced values.

JAX implementation of numpy.logspace.

Parameters:
  • start (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Used to specify the start value. The start value is base ** start.

  • stop (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Used to specify the stop value. The end value is base ** stop.

  • num (int) – int, optional, default=50. Number of values to generate.

  • endpoint (bool) – bool, optional, default=True. If True, then include the stop value in the result. If False, then exclude the stop value.

  • base (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array, optional, default=10. Specifies the base of the logarithm.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional. Specifies the dtype of the output.

  • axis (int) – int, optional, default=0. Axis along which to generate the logspace.

Return type:

Array

Returns:

An array of logarithm.

See also

Examples

List 5 logarithmically spaced values between 1 (10 ** 0) and 100 (10 ** 2):

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.logspace(0, 2, 5)
Array([  1.   ,   3.162,  10.   ,  31.623, 100.   ], dtype=float32)

List 5 logarithmically-spaced values between 1(10 ** 0) and 100 (10 ** 2), excluding endpoint:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.logspace(0, 2, 5, endpoint=False)
Array([ 1.   ,  2.512,  6.31 , 15.849, 39.811], dtype=float32)

List 7 logarithmically-spaced values between 1 (2 ** 0) and 4 (2 ** 2) with base 2:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.logspace(0, 2, 7, base=2)
Array([1.   , 1.26 , 1.587, 2.   , 2.52 , 3.175, 4.   ], dtype=float32)

Multi-dimensional logspace:

>>> start = jnp.array([0, 5])
>>> stop = jnp.array([5, 0])
>>> base = jnp.array([2, 3])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.logspace(start, stop, 5, base=base)
Array([[  1.   , 243.   ],
       [  2.378,  61.547],
       [  5.657,  15.588],
       [ 13.454,   3.948],
       [ 32.   ,   1.   ]], dtype=float32)
scico.numpy.mask_indices(n, mask_func, k=0, *, size=None)

Return indices of a mask of an (n, n) array.

Parameters:
  • n (int) – static integer array dimension.

  • mask_func (Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex], int], Array]) – a function that takes a shape (n, n) array and an optional offset k, and returns a shape (n, n) mask. Examples of functions with this signature are triu and tril.

  • k (int) – a scalar value passed to mask_func.

  • size (int | None) – optional argument specifying the static size of the output arrays. This is passed to nonzero when generating the indices from the mask.

Return type:

tuple[Array, Array]

Returns:

a tuple of indices where mask_func is nonzero.

See also

Examples

Calling mask_indices on built-in masking functions:

>>> jnp.mask_indices(3, jnp.triu)
(Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

Calling mask_indices on a custom masking function:

>>> def mask_func(x, k=0):
...   i = jnp.arange(x.shape[0])[:, None]
...   j = jnp.arange(x.shape[1])
...   return (i + 1) % (j + 1 + k) == 0
>>> mask_func(jnp.ones((3, 3)))
Array([[ True, False, False],
       [ True,  True, False],
       [ True, False,  True]], dtype=bool)
>>> jnp.mask_indices(3, mask_func)
(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))
scico.numpy.matmul(a, b, *, precision=None, preferred_element_type=None, out_sharding=None)

Perform a matrix multiplication.

JAX implementation of numpy.matmul.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – first input array, of shape (N,) or (..., K, N).

  • b (Union[Array, ndarray, bool, number, bool, int, float, complex]) – second input array. Must have shape (N,) or (..., N, M). In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of a.

  • precision (Union[None, str, Precision, tuple[str, str], tuple[Precision, Precision], DotAlgorithm, DotAlgorithmPreset]) – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two such values indicating precision of a and b.

  • preferred_element_type (Union[str, type[Any], dtype, SupportsDType, None]) – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Return type:

Array

Returns:

array containing the matrix product of the inputs. Shape is a.shape[:-1] if b.ndim == 1, otherwise the shape is (..., K, M), where leading dimensions of a and b are broadcast together.

See also

Examples

Vector dot products:

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.matmul(a, b)
Array(32, dtype=int32)

Matrix dot product:

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> b = jnp.array([[1, 2],
...                [3, 4],
...                [5, 6]])
>>> jnp.matmul(a, b)
Array([[22, 28],
       [49, 64]], dtype=int32)

For convenience, in all cases you can do the same computation using the @ operator:

>>> a @ b
Array([[22, 28],
       [49, 64]], dtype=int32)
scico.numpy.matrix_transpose(x, /)

Transpose the last two dimensions of an array.

JAX implementation of numpy.matrix_transpose, implemented in terms of jax.lax.transpose.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array, Must have x.ndim >= 2

Return type:

Array

Returns:

matrix-transposed copy of the array.

See also

Note

Unlike numpy.matrix_transpose, jax.numpy.matrix_transpose will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

Examples

Here is a 2x2x2 matrix representing a batched 2x2 matrix:

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

For convenience, you can perform the same transpose via the mT property of jax.Array:

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)
scico.numpy.max(a, axis=None, out=None, keepdims=False, initial=None, where=None)

Return the maximum of the array elements along a given axis.

JAX implementation of numpy.max.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array.

  • axis (Union[int, Sequence[int], None]) – int or array, default=None. Axis along which the maximum to be computed. If None, the maximum is computed along all the axes.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • initial (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array, default=None. Initial value for the maximum.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array of boolean dtype, default=None. The elements to be used in the maximum. Array should be broadcast compatible to the input. initial must be specified when where is used.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of maximum values along the given axis.

See also

  • jax.numpy.min: Compute the minimum of array elements along a given axis.

  • jax.numpy.sum: Compute the sum of array elements along a given axis.

  • jax.numpy.prod: Compute the product of array elements along a given axis.

Examples

By default, jnp.max computes the maximum of elements along all the axes.

>>> x = jnp.array([[9, 3, 4, 5],
...                [5, 2, 7, 4],
...                [8, 1, 3, 6]])
>>> jnp.max(x)
Array(9, dtype=int32)

If axis=1, the maximum will be computed along axis 1.

>>> jnp.max(x, axis=1)
Array([9, 7, 8], dtype=int32)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.max(x, axis=1, keepdims=True)
Array([[9],
       [7],
       [8]], dtype=int32)

To include only specific elements in computing the maximum, you can use where. It can either have same dimension as input

>>> where=jnp.array([[0, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.max(x, axis=1, keepdims=True, initial=0, where=where)
Array([[4],
       [7],
       [8]], dtype=int32)

or must be broadcast compatible with input.

>>> where = jnp.array([[False],
...                    [False],
...                    [False]])
>>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
scico.numpy.maximum(*args: ArrayLike, out: None = None, where: None = None) Any

Return element-wise maximum of the input arrays.

JAX implementation of numpy.maximum.

Parameters:
  • x – input array or scalar.

  • y – input array or scalar. Both x and y should either have same shape or be broadcast compatible.

Returns:

An array containing the element-wise maximum of x and y.

Note

For each pair of elements, jnp.maximum returns:
  • larger of the two if both elements are finite numbers.

  • nan if one element is nan.

See also

  • jax.numpy.minimum: Returns element-wise minimum of the input arrays.

  • jax.numpy.fmax: Returns element-wise maximum of the input arrays, ignoring NaNs.

  • jax.numpy.amax: Returns the maximum of array elements along a given axis.

  • jax.numpy.nanmax: Returns the maximum of the array elements along a given axis, ignoring NaNs.

Examples

Inputs with x.shape == y.shape:

>>> x = jnp.array([1, -5, 3, 2])
>>> y = jnp.array([-2, 4, 7, -6])
>>> jnp.maximum(x, y)
Array([1, 4, 7, 2], dtype=int32)

Inputs with broadcast compatibility:

>>> x1 = jnp.array([[-2, 5, 7, 4],
...                 [1, -6, 3, 8]])
>>> y1 = jnp.array([-5, 3, 6, 9])
>>> jnp.maximum(x1, y1)
Array([[-2,  5,  7,  9],
       [ 1,  3,  6,  9]], dtype=int32)

Inputs having nan:

>>> nan = jnp.nan
>>> x2 = jnp.array([nan, -3, 9])
>>> y2 = jnp.array([[4, -2, nan],
...                 [-3, -5, 10]])
>>> jnp.maximum(x2, y2)
Array([[nan, -2., nan],
      [nan, -3., 10.]], dtype=float32)
scico.numpy.mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=None)

Return the mean of array elements along a given axis.

JAX implementation of numpy.mean.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array.

  • axis (Union[int, Sequence[int], None]) – optional, int or sequence of ints, default=None. Axis along which the mean to be computed. If None, mean is computed along all the axes.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – The type of the output array. If None (default) then the output dtype will be match the input dtype for floating point inputs, or be set to float32 or float64 for non-floating-point inputs.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – optional, boolean array, default=None. The elements to be used in the mean. Array should be broadcast compatible to the input.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of the mean along the given axis.

Notes

For inputs of type float16 or bfloat16, the reductions will be performed at float32 precision.

See also

Examples

By default, the mean is computed along all the axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [5, 2, 6, 3],
...                [8, 1, 2, 9]])
>>> jnp.mean(x)
Array(3.8333335, dtype=float32)

If axis=1, the mean is computed along axis 1.

>>> jnp.mean(x, axis=1)
Array([2.5, 4. , 5. ], dtype=float32)

If keepdims=True, ndim of the output is equal to that of the input.

>>> jnp.mean(x, axis=1, keepdims=True)
Array([[2.5],
       [4. ],
       [5. ]], dtype=float32)

To use only specific elements of x to compute the mean, you can use where.

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 0, 1],
...                    [1, 1, 0, 1]], dtype=bool)
>>> jnp.mean(x, axis=1, keepdims=True, where=where)
Array([[2.5],
       [2.5],
       [6. ]], dtype=float32)
scico.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')

Construct N-dimensional grid arrays from N 1-dimensional vectors.

JAX implementation of numpy.meshgrid.

Parameters:
  • xi (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N arrays to convert to a grid.

  • copy (bool) – whether to copy the input arrays. JAX supports only copy=True, though under JIT compilation the compiler may opt to avoid copies.

  • sparse (bool) – if False (default), then each returned arrays will be of shape [len(x1), len(x2), ..., len(xN)]. If False, then returned arrays will be of shape [1, 1, ..., len(xi), ..., 1, 1].

  • indexing (str) – options are 'xy' for cartesian indexing (default) or 'ij' for matrix indexing.

Return type:

list[Array]

Returns:

A length-N list of grid arrays.

See also

Examples

For the following examples, we’ll use these 1D arrays as inputs:

>>> x = jnp.array([1, 2])
>>> y = jnp.array([10, 20, 30])

2D cartesian mesh grid:

>>> x_grid, y_grid = jnp.meshgrid(x, y)
>>> print(x_grid)
[[1 2]
 [1 2]
 [1 2]]
>>> print(y_grid)
[[10 10]
 [20 20]
 [30 30]]

2D sparse cartesian mesh grid:

>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True)
>>> print(x_grid)
[[1 2]]
>>> print(y_grid)
[[10]
 [20]
 [30]]

2D matrix-index mesh grid:

>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij')
>>> print(x_grid)
[[1 1 1]
 [2 2 2]]
>>> print(y_grid)
[[10 20 30]
 [10 20 30]]
scico.numpy.min(a, axis=None, out=None, keepdims=False, initial=None, where=None)

Return the minimum of array elements along a given axis.

JAX implementation of numpy.min.

Parameters:
Return type:

Array

Returns:

An array of minimum values along the given axis.

See also

  • jax.numpy.max: Compute the maximum of array elements along a given axis.

  • jax.numpy.sum: Compute the sum of array elements along a given axis.

  • jax.numpy.prod: Compute the product of array elements along a given axis.

Examples

By default, the minimum is computed along all the axes.

>>> x = jnp.array([[2, 5, 1, 6],
...                [3, -7, -2, 4],
...                [8, -4, 1, -3]])
>>> jnp.min(x)
Array(-7, dtype=int32)

If axis=1, the minimum is computed along axis 1.

>>> jnp.min(x, axis=1)
Array([ 1, -7, -4], dtype=int32)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.min(x, axis=1, keepdims=True)
Array([[ 1],
       [-7],
       [-4]], dtype=int32)

To include only specific elements in computing the minimum, you can use where. where can either have same dimension as input.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where)
Array([[ 0],
       [-2],
       [-4]], dtype=int32)

or must be broadcast compatible with input.

>>> where = jnp.array([[False],
...                    [False],
...                    [False]])
>>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
scico.numpy.minimum(*args: ArrayLike, out: None = None, where: None = None) Any

Return element-wise minimum of the input arrays.

JAX implementation of numpy.minimum.

Parameters:
  • x – input array or scalar.

  • y – input array or scalar. Both x and y should either have same shape or be broadcast compatible.

Returns:

An array containing the element-wise minimum of x and y.

Note

For each pair of elements, jnp.minimum returns:
  • smaller of the two if both elements are finite numbers.

  • nan if one element is nan.

See also

  • jax.numpy.maximum: Returns element-wise maximum of the input arrays.

  • jax.numpy.fmin: Returns element-wise minimum of the input arrays, ignoring NaNs.

  • jax.numpy.amin: Returns the minimum of array elements along a given axis.

  • jax.numpy.nanmin: Returns the minimum of the array elements along a given axis, ignoring NaNs.

Examples

Inputs with x.shape == y.shape:

>>> x = jnp.array([2, 3, 5, 1])
>>> y = jnp.array([-3, 6, -4, 7])
>>> jnp.minimum(x, y)
Array([-3,  3, -4,  1], dtype=int32)

Inputs having broadcast compatibility:

>>> x1 = jnp.array([[1, 5, 2],
...                 [-3, 4, 7]])
>>> y1 = jnp.array([-2, 3, 6])
>>> jnp.minimum(x1, y1)
Array([[-2,  3,  2],
       [-3,  3,  6]], dtype=int32)

Inputs with nan:

>>> nan = jnp.nan
>>> x2 = jnp.array([[2.5, nan, -2],
...                 [nan, 5, 6],
...                 [-4, 3, 7]])
>>> y2 = jnp.array([1, nan, 5])
>>> jnp.minimum(x2, y2)
Array([[ 1., nan, -2.],
       [nan, nan,  5.],
       [-4., nan,  5.]], dtype=float32)
scico.numpy.mod(x1, x2, /)

Alias of jax.numpy.remainder

Return type:

Array

scico.numpy.modf(x, /, out=None)

Return element-wise fractional and integral parts of the input array.

JAX implementation of numpy.modf.

Parameters:
Return type:

tuple[Array, Array]

Returns:

An array containing the fractional and integral parts of the elements of x, promoting dtypes inexact.

See also

  • jax.numpy.divmod: Calculates the integer quotient and remainder of x1 by x2 element-wise.

Examples

>>> jnp.modf(4.8)
(Array(0.8000002, dtype=float32, weak_type=True), Array(4., dtype=float32, weak_type=True))
>>> x = jnp.array([-3.4, -5.7, 0.6, 1.5, 2.3])
>>> jnp.modf(x)
(Array([-0.4000001 , -0.6999998 ,  0.6       ,  0.5       ,  0.29999995],      dtype=float32), Array([-3., -5.,  0.,  1.,  2.], dtype=float32))
scico.numpy.moveaxis(a, source, destination)

Move an array axis to a new position

JAX implementation of numpy.moveaxis, implemented in terms of jax.lax.transpose.

Parameters:
Return type:

Array

Returns:

Copy of a with axes moved from source to destination.

Notes

Unlike numpy.moveaxis, jax.numpy.moveaxis will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> a = jnp.ones((2, 3, 4, 5))

Move axis 1 to the end of the array:

>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)

Move the last axis to position 1:

>>> jnp.moveaxis(a, -1, 1).shape
(2, 5, 3, 4)

Move multiple axes:

>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape
(4, 5, 3, 2)

This can also be accomplished via transpose:

>>> a.transpose(2, 3, 1, 0).shape
(4, 5, 3, 2)
scico.numpy.multiply(*args: ArrayLike, out: None = None, where: None = None) Any

Multiply two arrays element-wise.

JAX implementation of numpy.multiply. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc. This function provides the implementation of the * operator for JAX arrays.

Parameters:
  • x – arrays to multiply. Must be broadcastable to a common shape.

  • y – arrays to multiply. Must be broadcastable to a common shape.

Returns:

Array containing the result of the element-wise multiplication.

Examples

Calling multiply explicitly:

>>> x = jnp.arange(4)
>>> jnp.multiply(x, 10)
Array([ 0, 10, 20, 30], dtype=int32)

Calling multiply via the * operator:

>>> x * 10
Array([ 0, 10, 20, 30], dtype=int32)
scico.numpy.nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)

Replace NaN and infinite entries in an array.

JAX implementation of numpy.nan_to_num.

Parameters:
Return type:

Array

Returns:

A copy of x with the requested substitutions.

See also

Examples

>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])

Default substitution values:

>>> jnp.nan_to_num(x)
Array([ 0.0000000e+00,  0.0000000e+00,  1.0000000e+00,  3.4028235e+38,
        2.0000000e+00, -3.4028235e+38], dtype=float32)

Overriding substitutions for -inf and +inf:

>>> jnp.nan_to_num(x, posinf=999, neginf=-999)
Array([   0.,    0.,    1.,  999.,    2., -999.], dtype=float32)

If you only wish to substitute for NaN values while leaving inf values untouched, using where with jax.numpy.isnan is a better option:

>>> jnp.where(jnp.isnan(x), 0, x)
Array([  0.,   0.,   1.,  inf,   2., -inf], dtype=float32)
scico.numpy.nanargmax(a, axis=None, out=None, keepdims=None)

Return the index of the maximum value of an array, ignoring NaNs.

JAX implementation of numpy.nanargmax.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array

  • axis (int | None) – optional integer specifying the axis along which to find the maximum value. If axis is not specified, a will be flattened.

  • out (None) – unused by JAX

  • keepdims (bool | None) – if True, then return an array with the same number of dimensions as a.

Return type:

Array

Returns:

an array containing the index of the maximum value along the specified axis.

Note

In the case of an axis with all-NaN values, the returned index will be -1. This differs from the behavior of numpy.nanargmax, which raises an error.

See also

Examples

>>> x = jnp.array([1, 3, 5, 4, jnp.nan])

Using a standard argmax leads to potentially unexpected results:

>>> jnp.argmax(x)
Array(4, dtype=int32)

Using nanargmax returns the index of the maximum non-NaN value.

>>> jnp.nanargmax(x)
Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan],
...                [5, 4, jnp.nan]])
>>> jnp.nanargmax(x, axis=1)
Array([1, 0], dtype=int32)
>>> jnp.nanargmax(x, axis=1, keepdims=True)
Array([[1],
       [0]], dtype=int32)
scico.numpy.nanargmin(a, axis=None, out=None, keepdims=None)

Return the index of the minimum value of an array, ignoring NaNs.

JAX implementation of numpy.nanargmin.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array

  • axis (int | None) – optional integer specifying the axis along which to find the maximum value. If axis is not specified, a will be flattened.

  • out (None) – unused by JAX

  • keepdims (bool | None) – if True, then return an array with the same number of dimensions as a.

Return type:

Array

Returns:

an array containing the index of the minimum value along the specified axis.

Note

In the case of an axis with all-NaN values, the returned index will be -1. This differs from the behavior of numpy.nanargmin, which raises an error.

See also

Examples

>>> x = jnp.array([jnp.nan, 3, 5, 4, 2])
>>> jnp.nanargmin(x)
Array(4, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan],
...                [5, 4, jnp.nan]])
>>> jnp.nanargmin(x, axis=1)
Array([0, 1], dtype=int32)
>>> jnp.nanargmin(x, axis=1, keepdims=True)
Array([[0],
       [1]], dtype=int32)
scico.numpy.nancumprod(a, axis=None, dtype=None, out=None)

Cumulative product of elements along an axis, ignoring NaN values.

JAX implementation of numpy.nancumprod.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array to be accumulated.

  • axis (int | None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.

  • out (None) – unused by JAX

Return type:

Array

Returns:

An array containing the accumulated product along the given axis.

See also

  • jax.numpy.cumprod: cumulative product without ignoring NaN values.

  • jax.numpy.multiply.accumulate: cumulative product via ufunc methods.

  • jax.numpy.prod: product along axis

Examples

>>> x = jnp.array([[1., 2., jnp.nan],
...                [4., jnp.nan, 6.]])

The standard cumulative product will propagate NaN values:

>>> jnp.cumprod(x)
Array([ 1.,  2., nan, nan, nan, nan], dtype=float32)

nancumprod will ignore NaN values, effectively replacing them with ones:

>>> jnp.nancumprod(x)
Array([ 1.,  2.,  2.,  8.,  8., 48.], dtype=float32)

Cumulative product along axis 1:

>>> jnp.nancumprod(x, axis=1)
Array([[ 1.,  2.,  2.],
       [ 4.,  4., 24.]], dtype=float32)
scico.numpy.nancumsum(a, axis=None, dtype=None, out=None)

Cumulative sum of elements along an axis, ignoring NaN values.

JAX implementation of numpy.nancumsum.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array to be accumulated.

  • axis (int | None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.

  • out (None) – unused by JAX

Return type:

Array

Returns:

An array containing the accumulated sum along the given axis.

See also

Examples

>>> x = jnp.array([[1., 2., jnp.nan],
...                [4., jnp.nan, 6.]])

The standard cumulative sum will propagate NaN values:

>>> jnp.cumsum(x)
Array([ 1.,  3., nan, nan, nan, nan], dtype=float32)

nancumsum will ignore NaN values, effectively replacing them with zeros:

>>> jnp.nancumsum(x)
Array([ 1.,  3.,  3.,  7.,  7., 13.], dtype=float32)

Cumulative sum along axis 1:

>>> jnp.nancumsum(x, axis=1)
Array([[ 1.,  3.,  3.],
       [ 4.,  4., 10.]], dtype=float32)
scico.numpy.nanmax(a, axis=None, out=None, keepdims=False, initial=None, where=None)

Return the maximum of the array elements along a given axis, ignoring NaNs.

JAX implementation of numpy.nanmax.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array.

  • axis (Union[int, Sequence[int], None]) – int or sequence of ints, default=None. Axis along which the maximum is computed. If None, the maximum is computed along the flattened array.

  • keepdims (bool) – bool, default=False. If True, reduced axes are left in the result with size 1.

  • initial (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array, default=None. Initial value for the maximum.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – array of boolean dtype, default=None. The elements to be used in the maximum. Array should be broadcast compatible to the input. initial must be specified when where is used.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of maximum values along the given axis, ignoring NaNs. If all values are NaNs along the given axis, returns nan.

See also

  • jax.numpy.nanmin: Compute the minimum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nansum: Compute the sum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanprod: Compute the product of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanmean: Compute the mean of array elements along a given axis, ignoring NaNs.

Examples

By default, jnp.nanmax computes the maximum of elements along the flattened array.

>>> nan = jnp.nan
>>> x = jnp.array([[8, nan, 4, 6],
...                [nan, -2, nan, -4],
...                [-2, 1, 7, nan]])
>>> jnp.nanmax(x)
Array(8., dtype=float32)

If axis=1, the maximum will be computed along axis 1.

>>> jnp.nanmax(x, axis=1)
Array([ 8., -2.,  7.], dtype=float32)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.nanmax(x, axis=1, keepdims=True)
Array([[ 8.],
       [-2.],
       [ 7.]], dtype=float32)

To include only specific elements in computing the maximum, you can use where. It can either have same dimension as input

>>> where=jnp.array([[0, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.nanmax(x, axis=1, keepdims=True, initial=0, where=where)
Array([[4.],
       [0.],
       [7.]], dtype=float32)

or must be broadcast compatible with input.

>>> where = jnp.array([[True],
...                    [False],
...                    [False]])
>>> jnp.nanmax(x, axis=0, keepdims=True, initial=0, where=where)
Array([[8., 0., 4., 6.]], dtype=float32)
scico.numpy.nanmin(a, axis=None, out=None, keepdims=False, initial=None, where=None)

Return the minimum of the array elements along a given axis, ignoring NaNs.

JAX implementation of numpy.nanmin.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array.

  • axis (Union[int, Sequence[int], None]) – int or sequence of ints, default=None. Axis along which the minimum is computed. If None, the minimum is computed along the flattened array.

  • keepdims (bool) – bool, default=False. If True, reduced axes are left in the result with size 1.

  • initial (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array, default=None. Initial value for the minimum.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – array of boolean dtype, default=None. The elements to be used in the minimum. Array should be broadcast compatible to the input. initial must be specified when where is used.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of minimum values along the given axis, ignoring NaNs. If all values are NaNs along the given axis, returns nan.

See also

  • jax.numpy.nanmax: Compute the maximum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nansum: Compute the sum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanprod: Compute the product of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanmean: Compute the mean of array elements along a given axis, ignoring NaNs.

Examples

By default, jnp.nanmin computes the minimum of elements along the flattened array.

>>> nan = jnp.nan
>>> x = jnp.array([[1, nan, 4, 5],
...                [nan, -2, nan, -4],
...                [2, 1, 3, nan]])
>>> jnp.nanmin(x)
Array(-4., dtype=float32)

If axis=1, the maximum will be computed along axis 1.

>>> jnp.nanmin(x, axis=1)
Array([ 1., -4.,  1.], dtype=float32)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.nanmin(x, axis=1, keepdims=True)
Array([[ 1.],
       [-4.],
       [ 1.]], dtype=float32)

To include only specific elements in computing the maximum, you can use where. It can either have same dimension as input

>>> where=jnp.array([[0, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.nanmin(x, axis=1, keepdims=True, initial=0, where=where)
Array([[ 0.],
       [-4.],
       [ 0.]], dtype=float32)

or must be broadcast compatible with input.

>>> where = jnp.array([[False],
...                    [True],
...                    [False]])
>>> jnp.nanmin(x, axis=0, keepdims=True, initial=0, where=where)
Array([[ 0., -2.,  0., -4.]], dtype=float32)
scico.numpy.nanprod(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None)

Return the product of the array elements along a given axis, ignoring NaNs.

JAX implementation of numpy.nanprod.

Parameters:
Return type:

Array

Returns:

An array containing the product of array elements along the given axis, ignoring NaNs. If all elements along the given axis are NaNs, returns 1.

See also

  • jax.numpy.nanmin: Compute the minimum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanmax: Compute the maximum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nansum: Compute the sum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanmean: Compute the mean of array elements along a given axis, ignoring NaNs.

Examples

By default, jnp.nanprod computes the product of elements along the flattened array.

>>> nan = jnp.nan
>>> x = jnp.array([[nan, 3, 4, nan],
...                [5, nan, 1, 3],
...                [2, 1, nan, 1]])
>>> jnp.nanprod(x)
Array(360., dtype=float32)

If axis=1, the product will be computed along axis 1.

>>> jnp.nanprod(x, axis=1)
Array([12., 15.,  2.], dtype=float32)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.nanprod(x, axis=1, keepdims=True)
Array([[12.],
       [15.],
       [ 2.]], dtype=float32)

To include only specific elements in computing the maximum, you can use where.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.nanprod(x, axis=1, keepdims=True, where=where)
Array([[4.],
       [3.],
       [2.]], dtype=float32)

If where is False at all elements, jnp.nanprod returns 1 along the given axis.

>>> where = jnp.array([[False],
...                    [False],
...                    [False]])
>>> jnp.nanprod(x, axis=0, keepdims=True, where=where)
Array([[1., 1., 1., 1.]], dtype=float32)
scico.numpy.nansum(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None)

Return the sum of the array elements along a given axis, ignoring NaNs.

JAX implementation of numpy.nansum.

Parameters:
Return type:

Array

Returns:

An array containing the sum of array elements along the given axis, ignoring NaNs. If all elements along the given axis are NaNs, returns 0.

See also

  • jax.numpy.nanmin: Compute the minimum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanmax: Compute the maximum of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanprod: Compute the product of array elements along a given axis, ignoring NaNs.

  • jax.numpy.nanmean: Compute the mean of array elements along a given axis, ignoring NaNs.

Examples

By default, jnp.nansum computes the sum of elements along the flattened array.

>>> nan = jnp.nan
>>> x = jnp.array([[3, nan, 4, 5],
...                [nan, -2, nan, 7],
...                [2, 1, 6, nan]])
>>> jnp.nansum(x)
Array(26., dtype=float32)

If axis=1, the sum will be computed along axis 1.

>>> jnp.nansum(x, axis=1)
Array([12.,  5.,  9.], dtype=float32)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.nansum(x, axis=1, keepdims=True)
Array([[12.],
       [ 5.],
       [ 9.]], dtype=float32)

To include only specific elements in computing the sum, you can use where.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.nansum(x, axis=1, keepdims=True, where=where)
Array([[7.],
       [7.],
       [9.]], dtype=float32)

If where is False at all elements, jnp.nansum returns 0 along the given axis.

>>> where = jnp.array([[False],
...                    [False],
...                    [False]])
>>> jnp.nansum(x, axis=0, keepdims=True, where=where)
Array([[0., 0., 0., 0.]], dtype=float32)
scico.numpy.ndim(a)

Return the number of dimensions of an array.

JAX implementation of numpy.ndim. Unlike np.ndim, this function raises a TypeError if the input is a collection such as a list or tuple.

Parameters:

a (Union[Array, ndarray, bool, number, bool, int, float, complex, SupportsNdim]) – array-like object, or any object with an ndim attribute.

Return type:

int

Returns:

An integer specifying the number of dimensions of a.

Examples

Number of dimensions for arrays:

>>> x = jnp.arange(10)
>>> jnp.ndim(x)
1
>>> y = jnp.ones((2, 3))
>>> jnp.ndim(y)
2

This also works for scalars:

>>> jnp.ndim(3.14)
0

For arrays, this can also be accessed via the jax.Array.ndim property:

>>> x.ndim
1
scico.numpy.negative(*args: ArrayLike, out: None = None, where: None = None) Any

Return element-wise negative values of the input.

JAX implementation of numpy.negative.

Parameters:

x – input array or scalar.

Returns:

An array with same shape and dtype as x containing -x.

See also

Note

jnp.negative, when applied over unsigned integer, produces the result of their two’s complement negation, which typically results in unexpected large positive values due to integer underflow.

Examples

For real-valued inputs:

>>> x = jnp.array([0., -3., 7])
>>> jnp.negative(x)
Array([-0.,  3., -7.], dtype=float32)

For complex inputs:

>>> x1 = jnp.array([1-2j, -3+4j, 5-6j])
>>> jnp.negative(x1)
Array([-1.+2.j,  3.-4.j, -5.+6.j], dtype=complex64)

For unit32:

>>> x2 = jnp.array([5, 0, -7]).astype(jnp.uint32)
>>> x2
Array([         5,          0, 4294967289], dtype=uint32)
>>> jnp.negative(x2)
Array([4294967291,          0,          7], dtype=uint32)
scico.numpy.nextafter(x, y, /)

Return element-wise next floating point value after x towards y.

JAX implementation of numpy.nextafter.

Parameters:
Return type:

Array

Returns:

An array containing the next representable number of x in the direction of y.

Examples

>>> jnp.nextafter(2, 1)  
Array(1.9999999, dtype=float32, weak_type=True)
>>> x = jnp.array([3, -2, 1])
>>> y = jnp.array([2, -1, 2])
>>> jnp.nextafter(x, y)  
Array([ 2.9999998, -1.9999999,  1.0000001], dtype=float32)
scico.numpy.nonzero(a, *, size=None, fill_value=None)

Return indices of nonzero elements of an array.

JAX implementation of numpy.nonzero.

Because the size of the output of nonzero is data-dependent, the function is not compatible with JIT and other transformations. The JAX version adds the optional size argument which must be specified statically for jnp.nonzero to be used within JAX’s transformations.

Parameters:
Return type:

tuple[Array, ...]

Returns:

Tuple of JAX Arrays of length a.ndim, containing the indices of each nonzero value.

Examples

One-dimensional array returns a length-1 tuple of indices:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)

Two-dimensional array returns a length-2 tuple of indices:

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

In either case, the resulting tuple of indices can be used directly to extract the nonzero values:

>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)

The output of nonzero has a dynamic shape, because the number of returned indices depends on the contents of the input array. As such, it is incompatible with JIT and other JAX transformations:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

This can be addressed by passing a static size parameter to specify the desired output shape:

>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)

If size does not match the true size, the result will be either truncated or padded:

>>> nonzero_jit(x, size=2)  # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5)  # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)

You can specify a custom fill value for the padding using the fill_value argument:

>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)
scico.numpy.not_equal(x, y, /)

Returns element-wise truth value of x != y.

JAX implementation of numpy.not_equal. This function provides the implementation of the != operator for JAX arrays.

Parameters:
Return type:

Array

Returns:

A boolean array containing True where the elements of x != y and False otherwise.

See also

Examples

>>> jnp.not_equal(0., -0.)
Array(False, dtype=bool, weak_type=True)
>>> jnp.not_equal(-2, 2)
Array(True, dtype=bool, weak_type=True)
>>> jnp.not_equal(1, 1.)
Array(False, dtype=bool, weak_type=True)
>>> jnp.not_equal(5, jnp.array(5))
Array(False, dtype=bool, weak_type=True)
>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> y = jnp.array([1, 5, 9])
>>> jnp.not_equal(x, y)
Array([[False,  True,  True],
       [ True, False,  True],
       [ True,  True, False]], dtype=bool)
>>> x != y
Array([[False,  True,  True],
       [ True, False,  True],
       [ True,  True, False]], dtype=bool)
scico.numpy.ones(shape, dtype=None, *, device=None, out_sharding=None)

Create an array full of ones.

JAX implementation of numpy.ones.

Parameters:
  • shape (Any) – int or sequence of ints specifying the shape of the created array.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype for the created array; defaults to float32 or float64 depending on the X64 configuration (see Default dtypes and the X64 flag).

  • device (Device | Sharding | None) – (optional) Device or Sharding to which the created array will be committed. This argument exists for compatibility with the Python Array API standard.

  • out_sharding (NamedSharding | P | None) – (optional) PartitionSpec or NamedSharding representing the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying both out_sharding and device will result in an error.

Return type:

Array

Returns:

Array of the specified shape and dtype, with the given device/sharding if specified.

Examples

>>> jnp.ones(4)
Array([1., 1., 1., 1.], dtype=float32)
>>> jnp.ones((2, 3), dtype=bool)
Array([[ True,  True,  True],
       [ True,  True,  True]], dtype=bool)
scico.numpy.ones_like(a, dtype=None, shape=None, *, device=None, out_sharding=None)

Create an array of ones with the same shape and dtype as an array.

JAX implementation of numpy.ones_like.

Parameters:
Return type:

Array

Returns:

Array of the specified shape and dtype, on the specified device if specified.

Examples

>>> x = jnp.arange(4)
>>> jnp.ones_like(x)
Array([1, 1, 1, 1], dtype=int32)
>>> jnp.ones_like(x, dtype=bool)
Array([ True,  True,  True,  True], dtype=bool)
>>> jnp.ones_like(x, shape=(2, 3))
Array([[1, 1, 1],
       [1, 1, 1]], dtype=int32)
scico.numpy.outer(a, b, out=None)

Compute the outer product of two arrays.

JAX implementation of numpy.outer.

Parameters:
Return type:

Array

Returns:

The outer product of the inputs a and b. Returned array will be of shape (a.size, b.size).

See also

Examples

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.outer(a, b)
Array([[ 4,  5,  6],
       [ 8, 10, 12],
       [12, 15, 18]], dtype=int32)
scico.numpy.pad(array, pad_width, mode='constant', **kwargs)

Add padding to an array.

JAX implementation of numpy.pad.

Parameters:
  • array (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array to pad.

  • pad_width (Union[int, Array, ndarray, Sequence[int | Array | ndarray], Sequence[Sequence[int | Array | ndarray]]]) –

    specify the pad width for each dimension of an array. Padding widths may be separately specified for before and after the array. Options are:

    • int or (int,): pad each array dimension with the same number of values both before and after.

    • (before, after): pad each array with before elements before, and after elements after

    • ((before_1, after_1), (before_2, after_2), ... (before_N, after_N)): specify distinct before and after values for each array dimension.

  • mode (str | Callable[..., Any]) –

    a string or callable. Supported pad modes are:

    • 'constant' (default): pad with a constant value, which defaults to zero.

    • 'empty': pad with empty values (i.e. zero)

    • 'edge': pad with the edge values of the array.

    • 'wrap': pad by wrapping the array.

    • 'linear_ramp': pad with a linear ramp to specified end_values.

    • 'maximum': pad with the maximum value.

    • 'mean': pad with the mean value.

    • 'median': pad with the median value.

    • 'minimum': pad with the minimum value.

    • 'reflect': pad by reflection.

    • 'symmetric': pad by symmetric reflection.

    • <callable>: a callable function. See Notes below.

  • constant_values – referenced for mode = 'constant'. Specify the constant value to pad with.

  • stat_length – referenced for mode in ['maximum', 'mean', 'median', 'minimum']. An integer or tuple specifying the number of edge values to use when calculating the statistic.

  • end_values – referenced for mode = 'linear_ramp'. Specify the end values to ramp the padding values to.

  • reflect_type – referenced for mode in ['reflect', 'symmetric']. Specify whether to use even or odd reflection.

Return type:

Array

Returns:

A padded copy of array.

Notes

When mode is callable, it should have the following signature:

def pad_func(row: Array, pad_width: tuple[int, int],
             iaxis: int, kwargs: dict) -> Array:
  ...

Here row is a 1D slice of the padded array along axis iaxis, with the pad values filled with zeros. pad_width is a tuple specifying the (before, after) padding sizes, and kwargs are any additional keyword arguments passed to the jax.numpy.pad function.

Note that while in NumPy, the function should modify row in-place, in JAX the function should return the modified row. In JAX, the custom padding function will be mapped across the padded axis using the jax.vmap transformation.

See also

Examples

Pad a 1-dimensional array with zeros:

>>> x = jnp.array([10, 20, 30, 40])
>>> jnp.pad(x, 2)
Array([ 0,  0, 10, 20, 30, 40,  0,  0], dtype=int32)
>>> jnp.pad(x, (2, 4))
Array([ 0,  0, 10, 20, 30, 40,  0,  0,  0,  0], dtype=int32)

Pad a 1-dimensional array with specified values:

>>> jnp.pad(x, 2, constant_values=99)
Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)

Pad a 1-dimensional array with the mean array value:

>>> jnp.pad(x, 2, mode='mean')
Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)

Pad a 1-dimensional array with reflected values:

>>> jnp.pad(x, 2, mode='reflect')
Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)

Pad a 2-dimensional array with different paddings in each dimension:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.pad(x, ((1, 2), (3, 0)))
Array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 2, 3],
       [0, 0, 0, 4, 5, 6],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]], dtype=int32)

Pad a 1-dimensional array with a custom padding function:

>>> def custom_pad(row, pad_width, iaxis, kwargs):
...   # row represents a 1D slice of the zero-padded array.
...   before, after = pad_width
...   before_value = kwargs.get('before_value', 0)
...   after_value = kwargs.get('after_value', 0)
...   row = row.at[:before].set(before_value)
...   return row.at[len(row) - after:].set(after_value)
>>> x = jnp.array([2, 3, 4])
>>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10)
Array([-10, -10,   2,   3,   4,  10,  10], dtype=int32)
scico.numpy.partition(a, kth, axis=-1)

Returns a partially-sorted copy of an array.

JAX implementation of numpy.partition. The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array to be partitioned.

  • kth (int) – static integer index about which to partition the array.

  • axis (int) – static integer axis along which to partition the array; default is -1.

Return type:

Array

Returns:

A copy of a partitioned at the kth value along axis. The entries before kth are values smaller than take(a, kth, axis), and entries after kth are indices of values larger than take(a, kth, axis)

Note

The JAX version requires the kth argument to be a static integer rather than a general array. This is implemented via two calls to jax.lax.top_k. If you’re only accessing the top or bottom k values of the output, it may be more efficient to call jax.lax.top_k directly.

See also

Examples

>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> x_partitioned = jnp.partition(x, kth)
>>> x_partitioned
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)

The result is a partially-sorted copy of the input. All values before kth are of smaller than the pivot value, and all values after kth are larger than the pivot value:

>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [9 8 7 6 5]

Notice that among smallest_values and largest_values, the returned order is arbitrary and implementation-dependent.

scico.numpy.permute_dims(a, /, axes)

Permute the axes/dimensions of an array.

JAX implementation of array_api.permute_dims.

Parameters:
Return type:

Array

Returns:

a copy of a with axes permuted.

Examples

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.permute_dims(a, (1, 0))
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
scico.numpy.piecewise(x, condlist, funclist, *args, **kw)

Evaluate a function defined piecewise across the domain.

JAX implementation of numpy.piecewise, in terms of jax.lax.switch.

Note

Unlike numpy.piecewise, jax.numpy.piecewise requires functions in funclist to be traceable by JAX, as it is implemented via jax.lax.switch.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of input values.

  • condlist (Array | Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – boolean array or sequence of boolean arrays corresponding to the functions in funclist. If a sequence of arrays, the length of each array must match the length of x

  • funclist (list[Union[Array, ndarray, bool, number, bool, int, float, complex, Callable[..., Array]]]) – list of arrays or functions; must either be the same length as condlist, or have length len(condlist) + 1, in which case the last entry is the default applied when none of the conditions are True. Alternatively, entries of funclist may be numerical values, in which case they indicate a constant function.

  • args – additional arguments are passed to each function in funclist.

  • kwargs – additional arguments are passed to each function in funclist.

Return type:

Array

Returns:

An array which is the result of evaluating the functions on x at the specified conditions.

See also

Examples

Here’s an example of a function which is zero for negative values, and linear for positive values:

>>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
>>> condlist = [x < 0, x >= 0]
>>> funclist = [lambda x: 0 * x, lambda x: x]
>>> jnp.piecewise(x, condlist, funclist)
Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

funclist can also contain a simple scalar value for constant functions:

>>> condlist = [x < 0, x >= 0]
>>> funclist = [0, lambda x: x]
>>> jnp.piecewise(x, condlist, funclist)
Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

You can specify a default value by appending an extra condition to funclist:

>>> condlist = [x < -1, x > 1]
>>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0]
>>> jnp.piecewise(x, condlist, funclist)
Array([-3, -2,  -1,  0,  0,  0,  1,  2, 3], dtype=int32)

condlist may also be a simple array of scalar conditions, in which case the associated function applies to the whole range

>>> condlist = jnp.array([False, True, False])
>>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100]
>>> jnp.piecewise(x, condlist, funclist)
Array([-40, -30, -20, -10,   0,  10,  20,  30,  40], dtype=int32)
scico.numpy.place(arr, mask, vals, *, inplace=True)

Update array elements based on a mask.

JAX implementation of numpy.place.

The semantics of numpy.place are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set to False` by the user as a reminder of this API difference.

Parameters:
Return type:

Array

Returns:

A copy of arr with masked values set to entries from vals.

See also

  • jax.numpy.put: put elements into an array at numerical indices.

  • jax.numpy.ndarray.at: array updates using NumPy-style indexing

Examples

>>> x = jnp.zeros((3, 5), dtype=int)
>>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape)
>>> mask
Array([[ True, False, False,  True, False],
       [False,  True, False, False,  True],
       [False, False,  True, False, False]], dtype=bool)

Placing a scalar value:

>>> jnp.place(x, mask, 1, inplace=False)
Array([[1, 0, 0, 1, 0],
       [0, 1, 0, 0, 1],
       [0, 0, 1, 0, 0]], dtype=int32)

In this case, jnp.place is similar to the masked array update syntax:

>>> x.at[mask].set(1)
Array([[1, 0, 0, 1, 0],
       [0, 1, 0, 0, 1],
       [0, 0, 1, 0, 0]], dtype=int32)

place differs when placing values from an array. The array is repeated to fill the masked entries:

>>> vals = jnp.array([1, 3, 5])
>>> jnp.place(x, mask, vals, inplace=False)
Array([[1, 0, 0, 3, 0],
       [0, 5, 0, 0, 1],
       [0, 0, 3, 0, 0]], dtype=int32)
scico.numpy.polydiv(u, v, *, trim_leading_zeros=False)

Returns the quotient and remainder of polynomial division.

JAX implementation of numpy.polydiv.

Parameters:
  • u (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Array of dividend polynomial coefficients.

  • v (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Array of divisor polynomial coefficients.

  • trim_leading_zeros (bool) – Default is False. If True removes the leading zeros in the return value to match the result of numpy. But prevents the function from being able to be used in compiled code. Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be considered zero may lead to inconsistent results between NumPy and JAX, and even between different JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True.

Return type:

tuple[Array, Array]

Returns:

A tuple of quotient and remainder arrays. The dtype of the output is always promoted to inexact.

Note

jax.numpy.polydiv only accepts arrays as input unlike numpy.polydiv which accepts scalar inputs as well.

See also

Examples

>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
(array([1.25  , 1.4375]), array([7.5625]))
>>> jnp.polydiv(x1, x2)
(Array([1.25  , 1.4375], dtype=float32), Array([0.    , 0.    , 7.5625], dtype=float32))

If trim_leading_zeros=True, the result matches with np.polydiv’s.

>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
(Array([1.25  , 1.4375], dtype=float32), Array([7.5625], dtype=float32))
scico.numpy.polymul(a1, a2, *, trim_leading_zeros=False)

Returns the product of two polynomials.

JAX implementation of numpy.polymul.

Parameters:
  • a1 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – 1D array of polynomial coefficients.

  • a2 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – 1D array of polynomial coefficients.

  • trim_leading_zeros (bool) – Default is False. If True removes the leading zeros in the return value to match the result of numpy. But prevents the function from being able to be used in compiled code. Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be considered zero may lead to inconsistent results between NumPy and JAX, and even between different JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True.

Return type:

Array

Returns:

An array of the coefficients of the product of the two polynomials. The dtype of the output is always promoted to inexact.

Note

jax.numpy.polymul only accepts arrays as input unlike numpy.polymul which accepts scalar inputs as well.

See also

Examples

>>> x1 = np.array([2, 1, 0])
>>> x2 = np.array([0, 5, 0, 3])
>>> np.polymul(x1, x2)
array([10,  5,  6,  3,  0])
>>> jnp.polymul(x1, x2)
Array([ 0., 10.,  5.,  6.,  3.,  0.], dtype=float32)

If trim_leading_zeros=True, the result matches with np.polymul’s.

>>> jnp.polymul(x1, x2, trim_leading_zeros=True)
Array([10.,  5.,  6.,  3.,  0.], dtype=float32)

For input arrays of dtype complex:

>>> x3 = np.array([2., 1+2j, 1-2j])
>>> x4 = np.array([0, 5, 0, 3])
>>> np.polymul(x3, x4)
array([10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j])
>>> jnp.polymul(x3, x4)
Array([ 0. +0.j, 10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j],      dtype=complex64)
>>> jnp.polymul(x3, x4, trim_leading_zeros=True)
Array([10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j], dtype=complex64)
scico.numpy.positive(x, /)

Return element-wise positive values of the input.

JAX implementation of numpy.positive.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar

Return type:

Array

Returns:

An array of same shape and dtype as x containing +x.

Note

jnp.positive is equivalent to x.copy() and is defined only for the types that support arithmetic operations.

See also

  • jax.numpy.negative: Returns element-wise negative values of the input.

  • jax.numpy.sign: Returns element-wise indication of sign of the input.

Examples

For real-valued inputs:

>>> x = jnp.array([-5, 4, 7., -9.5])
>>> jnp.positive(x)
Array([-5. ,  4. ,  7. , -9.5], dtype=float32)
>>> x.copy()
Array([-5. ,  4. ,  7. , -9.5], dtype=float32)

For complex inputs:

>>> x1 = jnp.array([1-2j, -3+4j, 5-6j])
>>> jnp.positive(x1)
Array([ 1.-2.j, -3.+4.j,  5.-6.j], dtype=complex64)
>>> x1.copy()
Array([ 1.-2.j, -3.+4.j,  5.-6.j], dtype=complex64)

For uint32:

>>> x2 = jnp.array([6, 0, -4]).astype(jnp.uint32)
>>> x2
Array([         6,          0, 4294967292], dtype=uint32)
>>> jnp.positive(x2)
Array([         6,          0, 4294967292], dtype=uint32)
scico.numpy.pow(x1, x2, /)

Alias of jax.numpy.power

Return type:

Array

scico.numpy.power(x1, x2, /)

Calculate element-wise base x1 exponential of x2.

JAX implementation of numpy.power.

Parameters:
Return type:

Array

Returns:

An array containing the base x1 exponentials of x2 with same dtype as input.

Note

  • When x2 is a concrete integer scalar, jnp.power lowers to jax.lax.integer_pow.

  • When x2 is a traced scalar or an array, jnp.power lowers to jax.lax.pow.

  • jnp.power raises a TypeError for integer type raised to a concrete negative integer power. For a non-concrete power, the operation is invalid and the returned value is implementation-defined.

  • jnp.power returns nan for negative value raised to the power of non-integer values.

See also

  • jax.lax.pow: Computes element-wise power, \(x^y\).

  • jax.lax.integer_pow: Computes element-wise power \(x^y\), where \(y\) is a fixed integer.

  • jax.numpy.float_power: Computes the first array raised to the power of second array, element-wise, by promoting to the inexact dtype.

  • jax.numpy.pow: Computes the first array raised to the power of second array, element-wise.

Examples

Inputs with scalar integers:

>>> jnp.power(4, 3)
Array(64, dtype=int32, weak_type=True)

Inputs with same shape:

>>> x1 = jnp.array([2, 4, 5])
>>> x2 = jnp.array([3, 0.5, 2])
>>> jnp.power(x1, x2)
Array([ 8.,  2., 25.], dtype=float32)

Inputs with broadcast compatibility:

>>> x3 = jnp.array([-2, 3, 1])
>>> x4 = jnp.array([[4, 1, 6],
...                 [1.3, 3, 5]])
>>> jnp.power(x3, x4)
Array([[16.,  3.,  1.],
       [nan, 27.,  1.]], dtype=float32)
scico.numpy.printoptions(*args, **kwargs)

Alias of numpy.printoptions.

JAX arrays are printed via NumPy, so NumPy’s printoptions configurations will apply to printed JAX arrays.

See the numpy.set_printoptions documentation for details on the available options and their meanings.

scico.numpy.prod(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)

Return product of the array elements over a given axis.

JAX implementation of numpy.prod.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array.

  • axis (Union[int, Sequence[int], None]) – int or array, default=None. Axis along which the product to be computed. If None, the product is computed along all the axes.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – The type of the output array. Default=None.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • initial (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array, Default=None. Initial value for the product.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array, default=None. The elements to be used in the product. Array should be broadcast compatible to the input.

  • promote_integers (bool) – bool, default=True. If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input. promote_integers is ignored if dtype is specified.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of the product along the given axis.

See also

  • jax.numpy.sum: Compute the sum of array elements over a given axis.

  • jax.numpy.max: Compute the maximum of array elements over given axis.

  • jax.numpy.min: Compute the minimum of array elements over given axis.

Examples

By default, jnp.prod computes along all the axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [5, 2, 1, 3],
...                [2, 1, 3, 1]])
>>> jnp.prod(x)
Array(4320, dtype=int32)

If axis=1, product is computed along axis 1.

>>> jnp.prod(x, axis=1)
Array([24, 30,  6], dtype=int32)

If keepdims=True, ndim of the output is equal to that of the input.

>>> jnp.prod(x, axis=1, keepdims=True)
Array([[24],
       [30],
       [ 6]], dtype=int32)

To include only specific elements in the sum, you can use a``where``.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.prod(x, axis=1, keepdims=True, where=where)
Array([[4],
       [3],
       [6]], dtype=int32)
>>> where = jnp.array([[False],
...                    [False],
...                    [False]])
>>> jnp.prod(x, axis=1, keepdims=True, where=where)
Array([[1],
       [1],
       [1]], dtype=int32)
scico.numpy.promote_types(a, b)

Returns the type to which a binary operation should cast its arguments.

JAX implementation of numpy.promote_types. For details of JAX’s type promotion semantics, see Type promotion semantics.

Parameters:
Return type:

dtype

Returns:

A numpy.dtype object.

Examples

Type specifiers may be strings, dtypes, or scalar types, and the return value is always a dtype:

>>> jnp.promote_types('int32', 'float32')  # strings
dtype('float32')
>>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32'))  # dtypes
dtype('float32')
>>> jnp.promote_types(jnp.int32, jnp.float32)  # scalar types
dtype('float32')

Built-in scalar types (:type:`int`, :type:`float`, or :type:`complex`) are treated as weakly-typed and will not change the bit width of a strongly-typed counterpart (see discussion in Type promotion semantics):

>>> jnp.promote_types('uint8', int)
dtype('uint8')
>>> jnp.promote_types('float16', float)
dtype('float16')

This differs from the NumPy version of this function, which treats built-in scalar types as equivalent to 64-bit types:

>>> import numpy
>>> numpy.promote_types('uint8', int)
dtype('int64')
>>> numpy.promote_types('float16', float)
dtype('float64')
scico.numpy.ptp(a, axis=None, out=None, keepdims=False)

Return the peak-to-peak range along a given axis.

JAX implementation of numpy.ptp.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array.

  • axis (Union[int, Sequence[int], None]) – optional, int or sequence of ints, default=None. Axis along which the range is computed. If None, the range is computed on the flattened array.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array with the range of elements along specified axis of input.

Examples

By default, jnp.ptp computes the range along all axes.

>>> x = jnp.array([[1, 3, 5, 2],
...                [4, 6, 8, 1],
...                [7, 9, 3, 4]])
>>> jnp.ptp(x)
Array(8, dtype=int32)

If axis=1, computes the range along axis 1.

>>> jnp.ptp(x, axis=1)
Array([4, 7, 6], dtype=int32)

To preserve the dimensions of input, you can set keepdims=True.

>>> jnp.ptp(x, axis=1, keepdims=True)
Array([[4],
       [7],
       [6]], dtype=int32)
scico.numpy.put(a, ind, v, mode=None, *, inplace=True)

Put elements into an array at given indices.

JAX implementation of numpy.put.

The semantics of numpy.put are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set to False` by the user as a reminder of this API difference.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array into which values will be placed.

  • ind (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of indices over the flattened array at which to put values.

  • v (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of values to put into the array.

  • mode (str | None) –

    string specifying how to handle out-of-bound indices. Supported values:

    • "clip" (default): clip out-of-bound indices to the final index.

    • "wrap": wrap out-of-bound indices to the beginning of the array.

  • inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.

Return type:

Array

Returns:

A copy of a with specified entries updated.

See also

  • jax.numpy.place: place elements into an array via boolean mask.

  • jax.numpy.ndarray.at: array updates using NumPy-style indexing.

  • jax.numpy.take: extract values from an array at given indices.

Examples

>>> x = jnp.zeros(5, dtype=int)
>>> indices = jnp.array([0, 2, 4])
>>> values = jnp.array([10, 20, 30])
>>> jnp.put(x, indices, values, inplace=False)
Array([10,  0, 20,  0, 30], dtype=int32)

This is equivalent to the following jax.numpy.ndarray.at indexing syntax:

>>> x.at[indices].set(values)
Array([10,  0, 20,  0, 30], dtype=int32)

There are two modes for handling out-of-bound indices. By default they are clipped:

>>> indices = jnp.array([0, 2, 6])
>>> jnp.put(x, indices, values, inplace=False, mode='clip')
Array([10,  0, 20,  0, 30], dtype=int32)

Alternatively, they can be wrapped to the beginning of the array:

>>> jnp.put(x, indices, values, inplace=False, mode='wrap')
Array([10,  30, 20,  0, 0], dtype=int32)

For N-dimensional inputs, the indices refer to the flattened array:

>>> x = jnp.zeros((3, 5), dtype=int)
>>> indices = jnp.array([0, 7, 14])
>>> jnp.put(x, indices, values, inplace=False)
Array([[10,  0,  0,  0,  0],
       [ 0,  0, 20,  0,  0],
       [ 0,  0,  0,  0, 30]], dtype=int32)
scico.numpy.rad2deg(x, /)

Convert angles from radians to degrees.

JAX implementation of numpy.rad2deg.

The angle in radians is converted to degrees by:

\[rad2deg(x) = x * \frac{180}{pi}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Specifies the angle in radians.

Return type:

Array

Returns:

An array containing the angles in degrees.

See also

Examples

>>> pi = jnp.pi
>>> x = jnp.array([pi/4, pi/2, 2*pi/3])
>>> jnp.rad2deg(x)
Array([ 45.     ,  90.     , 120.00001], dtype=float32)
>>> x * 180 / pi
Array([ 45.,  90., 120.], dtype=float32)
scico.numpy.radians(x, /)

Alias of jax.numpy.deg2rad

Return type:

Array

scico.numpy.ravel_multi_index(multi_index, dims, mode='raise', order='C', *, dtype=None)

Convert multi-dimensional indices into flat indices.

JAX implementation of numpy.ravel_multi_index

Parameters:
  • multi_index (Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – sequence of integer arrays containing indices in each dimension.

  • dims (Sequence[int]) – sequence of integer sizes; must have len(dims) == len(multi_index)

  • mode (str) –

    how to handle out-of bound indices. Options are

    • "raise" (default): raise a ValueError. This mode is incompatible with jit or other JAX transformations.

    • "clip": clip out-of-bound indices to valid range.

    • "wrap": wrap out-of-bound indices to valid range.

    • "ignore": do not coerce or check input indices. Behavior is undefined if indices are out of bounds.

  • order (str) – "C" (default) or "F", specify whether to assume C-style row-major order or Fortran-style column-major order.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – the desired output dtype. If not specified, the dtype is determined by standard type promotion rules of the input multi_index.

Return type:

Array

Returns:

array of flattened indices

See also

jax.numpy.unravel_index: inverse of this function.

Examples

Define a 2-dimensional array and a sequence of indices of even values:

>>> x = jnp.array([[2., 3., 4.],
...                [5., 6., 7.]])
>>> indices = jnp.where(x % 2 == 0)
>>> indices
(Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
>>> x[indices]
Array([2., 4., 6.], dtype=float32)

Compute the flattened indices:

>>> indices_flat = jnp.ravel_multi_index(indices, x.shape)
>>> indices_flat
Array([0, 2, 4], dtype=int32)

These flattened indices can be used to extract the same values from the flattened x array:

>>> x_flat = x.ravel()
>>> x_flat
Array([2., 3., 4., 5., 6., 7.], dtype=float32)
>>> x_flat[indices_flat]
Array([2., 4., 6.], dtype=float32)

The original indices can be recovered with unravel_index:

>>> jnp.unravel_index(indices_flat, x.shape)
(Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
scico.numpy.real(val, /)

Return element-wise real part of the complex argument.

JAX implementation of numpy.real.

Parameters:

val (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the real part of the elements of val.

See also

Examples

>>> jnp.real(5)
Array(5, dtype=int32, weak_type=True)
>>> jnp.real(2j)
Array(0., dtype=float32, weak_type=True)
>>> x = jnp.array([3-2j, 4+7j, -2j])
>>> jnp.real(x)
Array([ 3.,  4., -0.], dtype=float32)
scico.numpy.reciprocal(x, /)

Calculate element-wise reciprocal of the input.

JAX implementation of numpy.reciprocal.

The reciprocal is calculated by 1/x.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array of same shape as x containing the reciprocal of each element of x.

Note

For integer inputs, np.reciprocal returns rounded integer output, while jnp.reciprocal promotes integer inputs to floating point.

Examples

>>> jnp.reciprocal(2)
Array(0.5, dtype=float32, weak_type=True)
>>> jnp.reciprocal(0.)
Array(inf, dtype=float32, weak_type=True)
>>> x = jnp.array([1, 5., 4.])
>>> jnp.reciprocal(x)
Array([1.  , 0.2 , 0.25], dtype=float32)
scico.numpy.remainder(x1, x2, /)

Returns element-wise remainder of the division.

JAX implementation of numpy.remainder.

Parameters:
Return type:

Array

Returns:

An array containing the remainder of element-wise division of x1 by x2 with same sign as the elements of x2.

Note

The result of jnp.remainder is equivalent to x1 - x2 * jnp.floor(x1 / x2).

See also

  • jax.numpy.mod: Returns the element-wise remainder of the division.

  • jax.numpy.fmod: Calculates the element-wise floating-point modulo operation.

  • jax.numpy.divmod: Calculates the integer quotient and remainder of x1 by x2, element-wise.

Examples

>>> x1 = jnp.array([[3, -1, 4],
...                 [8, 5, -2]])
>>> x2 = jnp.array([2, 3, -5])
>>> jnp.remainder(x1, x2)
Array([[ 1,  2, -1],
       [ 0,  2, -2]], dtype=int32)
>>> x1 - x2 * jnp.floor(x1 / x2)
Array([[ 1.,  2., -1.],
       [ 0.,  2., -2.]], dtype=float32)
scico.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None, out_sharding=None)

Construct an array from repeated elements.

JAX implementation of numpy.repeat.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array

  • repeats (Union[Array, ndarray, bool, number, bool, int, float, complex]) – 1D integer array specifying the number of repeats. Must match the length of the repeated axis.

  • axis (int | None) – integer specifying the axis of a along which to construct the repeated array. If None (default) then a is first flattened.

  • total_repeat_length (int | None) – this must be specified statically for jnp.repeat to be compatible with jit and other JAX transformations. If sum(repeats) is larger than the specified total_repeat_length, the remaining values will be discarded. If sum(repeats) is smaller than total_repeat_length, the final value will be repeated.

Return type:

Array

Returns:

an array constructed from repeated values of a.

See also

Examples

Repeat each value twice along the last axis:

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.repeat(a, 2, axis=-1)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

If axis is not specified, the input array will be flattened:

>>> jnp.repeat(a, 2)
Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

Pass an array to repeats to repeat each value a different number of times:

>>> repeats = jnp.array([2, 3])
>>> jnp.repeat(a, repeats, axis=1)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

In order to use repeat within jit and other JAX transformations, the size of the output must be specified statically using total_repeat_length:

>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length'])
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=5)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

If total_repeat_length is smaller than sum(repeats), the result will be truncated:

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

If it is larger, then the additional entries will be filled with the final value:

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7)
Array([[1, 1, 2, 2, 2, 2, 2],
       [3, 3, 4, 4, 4, 4, 4]], dtype=int32)
scico.numpy.reshape(a, shape, order='C', *, copy=None, out_sharding=None)

Return a reshaped copy of an array.

JAX implementation of numpy.reshape, implemented in terms of jax.lax.reshape.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array to reshape

  • shape (Union[int, Any, Sequence[Union[int, Any]]]) – integer or sequence of integers giving the new shape, which must match the size of the input array. If any single dimension is given size -1, it will be replaced with a value such that the output has the correct size.

  • order (str) – 'F' or 'C', specifies whether the reshape should apply column-major (fortran-style, "F") or row-major (C-style, "C") order; default is "C". JAX does not support order="A".

  • copy (bool | None) – unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away.

Return type:

Array

Returns:

reshaped copy of input array with the specified shape.

Notes

Unlike numpy.reshape, jax.numpy.reshape will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.reshape(x, 6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>> jnp.reshape(x, (3, 2))
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

You can use -1 to automatically compute a shape that is consistent with the input size:

>>> jnp.reshape(x, -1)  # -1 is inferred to be 6
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>> jnp.reshape(x, (-1, 2))  # -1 is inferred to be 3
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

The default ordering of axes in the reshape is C-style row-major ordering. To use Fortran-style column-major ordering, specify order='F':

>>> jnp.reshape(x, 6, order='F')
Array([1, 4, 2, 5, 3, 6], dtype=int32)
>>> jnp.reshape(x, (3, 2), order='F')
Array([[1, 5],
       [4, 3],
       [2, 6]], dtype=int32)

For convenience, this functionality is also available via the jax.Array.reshape method:

>>> x.reshape(3, 2)
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)
scico.numpy.resize(a, new_shape)

Return a new array with specified shape.

JAX implementation of numpy.resize.

Parameters:
Return type:

Array

Returns:

A resized array with specified shape. The elements of a are repeated in the resized array, if the resized array is larger than the original array.

See also

Examples

>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> jnp.resize(x, (3, 3))
Array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]], dtype=int32)
>>> jnp.resize(x, (3, 4))
Array([[1, 2, 3, 4],
       [5, 6, 7, 8],
       [9, 1, 2, 3]], dtype=int32)
>>> jnp.resize(4, (3, 2))
Array([[4, 4],
       [4, 4],
       [4, 4]], dtype=int32, weak_type=True)
scico.numpy.result_type(*args)

Return the result of applying JAX promotion rules to the inputs.

JAX implementation of numpy.result_type.

JAX’s dtype promotion behavior is described in Type promotion semantics.

Parameters:

args (Any) – one or more arrays or dtype-like objects.

Return type:

dtype

Returns:

A numpy.dtype instance representing the result of type promotion for the inputs.

Examples

Inputs can be dtype specifiers:

>>> jnp.result_type('int32', 'float32')
dtype('float32')
>>> jnp.result_type(np.uint16, np.dtype('int32'))
dtype('int32')

Inputs may also be scalars or arrays:

>>> jnp.result_type(1.0, jnp.bfloat16(2))
dtype(bfloat16)
>>> jnp.result_type(jnp.arange(4), jnp.zeros(4))
dtype('float32')

Be aware that the result type will be canonicalized based on the state of the jax_enable_x64 configuration flag, meaning that 64-bit types may be downcast to 32-bit:

>>> jnp.result_type('float64')  
dtype('float32')

For details on 64-bit values, refer to Sharp bits - double precision:

scico.numpy.rint(x, /)

Rounds the elements of x to the nearest integer

JAX implementation of numpy.rint.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array

Return type:

Array

Returns:

An array-like object containing the rounded elements of x. Always promotes to inexact.

Note

If an element of x is exactly half way, e.g. 0.5 or 1.5, rint will round to the nearest even integer.

Examples

>>> x1 = jnp.array([5, 4, 7])
>>> jnp.rint(x1)
Array([5., 4., 7.], dtype=float32)
>>> x2 = jnp.array([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5])
>>> jnp.rint(x2)
Array([-2., -2., -0.,  0.,  2.,  2.,  4.,  4.], dtype=float32)
>>> x3 = jnp.array([-2.5+3.5j, 4.5-0.5j])
>>> jnp.rint(x3)
Array([-2.+4.j,  4.-0.j], dtype=complex64)
scico.numpy.roll(a, shift, axis=None)

Roll the elements of an array along a specified axis.

JAX implementation of numpy.roll.

Parameters:
Return type:

Array

Returns:

A copy of a with elements rolled along the specified axis or axes.

See also

Examples

>>> a = jnp.array([0, 1, 2, 3, 4, 5])
>>> jnp.roll(a, 2)
Array([4, 5, 0, 1, 2, 3], dtype=int32)

Roll elements along a specific axis:

>>> a = jnp.array([[ 0,  1,  2,  3],
...                [ 4,  5,  6,  7],
...                [ 8,  9, 10, 11]])
>>> jnp.roll(a, 1, axis=0)
Array([[ 8,  9, 10, 11],
       [ 0,  1,  2,  3],
       [ 4,  5,  6,  7]], dtype=int32)
>>> jnp.roll(a, [2, 3], axis=[0, 1])
Array([[ 5,  6,  7,  4],
       [ 9, 10, 11,  8],
       [ 1,  2,  3,  0]], dtype=int32)
scico.numpy.rollaxis(a, axis, start=0)

Roll the specified axis to a given position.

JAX implementation of numpy.rollaxis.

This function exists for compatibility with NumPy, but in most cases the newer jax.numpy.moveaxis instead, because the meaning of its arguments is more intuitive.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array.

  • axis (int) – index of the axis to roll forward.

  • start (int) – index toward which the axis will be rolled (default = 0). After normalizing negative axes, if start <= axis, the axis is rolled to the start index; if start > axis, the axis is rolled until the position before start.

Return type:

Array

Returns:

Copy of a with rolled axis.

Notes

Unlike numpy.rollaxis, jax.numpy.rollaxis will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> a = jnp.ones((2, 3, 4, 5))

Roll axis 2 to the start of the array:

>>> jnp.rollaxis(a, 2).shape
(4, 2, 3, 5)

Roll axis 1 to the end of the array:

>>> jnp.rollaxis(a, 1, a.ndim).shape
(2, 4, 5, 3)

Equivalent of these two with moveaxis

>>> jnp.moveaxis(a, 2, 0).shape
(4, 2, 3, 5)
>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)
scico.numpy.roots(p, *, strip_zeros=True)

Returns the roots of a polynomial given the coefficients p.

JAX implementations of numpy.roots.

Parameters:
  • p (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Array of polynomial coefficients having rank-1.

  • strip_zeros (bool) – bool, default=True. If True, then leading zeros in the coefficients will be stripped, similar to numpy.roots. If set to False, leading zeros will not be stripped, and undefined roots will be represented by NaN values in the function output. strip_zeros must be set to False for the function to be compatible with jax.jit and other JAX transformations.

Return type:

Array

Returns:

An array containing the roots of the polynomial.

Note

Unlike np.roots of this function, the jnp.roots returns the roots in a complex array regardless of the values of the roots.

See also

Examples

>>> coeffs = jnp.array([0, 1, 2])

The default behavior matches numpy and strips leading zeros:

>>> jnp.roots(coeffs)
Array([-2.+0.j], dtype=complex64)

With strip_zeros=False, extra roots are set to NaN:

>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)
scico.numpy.rot90(m, k=1, axes=(0, 1))

Rotate an array by 90 degrees counterclockwise in the plane specified by axes.

JAX implementation of numpy.rot90.

Parameters:
  • m (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array. Must have m.ndim >= 2.

  • k (int) – int, optional, default=1. Specifies the number of times the array is rotated. For negative values of k, the array is rotated in clockwise direction.

  • axes (tuple[int, int]) – tuple of 2 integers, optional, default= (0, 1). The axes define the plane in which the array is rotated. Both the axes must be different.

Return type:

Array

Returns:

An array containing the copy of the input, m rotated by 90 degrees.

See also

Examples

>>> m = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.rot90(m)
Array([[3, 6],
       [2, 5],
       [1, 4]], dtype=int32)
>>> jnp.rot90(m, k=2)
Array([[6, 5, 4],
       [3, 2, 1]], dtype=int32)

jnp.rot90(m, k=1, axes=(1, 0)) is equivalent to jnp.rot90(m, k=-1, axes(0,1)).

>>> jnp.rot90(m, axes=(1, 0))
Array([[4, 1],
       [5, 2],
       [6, 3]], dtype=int32)
>>> jnp.rot90(m, k=-1, axes=(0, 1))
Array([[4, 1],
       [5, 2],
       [6, 3]], dtype=int32)

when input array has ndim>2:

>>> m1 = jnp.array([[[1, 2, 3],
...                  [4, 5, 6]],
...                 [[7, 8, 9],
...                  [10, 11, 12]]])
>>> jnp.rot90(m1, k=1, axes=(2, 1))
Array([[[ 4,  1],
        [ 5,  2],
        [ 6,  3]],

       [[10,  7],
        [11,  8],
        [12,  9]]], dtype=int32)
scico.numpy.round(a, decimals=0, out=None)

Round input evenly to the given number of decimals.

JAX implementation of numpy.round.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

  • decimals (int) – int, default=0. Number of decimal points to which the input needs to be rounded. It must be specified statically. Not implemented for decimals < 0.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array containing the rounded values to the specified decimals with same shape and dtype as a.

Note

jnp.round rounds to the nearest even integer for the values exactly halfway between rounded decimal values.

See also

  • jax.numpy.floor: Rounds the input to the nearest integer downwards.

  • jax.numpy.ceil: Rounds the input to the nearest integer upwards.

  • jax.numpy.fix and :func:numpy.trunc`: Rounds the input to the nearest integer towards zero.

Examples

>>> x = jnp.array([1.532, 3.267, 6.149])
>>> jnp.round(x)
Array([2., 3., 6.], dtype=float32)
>>> jnp.round(x, decimals=2)
Array([1.53, 3.27, 6.15], dtype=float32)

For values exactly halfway between rounded values:

>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5])
>>> jnp.round(x1)
Array([10., 22., 12., 32.], dtype=float32)
scico.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')

Perform a binary search within a sorted array.

JAX implementation of numpy.searchsorted.

This will return the indices within a sorted array a where values in v can be inserted to maintain its sort order.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – one-dimensional array, assumed to be in sorted order unless sorter is specified.

  • v (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array of query values

  • side (str) – 'left' (default) or 'right'; specifies whether insertion indices will be to the left or the right in case of ties.

  • sorter (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – optional array of indices specifying the sort order of a. If specified, then the algorithm assumes that a[sorter] is in sorted order.

  • method (str) – one of 'scan' (default), 'scan_unrolled', 'sort' or 'compare_all'. See Note below.

Return type:

Array

Returns:

Array of insertion indices of shape v.shape.

Note

The method argument controls the algorithm used to compute the insertion indices.

  • 'scan' (the default) tends to be more performant on CPU, particularly when a is very large.

  • 'scan_unrolled' is more performant on GPU at the expense of additional compile time.

  • 'sort' is often more performant on accelerator backends like GPU and TPU, particularly when v is very large.

  • 'compare_all' tends to be the most performant when a is very small.

Examples

Searching for a single value:

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

Searching for a batch of values:

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

Optionally, the sorter argument can be used to find insertion indices into an array sorted via jax.numpy.argsort:

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

The result is equivalent to passing the sorted array:

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)
scico.numpy.select(condlist, choicelist, default=0)

Select values based on a series of conditions.

JAX implementation of numpy.select, implemented in terms of jax.lax.select_n

Parameters:
Return type:

Array

Returns:

Array of selected values from choicelist corresponding to the first True entry in condlist at each location.

See also

Examples

>>> condlist = [
...    jnp.array([False, True, False, False]),
...    jnp.array([True, False, False, False]),
...    jnp.array([False, True, True, False]),
... ]
>>> choicelist = [
...    jnp.array([1, 2, 3, 4]),
...    jnp.array([10, 20, 30, 40]),
...    jnp.array([100, 200, 300, 400]),
... ]
>>> jnp.select(condlist, choicelist, default=0)
Array([ 10,   2, 300,   0], dtype=int32)

This is logically equivalent to the following nested where statement:

>>> default = 0
>>> jnp.where(condlist[0],
...   choicelist[0],
...   jnp.where(condlist[1],
...     choicelist[1],
...     jnp.where(condlist[2],
...       choicelist[2],
...       default)))
Array([ 10,   2, 300,   0], dtype=int32)

However, for efficiency it is implemented in terms of jax.lax.select_n.

scico.numpy.set_printoptions(*args, **kwargs)

Alias of numpy.set_printoptions.

JAX arrays are printed via NumPy, so NumPy’s printoptions configurations will apply to printed JAX arrays.

See the numpy.set_printoptions documentation for details on the available options and their meanings.

scico.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)

Compute the set difference of two 1D arrays.

JAX implementation of numpy.setdiff1d.

Because the size of the output of setdiff1d is data-dependent, the function is not typically compatible with jit and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.setdiff1d to be used in such contexts.

Parameters:
  • ar1 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – first array of elements to be differenced.

  • ar2 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – second array of elements to be differenced.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if assume_unique is True and the input arrays contain duplicates, the behavior is undefined. default: False.

  • size (int | None) – if specified, return only the first size sorted elements. If there are fewer elements than size indicates, the return value will be padded with fill_value.

  • fill_value (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the minimum value.

Return type:

Array

Returns:

an array containing the set difference of elements in the input array – i.e. the elements in ar1 that are not contained in ar2.

See also

Examples

Computing the set difference of two arrays:

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.setdiff1d(ar1, ar2)
Array([1, 2], dtype=int32)

Because the output shape is dynamic, this will fail under jit and other transformations:

>>> jax.jit(jnp.setdiff1d)(ar1, ar2)  
Traceback (most recent call last):
   ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.

In order to ensure statically-known output shapes, you can pass a static size argument:

>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size'])
>>> jit_setdiff1d(ar1, ar2, size=2)
Array([1, 2], dtype=int32)

If size is too small, the difference is truncated:

>>> jit_setdiff1d(ar1, ar2, size=1)
Array([1], dtype=int32)

If size is too large, then the output is padded with fill_value:

>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0)
Array([1, 2, 0, 0], dtype=int32)
scico.numpy.setxor1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)

Compute the set-wise xor of elements in two arrays.

JAX implementation of numpy.setxor1d.

Because the size of the output of setxor1d is data-dependent, the function is not compatible with JIT or other JAX transformations.

Parameters:
  • ar1 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – first array of values to intersect.

  • ar2 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – second array of values to intersect.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if assume_unique is True and the input arrays contain duplicates, the behavior is undefined. default: False.

  • size (int | None) – if specified, return only the first size sorted elements. If there are fewer elements than size indicates, the return value will be padded with fill_value, and returned indices will be padded with an out-of-bound index.

  • fill_value (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the smallest value in the xor result.

Return type:

Array

Returns:

An array of values that are found in exactly one of the input arrays.

See also

Examples

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.setxor1d(ar1, ar2)
Array([1, 2, 5, 6], dtype=int32)
scico.numpy.shape(a)

Return the shape an array.

JAX implementation of numpy.shape. Unlike np.shape, this function raises a TypeError if the input is a collection such as a list or tuple.

Parameters:

a (Union[Array, ndarray, bool, number, bool, int, float, complex, SupportsShape]) – array-like object, or any object with a shape attribute.

Return type:

tuple[int, ...]

Returns:

An tuple of integers representing the shape of a.

Examples

Shape for arrays:

>>> x = jnp.arange(10)
>>> jnp.shape(x)
(10,)
>>> y = jnp.ones((2, 3))
>>> jnp.shape(y)
(2, 3)

This also works for scalars:

>>> jnp.shape(3.14)
()

For arrays, this can also be accessed via the jax.Array.shape property:

>>> x.shape
(10,)
scico.numpy.sign(x, /)

Return an element-wise indication of sign of the input.

JAX implementation of numpy.sign.

The sign of x for real-valued input is:

\[\begin{split}\mathrm{sign}(x) = \begin{cases} 1, & x > 0\\ 0, & x = 0\\ -1, & x < 0 \end{cases}\end{split}\]

For complex valued input, jnp.sign returns a unit vector representing the phase. For generalized case, the sign of x is given by:

\[\begin{split}\mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\\ 0, & x = 0 \end{cases}\end{split}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array with same shape and dtype as x containing the sign indication.

See also

  • jax.numpy.positive: Returns element-wise positive values of the input.

  • jax.numpy.negative: Returns element-wise negative values of the input.

Examples

For Real-valued inputs:

>>> x = jnp.array([0., -3., 7.])
>>> jnp.sign(x)
Array([ 0., -1.,  1.], dtype=float32)

For complex-inputs:

>>> x1 = jnp.array([1, 3+4j, 5j])
>>> jnp.sign(x1)
Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)
scico.numpy.signbit(x, /)

Return the sign bit of array elements.

JAX implementation of numpy.signbit.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array. Complex values are not supported.

Return type:

Array

Returns:

A boolean array of the same shape as x, containing True where the sign of x is negative, and False otherwise.

See also

  • jax.numpy.sign: return the mathematical sign of array elements, i.e. -1, 0, or +1.

Examples

signbit on boolean values is always False:

>>> x = jnp.array([True, False])
>>> jnp.signbit(x)
Array([False, False], dtype=bool)

signbit on integer values is equivalent to x < 0:

>>> x = jnp.array([-2, -1, 0, 1, 2])
>>> jnp.signbit(x)
Array([ True,  True, False, False, False], dtype=bool)

signbit on floating point values returns the value of the actual sign bit from the float representation, including signed zero:

>>> x = jnp.array([-1.5, -0.0, 0.0, 1.5])
>>> jnp.signbit(x)
Array([ True, True, False, False], dtype=bool)

This also returns the sign bit for special values such as signed NaN and signed infinity:

>>> x = jnp.array([jnp.nan, -jnp.nan, jnp.inf, -jnp.inf])
>>> jnp.signbit(x)
Array([False,  True, False,  True], dtype=bool)
scico.numpy.sin(x, /)

Compute a trigonometric sine of each element of input.

JAX implementation of numpy.sin.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array or scalar. Angle in radians.

Return type:

Array

Returns:

An array containing the sine of each element in x, promotes to inexact dtype.

See also

Examples

>>> pi = jnp.pi
>>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi])
>>> with jnp.printoptions(precision=3, suppress=True):
...   print(jnp.sin(x))
[ 0.707  1.     0.707 -0.   ]
scico.numpy.sinc(x, /)

Calculate the normalized sinc function.

JAX implementation of numpy.sinc.

The normalized sinc function is given by

\[\mathrm{sinc}(x) = \frac{\sin({\pi x})}{\pi x}\]

where sinc(0) returns the limit value of 1. The sinc function is smooth and infinitely differentiable.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array; will be promoted to an inexact type.

Return type:

Array

Returns:

An array of the same shape as x containing the result.

Examples

>>> x = jnp.array([-1, -0.5, 0, 0.5, 1])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sinc(x)
Array([-0.   ,  0.637,  1.   ,  0.637, -0.   ], dtype=float32)

Compare this to the naive approach to computing the function, which is undefined at zero:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sin(jnp.pi * x) / (jnp.pi * x)
Array([-0.   ,  0.637,    nan,  0.637, -0.   ], dtype=float32)

JAX defines a custom gradient rule for sinc to allow accurate evaluation of the gradient at zero even for higher-order derivatives:

>>> f = jnp.sinc
>>> for i in range(1, 6):
...   f = jax.grad(f)
...   print(f"(d/dx)^{i} f(0.0) = {f(0.0):.2f}")
...
(d/dx)^1 f(0.0) = 0.00
(d/dx)^2 f(0.0) = -3.29
(d/dx)^3 f(0.0) = 0.00
(d/dx)^4 f(0.0) = 19.48
(d/dx)^5 f(0.0) = 0.00
scico.numpy.sinh(x, /)

Calculate element-wise hyperbolic sine of input.

JAX implementation of numpy.sinh.

The hyperbolic sine is defined by:

\[sinh(x) = \frac{e^x - e^{-x}}{2}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the hyperbolic sine of each element of x, promoting to inexact dtype.

Note

jnp.sinh is equivalent to computing -1j * jnp.sin(1j * x).

See also

  • jax.numpy.cosh: Computes the element-wise hyperbolic cosine of the input.

  • jax.numpy.tanh: Computes the element-wise hyperbolic tangent of the input.

  • jax.numpy.arcsinh: Computes the element-wise inverse of hyperbolic sine of the input.

Examples

>>> x = jnp.array([[-2, 3, 5],
...                [0, -1, 4]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sinh(x)
Array([[-3.627, 10.018, 74.203],
       [ 0.   , -1.175, 27.29 ]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
...   -1j * jnp.sin(1j * x)
Array([[-3.627+0.j, 10.018-0.j, 74.203-0.j],
       [ 0.   -0.j, -1.175+0.j, 27.29 -0.j]],      dtype=complex64, weak_type=True)

For complex-valued input:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sinh(3-2j)
Array(-4.169-9.154j, dtype=complex64, weak_type=True)
>>> with jnp.printoptions(precision=3, suppress=True):
...   -1j * jnp.sin(1j * (3-2j))
Array(-4.169-9.154j, dtype=complex64, weak_type=True)
scico.numpy.size(a, axis=None)

Return number of elements along a given axis.

JAX implementation of numpy.size. Unlike np.size, this function raises a TypeError if the input is a collection such as a list or tuple.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex, SupportsSize, SupportsShape]) – array-like object, or any object with a size attribute when axis is not specified, or with a shape attribute when axis is specified.

  • axis (int | Sequence[int] | None) – optional integer or sequence of integers indicating which axis or axes to count elements along. None (the default) returns the total number of elements.

Return type:

int

Returns:

An integer specifying the number of elements in a.

Examples

Size for arrays:

>>> x = jnp.arange(10)
>>> jnp.size(x)
10
>>> y = jnp.ones((2, 3))
>>> jnp.size(y)
6
>>> jnp.size(y, axis=1)
3
>>> jnp.size(y, axis=(1,))
3
>>> jnp.size(y, axis=(0, 1))
6

This also works for scalars:

>>> jnp.size(3.14)
1

For arrays, this can also be accessed via the jax.Array.size property:

>>> y.size
6
scico.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)

Return a sorted copy of an array.

JAX implementation of numpy.sort.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array to sort

  • axis (int | None) – integer axis along which to sort. Defaults to -1, i.e. the last axis. If None, then a is flattened before being sorted.

  • stable (bool) – boolean specifying whether a stable sort should be used. Default=True.

  • descending (bool) – boolean specifying whether to sort in descending order. Default=False.

  • kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.

  • order (None) – not supported by JAX

Return type:

Array

Returns:

Sorted array of shape a.shape (if axis is an integer) or of shape (a.size,) (if axis is None).

Examples

Simple 1-dimensional sort

>>> x = jnp.array([1, 3, 5, 4, 2, 1])
>>> jnp.sort(x)
Array([1, 1, 2, 3, 4, 5], dtype=int32)

Sort along the last axis of an array:

>>> x = jnp.array([[2, 1, 3],
...                [4, 3, 6]])
>>> jnp.sort(x, axis=1)
Array([[1, 2, 3],
       [3, 4, 6]], dtype=int32)

See also

scico.numpy.sort_complex(a)

Return a sorted copy of complex array.

JAX implementation of numpy.sort_complex.

Complex numbers are sorted lexicographically, meaning by their real part first, and then by their imaginary part if real parts are equal.

Parameters:

a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array. If dtype is not complex, the array will be upcast to complex.

Return type:

Array

Returns:

A sorted array of the same shape and complex dtype as the input. If a is multi-dimensional, it is sorted along the last axis.

See also

Examples

>>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j])
>>> jnp.sort_complex(a)
Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)

Multi-dimensional arrays are sorted along the last axis:

>>> a = jnp.array([[5, 3, 4],
...                [6, 9, 2]])
>>> jnp.sort_complex(a)
Array([[3.+0.j, 4.+0.j, 5.+0.j],
       [2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64)
scico.numpy.split(ary, indices_or_sections, axis=0)

Split an array into sub-arrays.

JAX implementation of numpy.split.

Parameters:
  • ary (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array-like object to split

  • indices_or_sections (Union[int, Sequence[int], Array, ndarray, bool, number, bool, float, complex]) –

    either a single integer or a sequence of indices.

    • if indices_or_sections is an integer N, then N must evenly divide ary.shape[axis] and ary will be divided into N equally-sized chunks along axis.

    • if indices_or_sections is a sequence of integers, then these integers specify the boundary between unevenly-sized chunks along axis; see examples below.

  • axis (int) – the axis along which to split; defaults to 0.

Return type:

list[Array]

Returns:

A list of arrays. If indices_or_sections is an integer N, then the list is of length N. If indices_or_sections is a sequence seq, then the list is is of length len(seq) + 1.

Examples

Splitting a 1-dimensional array:

>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])

Split into three equal sections:

>>> chunks = jnp.split(x, 3)
>>> print(*chunks)
[1 2 3] [4 5 6] [7 8 9]

Split into sections by index:

>>> chunks = jnp.split(x, [2, 7])  # [x[0:2], x[2:7], x[7:]]
>>> print(*chunks)
[1 2] [3 4 5 6 7] [8 9]

Splitting a two-dimensional array along axis 1:

>>> x = jnp.array([[1, 2, 3, 4],
...                [5, 6, 7, 8]])
>>> x1, x2 = jnp.split(x, 2, axis=1)
>>> print(x1)
[[1 2]
 [5 6]]
>>> print(x2)
[[3 4]
 [7 8]]

See also

scico.numpy.sqrt(x, /)

Calculates element-wise non-negative square root of the input array.

JAX implementation of numpy.sqrt.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the non-negative square root of the elements of x.

Note

  • For real-valued negative inputs, jnp.sqrt produces a nan output.

  • For complex-valued negative inputs, jnp.sqrt produces a complex output.

See also

Examples

>>> x = jnp.array([-8-6j, 1j, 4])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sqrt(x)
Array([1.   -3.j   , 0.707+0.707j, 2.   +0.j   ], dtype=complex64)
>>> jnp.sqrt(-1)
Array(nan, dtype=float32, weak_type=True)
scico.numpy.square(x, /)

Calculate element-wise square of the input array.

JAX implementation of numpy.square.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the square of the elements of x.

Note

jnp.square is equivalent to computing jnp.power(x, 2).

See also

  • jax.numpy.sqrt: Calculates the element-wise non-negative square root of the input array.

  • jax.numpy.power: Calculates the element-wise base x1 exponential of x2.

  • jax.lax.integer_pow: Computes element-wise power \(x^y\), where \(y\) is a fixed integer.

  • jax.numpy.float_power: Computes the first array raised to the power of second array, element-wise, by promoting to the inexact dtype.

Examples

>>> x = jnp.array([3, -2, 5.3, 1])
>>> jnp.square(x)
Array([ 9.      ,  4.      , 28.090002,  1.      ], dtype=float32)
>>> jnp.power(x, 2)
Array([ 9.      ,  4.      , 28.090002,  1.      ], dtype=float32)

For integer inputs:

>>> x1 = jnp.array([2, 4, 5, 6])
>>> jnp.square(x1)
Array([ 4, 16, 25, 36], dtype=int32)

For complex-valued inputs:

>>> x2 = jnp.array([1-3j, -1j, 2])
>>> jnp.square(x2)
Array([-8.-6.j, -1.+0.j,  4.+0.j], dtype=complex64)
scico.numpy.squeeze(a, axis=None)

Remove one or more length-1 axes from array

JAX implementation of numpy.sqeeze, implemented via jax.lax.squeeze.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array

  • axis (int | Sequence[int] | None) – integer or sequence of integers specifying axes to remove. If any specified axis does not have a length of 1, an error is raised. If not specified, squeeze all length-1 axes in a.

Return type:

Array

Returns:

copy of a with length-1 axes removed.

Notes

Unlike numpy.squeeze, jax.numpy.squeeze will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> x = jnp.array([[[0]], [[1]], [[2]]])
>>> x.shape
(3, 1, 1)

Squeeze all length-1 dimensions:

>>> jnp.squeeze(x)
Array([0, 1, 2], dtype=int32)
>>> _.shape
(3,)

Equivalent while specifying the axes explicitly:

>>> jnp.squeeze(x, axis=(1, 2))
Array([0, 1, 2], dtype=int32)

Attempting to squeeze a non-unit axis results in an error:

>>> jnp.squeeze(x, axis=0)  
Traceback (most recent call last):
  ...
ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

For convenience, this functionality is also available via the jax.Array.squeeze method:

>>> x.squeeze()
Array([0, 1, 2], dtype=int32)
scico.numpy.stack(arrays, axis=0, out=None, dtype=None)

Join arrays along a new axis.

JAX implementation of numpy.stack.

Parameters:
  • arrays (ndarray | Array | Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – a sequence of arrays to stack; each must have the same shape. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.

  • axis (int) – specify the axis along which to stack.

  • out (None) – unused by JAX

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Return type:

Array

Returns:

the stacked result.

See also

Examples

>>> x = jnp.array([1, 2, 3])
>>> y = jnp.array([4, 5, 6])
>>> jnp.stack([x, y])
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.stack([x, y], axis=1)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

unstack performs the inverse operation:

>>> arr = jnp.stack([x, y], axis=1)
>>> x, y = jnp.unstack(arr, axis=1)
>>> x
Array([1, 2, 3], dtype=int32)
>>> y
Array([4, 5, 6], dtype=int32)
scico.numpy.std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)

Compute the standard deviation along a given axis.

JAX implementation of numpy.std.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array.

  • axis (Union[int, Sequence[int], None]) – optional, int or sequence of ints, default=None. Axis along which the standard deviation is computed. If None, standard deviaiton is computed along all the axes.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – The type of the output array. Default=None.

  • ddof (int) – int, default=0. Degrees of freedom. The divisor in the standard deviation computation is N-ddof, N is number of elements along given axis.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – optional, boolean array, default=None. The elements to be used in the standard deviation. Array should be broadcast compatible to the input.

  • mean (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the standard deviation instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean with keepdims=True and axis matching this function’s axis argument.

  • correction (int | float | None) – int or float, default=None. Alternative name for ddof. Both ddof and correction can’t be provided simultaneously.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of the standard deviation along the given axis.

See also

  • jax.numpy.var: Compute the variance of array elements over given axis.

  • jax.numpy.mean: Compute the mean of array elements over a given axis.

  • jax.numpy.nanvar: Compute the variance along a given axis, ignoring NaNs values.

  • jax.numpy.nanstd: Computed the standard deviation of a given axis, ignoring NaN values.

Examples

By default, jnp.std computes the standard deviation along all axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [4, 2, 5, 3],
...                [5, 4, 2, 3]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.std(x)
Array(1.21, dtype=float32)

If axis=0, computes along axis 0.

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0))
[1.7  0.82 1.25 0.47]

To preserve the dimensions of input, you can set keepdims=True.

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0, keepdims=True))
[[1.7  0.82 1.25 0.47]]

If ddof=1:

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0, keepdims=True, ddof=1))
[[2.08 1.   1.53 0.58]]

To include specific elements of the array to compute standard deviation, you can use where.

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 0, 1],
...                    [1, 1, 1, 0]], dtype=bool)
>>> jnp.std(x, axis=0, keepdims=True, where=where)
Array([[2., 1., 1., 0.]], dtype=float32)
scico.numpy.subtract(*args: ArrayLike, out: None = None, where: None = None) Any

Subtract two arrays element-wise.

JAX implementation of numpy.subtract. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc. This function provides the implementation of the - operator for JAX arrays.

Parameters:
  • x – arrays to subtract. Must be broadcastable to a common shape.

  • y – arrays to subtract. Must be broadcastable to a common shape.

Returns:

Array containing the result of the element-wise subtraction.

Examples

Calling subtract explicitly:

>>> x = jnp.arange(4)
>>> jnp.subtract(x, 10)
Array([-10,  -9,  -8,  -7], dtype=int32)

Calling subtract via the - operator:

>>> x - 10
Array([-10,  -9,  -8,  -7], dtype=int32)
scico.numpy.sum(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)

Sum of the elements of the array over a given axis.

JAX implementation of numpy.sum.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Input array.

  • axis (Union[int, Sequence[int], None]) – int or array, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – The type of the output array. Default=None.

  • out (None) – Unused by JAX

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • initial (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array, Default=None. Initial value for the sum.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – int or array, default=None. The elements to be used in the sum. Array should be broadcast compatible to the input.

  • promote_integers (bool) – bool, default=True. If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input. promote_integers is ignored if dtype is specified.

Return type:

Array

Returns:

An array of the sum along the given axis.

See also

  • jax.numpy.prod: Compute the product of array elements over a given axis.

  • jax.numpy.max: Compute the maximum of array elements over given axis.

  • jax.numpy.min: Compute the minimum of array elements over given axis.

Examples

By default, the sum is computed along all the axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [5, 2, 6, 3],
...                [8, 1, 3, 9]])
>>> jnp.sum(x)
Array(47, dtype=int32)

If axis=1, the sum is computed along axis 1.

>>> jnp.sum(x, axis=1)
Array([10, 16, 21], dtype=int32)

If keepdims=True, ndim of the output is equal to that of the input.

>>> jnp.sum(x, axis=1, keepdims=True)
Array([[10],
       [16],
       [21]], dtype=int32)

To include only specific elements in the sum, you can use where.

>>> where=jnp.array([[0, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.sum(x, axis=1, keepdims=True, where=where)
Array([[ 4],
       [ 9],
       [12]], dtype=int32)
>>> where=jnp.array([[False],
...                  [False],
...                  [False]])
>>> jnp.sum(x, axis=0, keepdims=True, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
scico.numpy.swapaxes(a, axis1, axis2)

Swap two axes of an array.

JAX implementation of numpy.swapaxes, implemented in terms of jax.lax.transpose.

Parameters:
Return type:

Array

Returns:

Copy of a with specified axes swapped.

Notes

Unlike numpy.swapaxes, jax.numpy.swapaxes will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> a = jnp.ones((2, 3, 4, 5))
>>> jnp.swapaxes(a, 1, 3).shape
(2, 5, 4, 3)

Equivalent output via the swapaxes array method:

>>> a.swapaxes(1, 3).shape
(2, 5, 4, 3)

Equivalent output via transpose:

>>> a.transpose(0, 3, 2, 1).shape
(2, 5, 4, 3)
scico.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)

Take elements from an array.

JAX implementation of numpy.take, implemented in terms of jax.lax.gather. JAX’s behavior differs from NumPy in the case of out-of-bound indices; see the mode parameter below.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array from which to take values.

  • indices (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array of integer indices of values to take from the array.

  • axis (int | None) – the axis along which to take values. If not specified, the array will be flattened before indexing is applied.

  • mode (str | None) – Out-of-bounds indexing mode, either "fill" or "clip". The default mode="fill" returns invalid values (e.g. NaN) for out-of bounds indices; the fill_value argument gives control over this value. For more discussion of mode options, see jax.numpy.ndarray.at.

  • fill_value (Union[bool, number, bool, int, float, complex, None]) – The fill value to return for out-of-bounds slices when mode is ‘fill’. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

  • unique_indices (bool) – If True, the implementation will assume that the indices are unique after normalization of negative indices, which lets the compiler emit more efficient code during the backward pass. If set to True and normalized indices are not unique, the result is implementation-defined and may be non-deterministic.

  • indices_are_sorted (bool) – If True, the implementation will assume that the indices are sorted in ascending order after normalization of negative indices, which can lead to more efficient execution on some backends. If set to True and normalized indices are not sorted, the output is implementation-defined.

Return type:

Array

Returns:

Array of values extracted from a.

See also

Examples

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([2, 0])

Passing no axis results in indexing into the flattened array:

>>> jnp.take(x, indices)
Array([3., 1.], dtype=float32)
>>> x.ravel()[indices]  # equivalent indexing syntax
Array([3., 1.], dtype=float32)

Passing an axis results ind applying the index to every subarray along the axis:

>>> jnp.take(x, indices, axis=1)
Array([[3., 1.],
       [6., 4.]], dtype=float32)
>>> x[:, indices]  # equivalent indexing syntax
Array([[3., 1.],
       [6., 4.]], dtype=float32)

Out-of-bound indices fill with invalid values. For float inputs, this is NaN:

>>> jnp.take(x, indices, axis=0)
Array([[nan, nan, nan],
       [ 1.,  2.,  3.]], dtype=float32)
>>> x.at[indices].get(mode='fill', fill_value=jnp.nan)  # equivalent indexing syntax
Array([[nan, nan, nan],
       [ 1.,  2.,  3.]], dtype=float32)

This default out-of-bound behavior can be adjusted using the mode parameter, for example, we can instead clip to the last valid value:

>>> jnp.take(x, indices, axis=0, mode='clip')
Array([[4., 5., 6.],
       [1., 2., 3.]], dtype=float32)
>>> x.at[indices].get(mode='clip')  # equivalent indexing syntax
Array([[4., 5., 6.],
       [1., 2., 3.]], dtype=float32)
scico.numpy.tan(x, /)

Compute a trigonometric tangent of each element of input.

JAX implementation of numpy.tan.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – scalar or array. Angle in radians.

Return type:

Array

Returns:

An array containing the tangent of each element in x, promotes to inexact dtype.

See also

Examples

>>> pi = jnp.pi
>>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6])
>>> with jnp.printoptions(precision=3, suppress=True):
...   print(jnp.tan(x))
[ 0.     0.577  1.    -1.    -0.577]
scico.numpy.tanh(x, /)

Calculate element-wise hyperbolic tangent of input.

JAX implementation of numpy.tanh.

The hyperbolic tangent is defined by:

\[tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]
Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array containing the hyperbolic tangent of each element of x, promoting to inexact dtype.

Note

jnp.tanh is equivalent to computing -1j * jnp.tan(1j * x).

See also

  • jax.numpy.sinh: Computes the element-wise hyperbolic sine of the input.

  • jax.numpy.cosh: Computes the element-wise hyperbolic cosine of the input.

  • jax.numpy.arctanh: Computes the element-wise inverse of hyperbolic tangent of the input.

Examples

>>> x = jnp.array([[-1, 0, 1],
...                [3, -2, 5]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.tanh(x)
Array([[-0.762,  0.   ,  0.762],
       [ 0.995, -0.964,  1.   ]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
...   -1j * jnp.tan(1j * x)
Array([[-0.762+0.j,  0.   -0.j,  0.762-0.j],
       [ 0.995-0.j, -0.964+0.j,  1.   -0.j]],      dtype=complex64, weak_type=True)

For complex-valued input:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.tanh(2-5j)
Array(1.031+0.021j, dtype=complex64, weak_type=True)
>>> with jnp.printoptions(precision=3, suppress=True):
...   -1j * jnp.tan(1j * (2-5j))
Array(1.031+0.021j, dtype=complex64, weak_type=True)
scico.numpy.tensordot(a, b, axes=2, *, precision=None, preferred_element_type=None, out_sharding=None)

Compute the tensor dot product of two N-dimensional arrays.

JAX implementation of numpy.linalg.tensordot.

Parameters:
Return type:

Array

Returns:

array containing the tensor dot product of the inputs

See also

Examples

>>> x1 = jnp.arange(24.).reshape(2, 3, 4)
>>> x2 = jnp.ones((3, 4, 5))
>>> jnp.tensordot(x1, x2)
Array([[ 66.,  66.,  66.,  66.,  66.],
       [210., 210., 210., 210., 210.]], dtype=float32)

Equivalent result when specifying the axes as explicit sequences:

>>> jnp.tensordot(x1, x2, axes=([1, 2], [0, 1]))
Array([[ 66.,  66.,  66.,  66.,  66.],
       [210., 210., 210., 210., 210.]], dtype=float32)

Equivalent result via einsum:

>>> jnp.einsum('ijk,jkm->im', x1, x2)
Array([[ 66.,  66.,  66.,  66.,  66.],
       [210., 210., 210., 210., 210.]], dtype=float32)

Setting axes=1 for two-dimensional inputs is equivalent to a matrix multiplication:

>>> x1 = jnp.array([[1, 2],
...                 [3, 4]])
>>> x2 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> jnp.linalg.tensordot(x1, x2, axes=1)
Array([[ 9, 12, 15],
       [19, 26, 33]], dtype=int32)
>>> x1 @ x2
Array([[ 9, 12, 15],
       [19, 26, 33]], dtype=int32)

Setting axes=0 for one-dimensional inputs is equivalent to outer:

>>> x1 = jnp.array([1, 2])
>>> x2 = jnp.array([1, 2, 3])
>>> jnp.linalg.tensordot(x1, x2, axes=0)
Array([[1, 2, 3],
       [2, 4, 6]], dtype=int32)
>>> jnp.outer(x1, x2)
Array([[1, 2, 3],
       [2, 4, 6]], dtype=int32)
scico.numpy.tile(A, reps)

Construct an array by repeating A along specified dimensions.

JAX implementation of numpy.tile.

If A is an array of shape (d1, d2, ..., dn) and reps is a sequence of integers, the resulting array will have a shape of (reps[0] * d1, reps[1] * d2, ..., reps[n] * dn), with A tiled along each dimension.

Parameters:
Return type:

Array

Returns:

a new array where the input array has been repeated according to reps.

See also

Examples

>>> arr = jnp.array([1, 2])
>>> jnp.tile(arr, 2)
Array([1, 2, 1, 2], dtype=int32)
>>> arr = jnp.array([[1, 2],
...                  [3, 4,]])
>>> jnp.tile(arr, (2, 1))
Array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]], dtype=int32)
scico.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)

Calculate sum of the diagonal of input along the given axes.

JAX implementation of numpy.trace.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array. Must have a.ndim >= 2.

  • offset (Union[int, Array, ndarray, bool, number, bool, float, complex]) – optional, int, default=0. Diagonal offset from the main diagonal. Can be positive or negative.

  • axis1 (int) – optional, default=0. The first axis along which to take the sum of diagonal. Must be a static integer value.

  • axis2 (int) – optional, default=1. The second axis along which to take the sum of diagonal. Must be a static integer value.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional. The dtype of the output array. Should be provided as static argument in JIT compilation.

  • out (None) – Not used by JAX.

Return type:

Array

Returns:

An array of dimension x.ndim-2 containing the sum of the diagonal elements along axes (axis1, axis2)

See also

Examples

>>> x = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x
Array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.trace(x)
Array([ 8, 10], dtype=int32)
>>> jnp.trace(x, offset=1)
Array([3, 4], dtype=int32)
>>> jnp.trace(x, axis1=1, axis2=2)
Array([ 5, 13], dtype=int32)
>>> jnp.trace(x, offset=1, axis1=1, axis2=2)
Array([2, 6], dtype=int32)
scico.numpy.transpose(a, axes=None)

Return a transposed version of an N-dimensional array.

JAX implementation of numpy.transpose, implemented in terms of jax.lax.transpose.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array

  • axes (Sequence[int] | None) – optionally specify the permutation using a length-a.ndim sequence of integers i satisfying 0 <= i < a.ndim. Defaults to range(a.ndim)[::-1], i.e. reverses the order of all axes.

Return type:

Array

Returns:

transposed copy of the array.

See also

Note

Unlike numpy.transpose, jax.numpy.transpose will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

Examples

For a 1D array, the transpose is the identity:

>>> x = jnp.array([1, 2, 3, 4])
>>> jnp.transpose(x)
Array([1, 2, 3, 4], dtype=int32)

For a 2D array, the transpose is a matrix transpose:

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.transpose(x)
Array([[1, 3],
       [2, 4]], dtype=int32)

For an N-dimensional array, the transpose reverses the order of the axes:

>>> x = jnp.zeros(shape=(3, 4, 5))
>>> jnp.transpose(x).shape
(5, 4, 3)

The axes argument can be specified to change this default behavior:

>>> jnp.transpose(x, (0, 2, 1)).shape
(3, 5, 4)

Since swapping the last two axes is a common operation, it can be done via its own API, jax.numpy.matrix_transpose:

>>> jnp.matrix_transpose(x).shape
(3, 5, 4)

For convenience, transposes may also be performed using the jax.Array.transpose method or the jax.Array.T property:

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> x.transpose()
Array([[1, 3],
       [2, 4]], dtype=int32)
>>> x.T
Array([[1, 3],
       [2, 4]], dtype=int32)
scico.numpy.tri(N, M=None, k=0, dtype=None)

Return an array with ones on and below the diagonal and zeros elsewhere.

JAX implementation of numpy.tri

Parameters:
  • N (int) – int. Dimension of the rows of the returned array.

  • M (int | None) – optional, int. Dimension of the columns of the returned array. If not specified, then M = N.

  • k (int) – optional, int, default=0. Specifies the sub-diagonal on and below which the array is filled with ones. k=0 refers to main diagonal, k<0 refers to sub-diagonal below the main diagonal and k>0 refers to sub-diagonal above the main diagonal.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional, data type of the returned array. The default type is float.

Return type:

Array

Returns:

An array of shape (N, M) containing the lower triangle with elements below the sub-diagonal specified by k are set to one and zero elsewhere.

See also

Examples

>>> jnp.tri(3)
Array([[1., 0., 0.],
       [1., 1., 0.],
       [1., 1., 1.]], dtype=float32)

When M is not equal to N:

>>> jnp.tri(3, 4)
Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.]], dtype=float32)

when k>0:

>>> jnp.tri(3, k=1)
Array([[1., 1., 0.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

When k<0:

>>> jnp.tri(3, 4, k=-1)
Array([[0., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 1., 0., 0.]], dtype=float32)
scico.numpy.tril_indices(n, k=0, m=None)

Return the indices of lower triangle of an array of size (n, m).

JAX implementation of numpy.tril_indices.

Parameters:
  • n (Union[int, Any]) – int. Number of rows of the array for which the indices are returned.

  • k (Union[int, Any]) – optional, int, default=0. Specifies the sub-diagonal on and below which the indices of lower triangle are returned. k=0 refers to main diagonal, k<0 refers to sub-diagonal below the main diagonal and k>0 refers to sub-diagonal above the main diagonal.

  • m (Union[int, Any, None]) – optional, int. Number of columns of the array for which the indices are returned. If not specified, then m = n.

Return type:

tuple[Array, Array]

Returns:

A tuple of two arrays containing the indices of the lower triangle, one along each axis.

See also

Examples

If only n is provided in input, the indices of lower triangle of an array of size (n, n) array are returned.

>>> jnp.tril_indices(3)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

If both n and m are provided in input, the indices of lower triangle of an (n, m) array are returned.

>>> jnp.tril_indices(3, m=2)
(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1], dtype=int32))

If k = 1, the indices on and below the first sub-diagonal above the main diagonal are returned.

>>> jnp.tril_indices(3, k=1)
(Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))

If k = -1, the indices on and below the first sub-diagonal below the main diagonal are returned.

>>> jnp.tril_indices(3, k=-1)
(Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))
scico.numpy.tril_indices_from(arr, k=0)

Return the indices of lower triangle of a given array.

JAX implementation of numpy.tril_indices_from.

Parameters:
  • arr (Union[Array, ndarray, bool, number, bool, int, float, complex, SupportsShape]) – input array. Must have arr.ndim == 2.

  • k (int) – optional, int, default=0. Specifies the sub-diagonal on and below which the indices of upper triangle are returned. k=0 refers to main diagonal, k<0 refers to sub-diagonal below the main diagonal and k>0 refers to sub-diagonal above the main diagonal.

Return type:

tuple[Array, Array]

Returns:

A tuple of two arrays containing the indices of the lower triangle, one along each axis.

See also

Examples

>>> arr = jnp.array([[1, 2, 3],
...                  [4, 5, 6],
...                  [7, 8, 9]])
>>> jnp.tril_indices_from(arr)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

Elements indexed by jnp.tril_indices_from correspond to those in the output of jnp.tril.

>>> ind = jnp.tril_indices_from(arr)
>>> arr[ind]
Array([1, 4, 5, 7, 8, 9], dtype=int32)
>>> jnp.tril(arr)
Array([[1, 0, 0],
       [4, 5, 0],
       [7, 8, 9]], dtype=int32)

When k > 0:

>>> jnp.tril_indices_from(arr, k=1)
(Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))

When k < 0:

>>> jnp.tril_indices_from(arr, k=-1)
(Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))
scico.numpy.trim_zeros(filt, trim='fb', axis=None)

Trim leading and/or trailing zeros of the input array.

JAX implementation of numpy.trim_zeros.

Parameters:
  • filt (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional input array.

  • trim (str) –

    string, optional, default = fb. Specifies from which end the input is trimmed.

    • f - trims only the leading zeros.

    • b - trims only the trailing zeros.

    • fb - trims both leading and trailing zeros.

  • axis (int | Sequence[int] | None) – optional axis or axes along which to trim. If not specified, trim along all axes of the array.

Return type:

Array

Returns:

An array containing the trimmed input with same dtype as filt.

Examples

One-dimensional input:

>>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0])
>>> jnp.trim_zeros(x)
Array([2, 0, 1, 4, 3], dtype=int32)
>>> jnp.trim_zeros(x, trim='f')
Array([2, 0, 1, 4, 3, 0, 0, 0], dtype=int32)
>>> jnp.trim_zeros(x, trim='b')
Array([0, 0, 2, 0, 1, 4, 3], dtype=int32)

Two-dimensional input:

>>> x = jnp.zeros((4, 5)).at[1:3, 1:4].set(1)
>>> x
Array([[0., 0., 0., 0., 0.],
       [0., 1., 1., 1., 0.],
       [0., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)
>>> jnp.trim_zeros(x)
Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)
>>> jnp.trim_zeros(x, trim='f')
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.],
       [0., 0., 0., 0.]], dtype=float32)
>>> jnp.trim_zeros(x, axis=0)
Array([[0., 1., 1., 1., 0.],
       [0., 1., 1., 1., 0.]], dtype=float32)
>>> jnp.trim_zeros(x, axis=1)
Array([[0., 0., 0.],
       [1., 1., 1.],
       [1., 1., 1.],
       [0., 0., 0.]], dtype=float32)
scico.numpy.triu_indices(n, k=0, m=None)

Return the indices of upper triangle of an array of size (n, m).

JAX implementation of numpy.triu_indices.

Parameters:
  • n (Union[int, Any]) – int. Number of rows of the array for which the indices are returned.

  • k (Union[int, Any]) – optional, int, default=0. Specifies the sub-diagonal on and above which the indices of upper triangle are returned. k=0 refers to main diagonal, k<0 refers to sub-diagonal below the main diagonal and k>0 refers to sub-diagonal above the main diagonal.

  • m (Union[int, Any, None]) – optional, int. Number of columns of the array for which the indices are returned. If not specified, then m = n.

Return type:

tuple[Array, Array]

Returns:

A tuple of two arrays containing the indices of the upper triangle, one along each axis.

See also

Examples

If only n is provided in input, the indices of upper triangle of an array of size (n, n) array are returned.

>>> jnp.triu_indices(3)
(Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))

If both n and m are provided in input, the indices of upper triangle of an (n, m) array are returned.

>>> jnp.triu_indices(3, m=2)
(Array([0, 0, 1], dtype=int32), Array([0, 1, 1], dtype=int32))

If k = 1, the indices on and above the first sub-diagonal above the main diagonal are returned.

>>> jnp.triu_indices(3, k=1)
(Array([0, 0, 1], dtype=int32), Array([1, 2, 2], dtype=int32))

If k = -1, the indices on and above the first sub-diagonal below the main diagonal are returned.

>>> jnp.triu_indices(3, k=-1)
(Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))
scico.numpy.triu_indices_from(arr, k=0)

Return the indices of upper triangle of a given array.

JAX implementation of numpy.triu_indices_from.

Parameters:
  • arr (Union[Array, ndarray, bool, number, bool, int, float, complex, SupportsShape]) – input array. Must have arr.ndim == 2.

  • k (int) – optional, int, default=0. Specifies the sub-diagonal on and above which the indices of upper triangle are returned. k=0 refers to main diagonal, k<0 refers to sub-diagonal below the main diagonal and k>0 refers to sub-diagonal above the main diagonal.

Return type:

tuple[Array, Array]

Returns:

A tuple of two arrays containing the indices of the upper triangle, one along each axis.

See also

Examples

>>> arr = jnp.array([[1, 2, 3],
...                  [4, 5, 6],
...                  [7, 8, 9]])
>>> jnp.triu_indices_from(arr)
(Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))

Elements indexed by jnp.triu_indices_from correspond to those in the output of jnp.triu.

>>> ind = jnp.triu_indices_from(arr)
>>> arr[ind]
Array([1, 2, 3, 5, 6, 9], dtype=int32)
>>> jnp.triu(arr)
Array([[1, 2, 3],
       [0, 5, 6],
       [0, 0, 9]], dtype=int32)

When k > 0:

>>> jnp.triu_indices_from(arr, k=1)
(Array([0, 0, 1], dtype=int32), Array([1, 2, 2], dtype=int32))

When k < 0:

>>> jnp.triu_indices_from(arr, k=-1)
(Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))
scico.numpy.true_divide(x1, x2, /)

Calculates the division of x1 by x2 element-wise

JAX implementation of numpy.true_divide.

Parameters:
Return type:

Array

Returns:

An array containing the elementwise quotients, will always use floating point division.

Examples

>>> x1 = jnp.array([3, 4, 5])
>>> x2 = 2
>>> jnp.true_divide(x1, x2)
Array([1.5, 2. , 2.5], dtype=float32)
>>> x1 = 24
>>> x2 = jnp.array([3, 4, 6j])
>>> jnp.true_divide(x1, x2)
Array([8.+0.j, 6.+0.j, 0.-4.j], dtype=complex64)
>>> x1 = jnp.array([1j, 9+5j, -4+2j])
>>> x2 = 3j
>>> jnp.true_divide(x1, x2)
Array([0.33333334+0.j       , 1.6666666 -3.j       ,
       0.6666667 +1.3333334j], dtype=complex64)

See also

jax.numpy.floor_divide for integer division

scico.numpy.trunc(x)

Round input to the nearest integer towards zero.

JAX implementation of numpy.trunc.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array or scalar.

Return type:

Array

Returns:

An array with same shape and dtype as x containing the rounded values.

See also

  • jax.numpy.fix: Rounds the input to the nearest integer towards zero.

  • jax.numpy.ceil: Rounds the input up to the nearest integer.

  • jax.numpy.floor: Rounds the input down to the nearest integer.

Examples

>>> key = jax.random.key(42)
>>> x = jax.random.uniform(key, (3, 3), minval=-10, maxval=10)
>>> with jnp.printoptions(precision=2, suppress=True):
...     print(x)
[[-0.23  3.6   2.33]
 [ 1.22 -0.99  1.72]
 [-8.5   5.5   3.98]]
>>> jnp.trunc(x)
Array([[-0.,  3.,  2.],
       [ 1., -0.,  1.],
       [-8.,  5.,  3.]], dtype=float32)
scico.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)

Compute the set union of two 1D arrays.

JAX implementation of numpy.union1d.

Because the size of the output of union1d is data-dependent, the function is not typically compatible with jit and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.union1d to be used in such contexts.

Parameters:
Return type:

Array

Returns:

an array containing the union of elements in the input array.

See also

Examples

Computing the union of two arrays:

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.union1d(ar1, ar2)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

Because the output shape is dynamic, this will fail under jit and other transformations:

>>> jax.jit(jnp.union1d)(ar1, ar2)  
Traceback (most recent call last):
   ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.

In order to ensure statically-known output shapes, you can pass a static size argument:

>>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size'])
>>> jit_union1d(ar1, ar2, size=6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

If size is too small, the union is truncated:

>>> jit_union1d(ar1, ar2, size=4)
Array([1, 2, 3, 4], dtype=int32)

If size is too large, then the output is padded with fill_value:

>>> jit_union1d(ar1, ar2, size=8, fill_value=0)
Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)
scico.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None, sorted=True)

Return the unique values from an array.

JAX implementation of numpy.unique.

Because the size of the output of unique is data-dependent, the function is not typically compatible with jit and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.

Parameters:
  • ar (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array from which unique values will be extracted.

  • return_index (bool) – if True, also return the indices in ar where each value occurs

  • return_inverse (bool) – if True, also return the indices that can be used to reconstruct ar from the unique values.

  • return_counts (bool) – if True, also return the number of occurrences of each unique value.

  • axis (int | None) – if specified, compute unique values along the specified axis. If None (default), then flatten ar before computing the unique values.

  • equal_nan (bool) – if True, consider NaN values equivalent when determining uniqueness.

  • size (int | None) – if specified, return only the first size sorted unique elements. If there are fewer unique elements than size indicates, the return value will be padded with fill_value.

  • fill_value (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the minimum unique value.

  • sorted (bool) – unused by JAX.

Returns:

An array or tuple of arrays, depending on the values of return_index, return_inverse, and return_counts. Returned values are

  • unique_values:

    if axis is None, a 1D array of length n_unique, If axis is specified, shape is (*ar.shape[:axis], n_unique, *ar.shape[axis + 1:]).

  • unique_index:

    (returned only if return_index is True) An array of shape (n_unique,). Contains the indices of the first occurrence of each unique value in ar. For 1D inputs, ar[unique_index] is equivalent to unique_values.

  • unique_inverse:

    (returned only if return_inverse is True) An array of shape (ar.size,) if axis is None, or of shape (ar.shape[axis],) if axis is specified. Contains the indices within unique_values of each value in ar. For 1D inputs, unique_values[unique_inverse] is equivalent to ar.

  • unique_counts:

    (returned only if return_counts is True) An array of shape (n_unique,). Contains the number of occurrences of each unique value in ar.

See also

Examples

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> jnp.unique(x)
Array([1, 3, 4], dtype=int32)

JIT compilation & the size argument

If you try this under jit or another transformation, you will get an error because the output shape is dynamic:

>>> jax.jit(jnp.unique)(x)  
Traceback (most recent call last):
   ...
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5].
The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.

The issue is that the output of transformed functions must have static shapes. In order to make this work, you can pass a static size parameter:

>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size'])
>>> jit_unique(x, size=3)
Array([1, 3, 4], dtype=int32)

If your static size is smaller than the true number of unique values, they will be truncated.

>>> jit_unique(x, size=2)
Array([1, 3], dtype=int32)

If the static size is larger than the true number of unique values, they will be padded with fill_value, which defaults to the minimum unique value:

>>> jit_unique(x, size=5)
Array([1, 3, 4, 1, 1], dtype=int32)
>>> jit_unique(x, size=5, fill_value=0)
Array([1, 3, 4, 0, 0], dtype=int32)

Multi-dimensional unique values

If you pass a multi-dimensional array to unique, it will be flattened by default:

>>> M = jnp.array([[1, 2],
...                [2, 3],
...                [1, 2]])
>>> jnp.unique(M)
Array([1, 2, 3], dtype=int32)

If you pass an axis keyword, you can find unique slices of the array along that axis:

>>> jnp.unique(M, axis=0)
Array([[1, 2],
       [2, 3]], dtype=int32)

Returning indices

If you set return_index=True, then unique returns the indices of the first occurrence of each unique value:

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, indices = jnp.unique(x, return_index=True)
>>> print(values)
[1 3 4]
>>> print(indices)
[2 0 1]
>>> jnp.all(values == x[indices])
Array(True, dtype=bool)

In multiple dimensions, the unique values can be extracted with jax.numpy.take evaluated along the specified axis:

>>> values, indices = jnp.unique(M, axis=0, return_index=True)
>>> jnp.all(values == jnp.take(M, indices, axis=0))
Array(True, dtype=bool)

Returning inverse

If you set return_inverse=True, then unique returns the indices within the unique values for every entry in the input array:

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, inverse = jnp.unique(x, return_inverse=True)
>>> print(values)
[1 3 4]
>>> print(inverse)
[1 2 0 1 0]
>>> jnp.all(values[inverse] == x)
Array(True, dtype=bool)

In multiple dimensions, the input can be reconstructed using jax.numpy.take:

>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
>>> jnp.all(jnp.take(values, inverse, axis=0) == M)
Array(True, dtype=bool)

Returning counts

If you set return_counts=True, then unique returns the number of occurrences within the input for every unique value:

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, counts = jnp.unique(x, return_counts=True)
>>> print(values)
[1 3 4]
>>> print(counts)
[2 2 1]

For multi-dimensional arrays, this also returns a 1D array of counts indicating number of occurrences along the specified axis:

>>> values, counts = jnp.unique(M, axis=0, return_counts=True)
>>> print(values)
[[1 2]
 [2 3]]
>>> print(counts)
[2 1]
scico.numpy.unique_all(x, /, *, size=None, fill_value=None)

Return unique values from x, along with indices, inverse indices, and counts.

JAX implementation of numpy.unique_all; this is equivalent to calling jax.numpy.unique with return_index, return_inverse, return_counts, and equal_nan set to True.

Because the size of the output of unique_all is data-dependent, the function is not typically compatible with jit and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array from which unique values will be extracted.

  • size (int | None) – if specified, return only the first size sorted unique elements. If there are fewer unique elements than size indicates, the return value will be padded with fill_value.

  • fill_value (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the minimum unique value.

Return type:

_UniqueAllResult

Returns:

A tuple (values, indices, inverse_indices, counts), with the following properties –

  • values:

    an array of shape (n_unique,) containing the unique values from x.

  • indices:

    An array of shape (n_unique,). Contains the indices of the first occurrence of each unique value in x. For 1D inputs, x[indices] is equivalent to values.

  • inverse_indices:

    An array of shape x.shape. Contains the indices within values of each value in x. For 1D inputs, values[inverse_indices] is equivalent to x.

  • counts:

    An array of shape (n_unique,). Contains the number of occurrences of each unique value in x.

See also

Examples

Here we compute the unique values in a 1D array:

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_all(x)

The result is a NamedTuple with four named attributes. The values attribute contains the unique values from the array:

>>> result.values
Array([1, 3, 4], dtype=int32)

The indices attribute contains the indices of the unique values within the input array:

>>> result.indices
Array([2, 0, 1], dtype=int32)
>>> jnp.all(result.values == x[result.indices])
Array(True, dtype=bool)

The inverse_indices attribute contains the indices of the input within values:

>>> result.inverse_indices
Array([1, 2, 0, 1, 0], dtype=int32)
>>> jnp.all(x == result.values[result.inverse_indices])
Array(True, dtype=bool)

The counts attribute contains the counts of each unique value in the input:

>>> result.counts
Array([2, 2, 1], dtype=int32)

For examples of the size and fill_value arguments, see jax.numpy.unique.

scico.numpy.unique_counts(x, /, *, size=None, fill_value=None)

Return unique values from x, along with counts.

JAX implementation of numpy.unique_counts; this is equivalent to calling jax.numpy.unique with return_counts and equal_nan set to True.

Because the size of the output of unique_counts is data-dependent, the function is not typically compatible with jit and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array from which unique values will be extracted.

  • size (int | None) – if specified, return only the first size sorted unique elements. If there are fewer unique elements than size indicates, the return value will be padded with fill_value.

  • fill_value (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the minimum unique value.

Return type:

_UniqueCountsResult

Returns:

A tuple (values, counts), with the following properties –

  • values:

    an array of shape (n_unique,) containing the unique values from x.

  • counts:

    An array of shape (n_unique,). Contains the number of occurrences of each unique value in x.

See also

Examples

Here we compute the unique values in a 1D array:

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_counts(x)

The result is a NamedTuple with two named attributes. The values attribute contains the unique values from the array:

>>> result.values
Array([1, 3, 4], dtype=int32)

The counts attribute contains the counts of each unique value in the input:

>>> result.counts
Array([2, 2, 1], dtype=int32)

For examples of the size and fill_value arguments, see jax.numpy.unique.

scico.numpy.unique_inverse(x, /, *, size=None, fill_value=None)

Return unique values from x, along with indices, inverse indices, and counts.

JAX implementation of numpy.unique_inverse; this is equivalent to calling jax.numpy.unique with return_inverse and equal_nan set to True.

Because the size of the output of unique_inverse is data-dependent, the function is not typically compatible with jit and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array from which unique values will be extracted.

  • size (int | None) – if specified, return only the first size sorted unique elements. If there are fewer unique elements than size indicates, the return value will be padded with fill_value.

  • fill_value (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the minimum unique value.

Return type:

_UniqueInverseResult

Returns:

A tuple (values, indices, inverse_indices, counts), with the following properties –

  • values:

    an array of shape (n_unique,) containing the unique values from x.

  • inverse_indices:

    An array of shape x.shape. Contains the indices within values of each value in x. For 1D inputs, values[inverse_indices] is equivalent to x.

See also

Examples

Here we compute the unique values in a 1D array:

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_inverse(x)

The result is a NamedTuple with two named attributes. The values attribute contains the unique values from the array:

>>> result.values
Array([1, 3, 4], dtype=int32)

The indices attribute contains the indices of the unique values within the input array:

The inverse_indices attribute contains the indices of the input within values:

>>> result.inverse_indices
Array([1, 2, 0, 1, 0], dtype=int32)
>>> jnp.all(x == result.values[result.inverse_indices])
Array(True, dtype=bool)

For examples of the size and fill_value arguments, see jax.numpy.unique.

scico.numpy.unique_values(x, /, *, size=None, fill_value=None)

Return unique values from x, along with indices, inverse indices, and counts.

JAX implementation of numpy.unique_values; this is equivalent to calling jax.numpy.unique with equal_nan set to True.

Because the size of the output of unique_values is data-dependent, the function is not typically compatible with jit and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array from which unique values will be extracted.

  • size (int | None) – if specified, return only the first size sorted unique elements. If there are fewer unique elements than size indicates, the return value will be padded with fill_value.

  • fill_value (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the minimum unique value.

Return type:

Array

Returns:

An array values of shape (n_unique,) containing the unique values from x.

See also

Examples

Here we compute the unique values in a 1D array:

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> jnp.unique_values(x)
Array([1, 3, 4], dtype=int32)

For examples of the size and fill_value arguments, see jax.numpy.unique.

scico.numpy.unravel_index(indices, shape)

Convert flat indices into multi-dimensional indices.

JAX implementation of numpy.unravel_index. The JAX version differs in its treatment of out-of-bound indices: unlike NumPy, negative indices are supported, and out-of-bound indices are clipped to the nearest valid value.

Parameters:
Return type:

tuple[Array, ...]

Returns:

Tuple of unraveled indices

See also

jax.numpy.ravel_multi_index: Inverse of this function.

Examples

Start with a 1D array values and indices:

>>> x = jnp.array([2., 3., 4., 5., 6., 7.])
>>> indices = jnp.array([1, 3, 5])
>>> print(x[indices])
[3. 5. 7.]

Now if x is reshaped, unravel_indices can be used to convert the flat indices into a tuple of indices that access the same entries:

>>> shape = (2, 3)
>>> x_2D = x.reshape(shape)
>>> indices_2D = jnp.unravel_index(indices, shape)
>>> indices_2D
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
>>> print(x_2D[indices_2D])
[3. 5. 7.]

The inverse function, ravel_multi_index, can be used to obtain the original indices:

>>> jnp.ravel_multi_index(indices_2D, shape)
Array([1, 3, 5], dtype=int32)
scico.numpy.unwrap(p, discont=None, axis=-1, period=6.283185307179586)

Unwrap a periodic signal.

JAX implementation of numpy.unwrap.

Parameters:
Return type:

Array

Returns:

An unwrapped copy of p.

Notes

This implementation follows that of numpy.unwrap, and is not well-suited for integer-period unwrapping of narrow-width integers (e.g. int8, int16) or unsigned integers.

Examples

Consider a situation in which you are making measurements of the position of a rotating disk via the x and y locations of some point on that disk. The underlying variable is an always-increasing angle which we’ll generate this way, using degrees for ease of representation:

>>> rng = np.random.default_rng(0)
>>> theta = rng.integers(0, 90, size=(20,)).cumsum()
>>> theta
array([ 76, 133, 179, 203, 230, 233, 239, 240, 255, 328, 386, 468, 513,
       567, 654, 719, 775, 823, 873, 957])

Our observations of this angle are the x and y coordinates, given by the sine and cosine of this underlying angle:

>>> x, y = jnp.sin(jnp.deg2rad(theta)), jnp.cos(jnp.deg2rad(theta))

Now, say that given these x and y coordinates, we wish to recover the original angle theta. We might do this via the atan2 function:

>>> theta_out = jnp.rad2deg(jnp.atan2(x, y)).round()
>>> theta_out
Array([  76.,  133.,  179., -157., -130., -127., -121., -120., -105.,
        -32.,   26.,  108.,  153., -153.,  -66.,   -1.,   55.,  103.,
        153., -123.], dtype=float32)

The first few values match the input angle theta above, but after this the values are wrapped because the sin and cos observations obscure the phase information. The purpose of the unwrap function is to recover the original signal from this wrapped view of it:

>>> jnp.unwrap(theta_out, period=360)
Array([ 76., 133., 179., 203., 230., 233., 239., 240., 255., 328., 386.,
       468., 513., 567., 654., 719., 775., 823., 873., 957.],      dtype=float32)

It does this by assuming that the true underlying sequence does not differ by more than discont (which defaults to period / 2) within a single step, and when it encounters a larger discontinuity it adds factors of the period to the data. For periodic signals that satisfy this assumption, unwrap can recover the original phased signal.

scico.numpy.var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)

Compute the variance along a given axis.

JAX implementation of numpy.var.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array.

  • axis (Union[int, Sequence[int], None]) – optional, int or sequence of ints, default=None. Axis along which the variance is computed. If None, variance is computed along all the axes.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – The type of the output array. Default=None.

  • ddof (int) – int, default=0. Degrees of freedom. The divisor in the variance computation is N-ddof, N is number of elements along given axis.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – optional, boolean array, default=None. The elements to be used in the variance. Array should be broadcast compatible to the input.

  • mean (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the variance instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean with keepdims=True and axis matching this function’s axis argument.

  • correction (int | float | None) – int or float, default=None. Alternative name for ddof. Both ddof and correction can’t be provided simultaneously.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of the variance along the given axis.

See also

  • jax.numpy.mean: Compute the mean of array elements over a given axis.

  • jax.numpy.std: Compute the standard deviation of array elements over given axis.

  • jax.numpy.nanvar: Compute the variance along a given axis, ignoring NaNs values.

  • jax.numpy.nanstd: Computed the standard deviation of a given axis, ignoring NaN values.

Examples

By default, jnp.var computes the variance along all axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [5, 2, 6, 3],
...                [8, 4, 2, 9]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.var(x)
Array(5.74, dtype=float32)

If axis=1, variance is computed along axis 1.

>>> jnp.var(x, axis=1)
Array([1.25  , 2.5   , 8.1875], dtype=float32)

To preserve the dimensions of input, you can set keepdims=True.

>>> jnp.var(x, axis=1, keepdims=True)
Array([[1.25  ],
       [2.5   ],
       [8.1875]], dtype=float32)

If ddof=1:

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.var(x, axis=1, keepdims=True, ddof=1))
[[ 1.67]
 [ 3.33]
 [10.92]]

To include specific elements of the array to compute variance, you can use where.

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 1, 0],
...                    [1, 1, 1, 0]], dtype=bool)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.var(x, axis=1, keepdims=True, where=where))
[[2.25]
 [4.  ]
 [6.22]]
scico.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)

Perform a conjugate multiplication of two 1D vectors.

JAX implementation of numpy.vdot.

Parameters:
Return type:

Array

Returns:

Scalar array (shape ()) containing the conjugate vector product of the inputs.

See also

Examples

>>> x = jnp.array([1j, 2j, 3j])
>>> y = jnp.array([1., 2., 3.])
>>> jnp.vdot(x, y)
Array(0.-14.j, dtype=complex64)

Note the difference between this and dot, which does not conjugate the first input when complex:

>>> jnp.dot(x, y)
Array(0.+14.j, dtype=complex64)
scico.numpy.vecdot(x1, x2, /, *, axis=-1, precision=None, preferred_element_type=None)

Perform a conjugate multiplication of two batched vectors.

JAX implementation of numpy.vecdot.

Parameters:
  • a – left-hand side array.

  • b – right-hand side array. Size of b[axis] must match size of a[axis], and remaining dimensions must be broadcast-compatible.

  • axis (int) – axis along which to compute the dot product (default: -1)

  • precision (Union[None, str, Precision, tuple[str, str], tuple[Precision, Precision], DotAlgorithm, DotAlgorithmPreset]) – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two such values indicating precision of a and b.

  • preferred_element_type (Union[str, type[Any], dtype, SupportsDType, None]) – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Return type:

Array

Returns:

array containing the conjugate dot product of a and b along axis. The non-contracted dimensions are broadcast together.

See also

Examples

Vector conjugate-dot product of two 1D arrays:

>>> a = jnp.array([1j, 2j, 3j])
>>> b = jnp.array([4., 5., 6.])
>>> jnp.linalg.vecdot(a, b)
Array(0.-32.j, dtype=complex64)

Batched vector dot product of two 2D arrays:

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> b = jnp.array([[2, 3, 4]])
>>> jnp.linalg.vecdot(a, b, axis=-1)
Array([20, 47], dtype=int32)
scico.numpy.vectorize(pyfunc, *, excluded=frozenset({}), signature=None)

Define a vectorized function with broadcasting.

vectorize is a convenience wrapper for defining vectorized functions with broadcasting, in the style of NumPy’s generalized universal functions. It allows for defining functions that are automatically repeated across any leading dimensions, without the implementation of the function needing to be concerned about how to handle higher dimensional inputs.

jax.numpy.vectorize has the same interface as numpy.vectorize, but it is syntactic sugar for an auto-batching transformation (vmap) rather than a Python loop. This should be considerably more efficient, but the implementation must be written in terms of functions that act on JAX arrays.

Parameters:
  • pyfunc – function to vectorize.

  • excluded – optional set of integers representing positional arguments for which the function will not be vectorized. These will be passed directly to pyfunc unmodified.

  • signature – optional generalized universal function signature, e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication. If provided, pyfunc will be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By default, pyfunc is assumed to take scalar arrays as input, and if signature is None, pyfunc can produce outputs of any shape.

Returns:

Vectorized version of the given function.

Examples

Here are a few examples of how one could write vectorized linear algebra routines using vectorize:

>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
... def cross_product(a, b):
...   assert a.shape == b.shape and a.ndim == b.ndim == 1
...   return jnp.array([a[1] * b[2] - a[2] * b[1],
...                     a[2] * b[0] - a[0] * b[2],
...                     a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)')
... def matrix_vector_product(matrix, vector):
...   assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
...   return matrix @ vector

These functions are only written to handle 1D or 2D arrays (the assert statements will never be violated), but with vectorize they support arbitrary dimensional inputs with NumPy style broadcasting, e.g.,

>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
(3,)
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3))  
Traceback (most recent call last):
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2,)
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2)

Note that this has different semantics than jnp.matmul:

>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3)))  
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].
scico.numpy.vsplit(ary, indices_or_sections)

Split an array into sub-arrays vertically.

JAX implementation of numpy.vsplit.

Refer to the documentation of jax.numpy.split for details; vsplit is equivalent to split with axis=0.

Return type:

list[Array]

Examples

1D array:

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> x1, x2 = jnp.vsplit(x, 2)
>>> print(x1, x2)
[1 2 3] [4 5 6]

2D array:

>>> x = jnp.array([[1, 2, 3, 4],
...                [5, 6, 7, 8]])
>>> x1, x2 = jnp.vsplit(x, 2)
>>> print(x1, x2)
[[1 2 3 4]] [[5 6 7 8]]

See also

scico.numpy.vstack(tup, dtype=None)

Vertically stack arrays.

JAX implementation of numpy.vstack.

For arrays of two or more dimensions, this is equivalent to jax.numpy.concatenate with axis=0.

Parameters:
  • tup (ndarray | Array | Sequence[Union[Array, ndarray, bool, number, bool, int, float, complex]]) – a sequence of arrays to stack; each must have the same shape along all but the first axis. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Return type:

Array

Returns:

the stacked result.

See also

Examples

Scalar values:

>>> jnp.vstack([1, 2, 3])
Array([[1],
       [2],
       [3]], dtype=int32, weak_type=True)

1D arrays:

>>> x = jnp.arange(4)
>>> y = jnp.ones(4)
>>> jnp.vstack([x, y])
Array([[0., 1., 2., 3.],
       [1., 1., 1., 1.]], dtype=float32)

2D arrays:

>>> x = x.reshape(1, 4)
>>> y = y.reshape(1, 4)
>>> jnp.vstack([x, y])
Array([[0., 1., 2., 3.],
       [1., 1., 1., 1.]], dtype=float32)
scico.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)

Select elements from two arrays based on a condition.

JAX implementation of numpy.where.

Note

when only condition is provided, jnp.where(condition) is equivalent to jnp.nonzero(condition). For that case, refer to the documentation of jax.numpy.nonzero. The docstring below focuses on the case where x and y are specified.

The three-term version of jnp.where lowers to jax.lax.select.

Parameters:
  • condition – boolean array. Must be broadcast-compatible with x and y when they are specified.

  • x – arraylike. Should be broadcast-compatible with condition and y, and typecast-compatible with y.

  • y – arraylike. Should be broadcast-compatible with condition and x, and typecast-compatible with x.

  • size – integer, only referenced when x and y are None. For details, see jax.numpy.nonzero.

  • fill_value – only referenced when x and y are None. For details, see jax.numpy.nonzero.

Returns:

An array of dtype jnp.result_type(x, y) with values drawn from x where condition is True, and from y where condition is False. If x and y are None, the function behaves differently; see jax.numpy.nonzero for a description of the return type.

Notes

Special care is needed when the x or y input to jax.numpy.where could have a value of NaN. Specifically, when a gradient is taken with jax.grad (reverse-mode differentiation), a NaN in either x or y will propagate into the gradient, regardless of the value of condition. More information on this behavior and workarounds is available in the JAX FAQ.

Examples

When x and y are not provided, where behaves equivalently to jax.numpy.nonzero:

>>> x = jnp.arange(10)
>>> jnp.where(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)
>>> jnp.nonzero(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)

When x and y are provided, where selects between them based on the specified condition:

>>> jnp.where(x > 4, x, 0)
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
scico.numpy.zeros(shape, dtype=None, *, device=None, out_sharding=None)

Create an array full of zeros.

JAX implementation of numpy.zeros.

Parameters:
  • shape (Any) – int or sequence of ints specifying the shape of the created array.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – optional dtype for the created array; defaults to float32 or float64 depending on the X64 configuration (see Default dtypes and the X64 flag).

  • device (Device | Sharding | None) – (optional) Device or Sharding to which the created array will be committed. This argument exists for compatibility with the Python Array API standard.

  • out_sharding (NamedSharding | P | None) – (optional) PartitionSpec or NamedSharding representing the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying both out_sharding and device will result in an error.

Return type:

Array

Returns:

Array of the specified shape and dtype, with the given device/sharding if specified.

Examples

>>> jnp.zeros(4)
Array([0., 0., 0., 0.], dtype=float32)
>>> jnp.zeros((2, 3), dtype=bool)
Array([[False, False, False],
       [False, False, False]], dtype=bool)
scico.numpy.zeros_like(a, dtype=None, shape=None, *, device=None, out_sharding=None)

Create an array full of zeros with the same shape and dtype as an array.

JAX implementation of numpy.zeros_like.

Parameters:
Return type:

Array

Returns:

Array of the specified shape and dtype, on the specified device if specified.

Examples

>>> x = jnp.arange(4)
>>> jnp.zeros_like(x)
Array([0, 0, 0, 0], dtype=int32)
>>> jnp.zeros_like(x, dtype=bool)
Array([False, False, False, False], dtype=bool)
>>> jnp.zeros_like(x, shape=(2, 3))
Array([[0, 0, 0],
       [0, 0, 0]], dtype=int32)