scico.numpy¶
BlockArray and compatible functions.
This module consists of BlockArray and functions that support
both instances of this class and jax arrays. This includes all the
functions from jax.numpy and numpy.testing, where many have
been extended to automatically map over block array blocks as described
in NumPy and SciPy Functions. Also included are additional
functions unique to SCICO in util.
Modules
Discrete Fourier Transform functions. |
|
Linear algebra functions. |
|
Test support functions. |
|
Utility functions for working with jax arrays and BlockArrays. |
Functions
|
Alias of |
|
Calculate the absolute value element-wise. |
|
Add two arrays element-wise. |
|
Test whether all array elements along a given axis evaluate to True. |
|
Check if two arrays are element-wise approximately equal within a tolerance. |
|
Alias of |
|
Alias of |
|
Return the angle of a complex valued number or array. |
|
Test whether any of the array elements along a given axis evaluate to True. |
|
Return a new array with values appended to the end of the original array. |
|
Apply a function to 1D array slices along an axis. |
|
Apply a function repeatedly over specified axes. |
|
Create an array of evenly-spaced values. |
|
Compute element-wise inverse of trigonometric cosine of input. |
|
Calculate element-wise inverse of hyperbolic cosine of input. |
|
Compute element-wise inverse of trigonometric sine of input. |
|
Calculate element-wise inverse of hyperbolic sine of input. |
|
Compute element-wise inverse of trigonometric tangent of input. |
|
Compute the arctangent of x1/x2, choosing the correct quadrant. |
|
Calculate element-wise inverse of hyperbolic tangent of input. |
|
Return the index of the maximum value of an array. |
|
Return the index of the minimum value of an array. |
|
Return indices that sort an array. |
|
Find the indices of nonzero array elements |
|
Alias of |
|
Convert an object to a JAX array. |
|
Check if two arrays are element-wise equal. |
|
Check if two arrays are element-wise equal. |
|
Split an array into sub-arrays. |
|
Convert an object to a JAX array. |
|
Convert an array to a specified dtype. |
|
Convert inputs to arrays with at least 1 dimension. |
|
Convert inputs to arrays with at least 2 dimensions. |
|
Convert inputs to arrays with at least 3 dimensions. |
|
Compute the weighed average. |
|
Return a Bartlett window of size M. |
|
Count the number of occurrences of each value in an integer array. |
|
Return a Blackman window of size M. |
|
Create an array from a list of blocks. |
|
Construct a |
|
Broadcast arrays to a common shape. |
|
Broadcast input shapes to a common output shape. |
|
Broadcast an array to a specified shape. |
|
Calculates element-wise cube root of the input array. |
|
Round input to the nearest integer upwards. |
|
Construct an array by stacking slices of choice arrays. |
|
Clip array values to a specified range. |
|
Stack arrays column-wise. |
|
Compress an array along a given axis using a boolean condition. |
|
Join arrays along an existing axis. |
|
Join arrays along an existing axis. |
|
Alias of |
|
Return element-wise complex-conjugate of the input. |
|
Convolution of two one dimensional arrays. |
|
Return a copy of the array. |
|
Copies the sign of each element in |
|
Compute a trigonometric cosine of each element of input. |
|
Calculate element-wise hyperbolic cosine of input. |
|
Return the number of nonzero elements along a given axis. |
|
Compute the (batched) cross product of two arrays. |
|
Cumulative product of elements along an axis. |
|
Cumulative sum of elements along an axis. |
|
Cumulative product along the axis of an array. |
|
Cumulative sum along the axis of an array. |
|
Convert angles from degrees to radians. |
|
Alias of |
|
Delete entry or entries from an array. |
|
Returns the specified diagonal or constructs a diagonal array. |
|
Return indices for accessing the main diagonal of a multidimensional array. |
|
Return indices for accessing the main diagonal of a given array. |
|
Return a 2-D array with the flattened input array laid out on the diagonal. |
|
Calculate n-th order difference between array elements along a given axis. |
|
Alias of |
|
Calculates the integer quotient and remainder of x1 by x2 element-wise |
|
Compute the dot product of two arrays. |
|
Split an array into sub-arrays depth-wise. |
|
Stack arrays depth-wise. |
|
Compute the differences of the elements of the flattened array. |
|
Einstein summation |
|
Evaluates the optimal contraction path without evaluating the einsum. |
|
Create an empty array. |
|
Create an empty array with the same shape and dtype as an array. |
|
Returns element-wise truth value of |
|
Calculate element-wise exponential of the input. |
|
Calculate element-wise base-2 exponential of input. |
|
Insert dimensions of length 1 into array |
|
Calculate |
|
Return the elements of an array that satisfy a condition. |
|
Create a square or rectangular identity matrix |
|
Compute the element-wise absolute values of the real-valued input. |
|
Return a copy of the array with the diagonal overwritten. |
|
Return indices of nonzero elements in a flattened array |
|
Reverse the order of elements of an array along the given axis. |
|
Reverse the order of elements of an array along axis 1. |
|
Reverse the order of elements of an array along axis 0. |
|
Calculate element-wise base |
|
Round input to the nearest integer downwards. |
|
Calculates the floor division of x1 by x2 element-wise |
|
Return element-wise maximum of the input arrays. |
|
Return element-wise minimum of the input arrays. |
|
Calculate element-wise floating-point modulo operation. |
|
Split floating point values into mantissa and twos exponent. |
|
Construct a JAX array via DLPack. |
|
Convert a buffer into a 1-D JAX array. |
|
Unimplemented JAX wrapper for jnp.fromfile. |
|
Create an array from a function applied over indices. |
|
Unimplemented JAX wrapper for jnp.fromiter. |
|
Create a JAX ufunc from an arbitrary JAX-compatible scalar function. |
|
Convert a string of text into 1-D JAX array. |
|
Create an array full of a specified value. |
|
Create an array full of a specified value with the same shape and dtype as an array. |
|
Compute the greatest common divisor of two arrays. |
|
Generate geometrically-spaced values. |
Alias of |
|
|
Compute the numerical gradient of a sampled function. |
|
Return element-wise truth value of |
|
Return element-wise truth value of |
|
Return a Hamming window of size M. |
|
Return a Hanning window of size M. |
|
Compute the heaviside step function. |
|
Compute a 1-dimensional histogram. |
|
Compute a 2-dimensional histogram. |
|
Compute the bin edges for a histogram. |
|
Compute an N-dimensional histogram. |
|
Split an array into sub-arrays horizontally. |
|
Horizontally stack arrays. |
|
Return element-wise hypotenuse for the given legs of a right angle triangle. |
|
Calculate modified Bessel function of first kind, zeroth order. |
|
Create a square identity matrix |
|
Return element-wise imaginary of part of the complex argument. |
|
Generate arrays of grid indices. |
|
Compute the inner product of two arrays. |
|
Insert entries into an array at specified indices. |
|
One-dimensional linear interpolation. |
|
Compute the set intersection of two 1D arrays. |
|
Check if the elements of two arrays are approximately equal within a tolerance. |
|
Return boolean array showing where the input is complex. |
|
Check if the input is a complex number or an array containing complex elements. |
|
Returns a boolean indicating whether a provided dtype is of a specified kind. |
|
Return a boolean array indicating whether each element of input is finite. |
|
Determine whether elements in |
|
Return a boolean array indicating whether each element of input is infinite. |
|
Returns a boolean array indicating whether each element of input is |
|
Return boolean array indicating whether each element of input is negative infinite. |
|
Return boolean array indicating whether each element of input is positive infinite. |
|
Return boolean array showing where the input is real. |
|
Check if the input is not a complex number or an array containing complex elements. |
|
Return True if the input is a scalar. |
|
Return True if arg1 is equal or lower than arg2 in the type hierarchy. |
|
Check whether or not an object can be iterated over. |
|
Return a multi-dimensional grid (open mesh) from N one-dimensional sequences. |
|
Return a Kaiser window of size M. |
|
Compute the Kronecker product of two input arrays. |
|
Compute the least common multiple of two arrays. |
|
Compute x1 * 2 ** x2 |
|
Return element-wise truth value of |
|
Return element-wise truth value of |
|
Sort a sequence of keys in lexicographic order. |
|
Return evenly-spaced numbers within an interval. |
|
Load JAX arrays from npy files. |
|
Calculate element-wise natural logarithm of the input. |
|
Calculates the base-10 logarithm of x element-wise |
|
Calculates element-wise logarithm of one plus input, |
|
Calculates the base-2 logarithm of |
|
Compute |
|
Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. |
|
Compute the logical AND operation elementwise. |
|
Compute NOT bool(x) element-wise. |
|
Compute the logical OR operation elementwise. |
|
Compute the logical XOR operation elementwise. |
|
Generate logarithmically-spaced values. |
|
Return indices of a mask of an (n, n) array. |
|
Perform a matrix multiplication. |
|
Transpose the last two dimensions of an array. |
|
Return the maximum of the array elements along a given axis. |
|
Return element-wise maximum of the input arrays. |
|
Return the mean of array elements along a given axis. |
|
Construct N-dimensional grid arrays from N 1-dimensional vectors. |
|
Return the minimum of array elements along a given axis. |
|
Return element-wise minimum of the input arrays. |
|
Alias of |
|
Return element-wise fractional and integral parts of the input array. |
|
Move an array axis to a new position |
|
Multiply two arrays element-wise. |
|
Replace NaN and infinite entries in an array. |
|
Return the index of the maximum value of an array, ignoring NaNs. |
|
Return the index of the minimum value of an array, ignoring NaNs. |
|
Cumulative product of elements along an axis, ignoring NaN values. |
|
Cumulative sum of elements along an axis, ignoring NaN values. |
|
Return the maximum of the array elements along a given axis, ignoring NaNs. |
|
Return the minimum of the array elements along a given axis, ignoring NaNs. |
|
Return the product of the array elements along a given axis, ignoring NaNs. |
|
Return the sum of the array elements along a given axis, ignoring NaNs. |
|
Return the number of dimensions of an array. |
|
Return element-wise negative values of the input. |
|
Return element-wise next floating point value after |
|
Return indices of nonzero elements of an array. |
|
Returns element-wise truth value of |
|
Create an array full of ones. |
|
Create an array of ones with the same shape and dtype as an array. |
|
Compute the outer product of two arrays. |
|
Add padding to an array. |
|
Returns a partially-sorted copy of an array. |
|
Permute the axes/dimensions of an array. |
|
Evaluate a function defined piecewise across the domain. |
|
Update array elements based on a mask. |
|
Returns the quotient and remainder of polynomial division. |
|
Returns the product of two polynomials. |
|
Return element-wise positive values of the input. |
|
Alias of |
|
Calculate element-wise base |
|
Alias of |
|
Return product of the array elements over a given axis. |
|
Returns the type to which a binary operation should cast its arguments. |
|
Return the peak-to-peak range along a given axis. |
|
Put elements into an array at given indices. |
|
Convert angles from radians to degrees. |
|
Alias of |
|
Completely flatten a |
|
Convert multi-dimensional indices into flat indices. |
|
Return element-wise real part of the complex argument. |
|
Calculate element-wise reciprocal of the input. |
|
Returns element-wise remainder of the division. |
|
Construct an array from repeated elements. |
|
Return a reshaped copy of an array. |
|
Return a new array with specified shape. |
|
Return the result of applying JAX promotion rules to the inputs. |
|
Rounds the elements of x to the nearest integer |
|
Roll the elements of an array along a specified axis. |
|
Roll the specified axis to a given position. |
|
Returns the roots of a polynomial given the coefficients |
|
Rotate an array by 90 degrees counterclockwise in the plane specified by axes. |
|
Round input evenly to the given number of decimals. |
|
Perform a binary search within a sorted array. |
|
Select values based on a series of conditions. |
|
Alias of |
|
Compute the set difference of two 1D arrays. |
|
Compute the set-wise xor of elements in two arrays. |
|
Return the shape an array. |
|
Return an element-wise indication of sign of the input. |
|
Return the sign bit of array elements. |
|
Compute a trigonometric sine of each element of input. |
|
Calculate the normalized sinc function. |
|
Calculate element-wise hyperbolic sine of input. |
|
Return number of elements along a given axis. |
|
Return a sorted copy of an array. |
|
Return a sorted copy of complex array. |
|
Split an array into sub-arrays. |
|
Calculates element-wise non-negative square root of the input array. |
|
Calculate element-wise square of the input array. |
|
Remove one or more length-1 axes from array |
|
Join arrays along a new axis. |
|
Compute the standard deviation along a given axis. |
|
Subtract two arrays element-wise. |
|
Sum of the elements of the array over a given axis. |
|
Swap two axes of an array. |
|
Take elements from an array. |
|
Compute a trigonometric tangent of each element of input. |
|
Calculate element-wise hyperbolic tangent of input. |
|
Compute the tensor dot product of two N-dimensional arrays. |
|
Construct an array by repeating |
|
Calculate sum of the diagonal of input along the given axes. |
|
Return a transposed version of an N-dimensional array. |
|
Return an array with ones on and below the diagonal and zeros elsewhere. |
|
Return the indices of lower triangle of an array of size |
|
Return the indices of lower triangle of a given array. |
|
Trim leading and/or trailing zeros of the input array. |
|
Return the indices of upper triangle of an array of size |
|
Return the indices of upper triangle of a given array. |
|
Calculates the division of x1 by x2 element-wise |
|
Round input to the nearest integer towards zero. |
|
Compute the set union of two 1D arrays. |
|
Return the unique values from an array. |
|
Return unique values from x, along with indices, inverse indices, and counts. |
|
Return unique values from x, along with counts. |
|
Return unique values from x, along with indices, inverse indices, and counts. |
|
Return unique values from x, along with indices, inverse indices, and counts. |
|
Convert flat indices into multi-dimensional indices. |
|
Unwrap a periodic signal. |
|
Compute the variance along a given axis. |
|
Perform a conjugate multiplication of two 1D vectors. |
|
Perform a conjugate multiplication of two batched vectors. |
|
Define a vectorized function with broadcasting. |
|
Split an array into sub-arrays vertically. |
|
Vertically stack arrays. |
|
Select elements from two arrays based on a condition. |
|
Create an array full of zeros. |
|
Create an array full of zeros with the same shape and dtype as an array. |
Classes
|
Block array class. |
- class scico.numpy.BlockArray(inputs)¶
Bases:
objectBlock array class.
A block array provides a way to combine arrays of different shapes into a single object for use with other SCICO classes. For further information, see the detailed BlockArray documentation.
Example
>>> x = snp.blockarray(( ... [[1, 3, 7], ... [2, 2, 1]], ... [2, 4, 8] ... )) >>> x.shape ((2, 3), (3,)) >>> snp.sum(x) Array(30, dtype=int32)
- property T¶
Compute the all-axis array transpose.
Refer to
jax.numpy.transposefor details.
- all(axis=None, out=None, keepdims=False, *, where=None)¶
Test whether all array elements along a given axis evaluate to True.
Refer to
jax.numpy.allfor the full documentation.- Return type:
- any(axis=None, out=None, keepdims=False, *, where=None)¶
Test whether any array elements along a given axis evaluate to True.
Refer to
jax.numpy.anyfor the full documentation.- Return type:
- argmax(axis=None, out=None, keepdims=None)¶
Return the index of the maximum value.
Refer to
jax.numpy.argmaxfor the full documentation.- Return type:
- argmin(axis=None, out=None, keepdims=None)¶
Return the index of the minimum value.
Refer to
jax.numpy.argminfor the full documentation.- Return type:
- argpartition(kth, axis=-1)¶
Return the indices that partially sort the array.
Refer to
jax.numpy.argpartitionfor the full documentation.- Return type:
- argsort(axis=-1, *, kind=None, order=None, stable=True, descending=False)¶
Return the indices that sort the array.
Refer to
jax.numpy.argsortfor the full documentation.- Return type:
- astype(dtype, copy=False, device=None)¶
Copy the array and cast to a specified dtype.
This is implemented via
jax.lax.convert_element_type, which may have slightly different behavior thannumpy.ndarray.astypein some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.- Return type:
- block_until_ready¶
(self) -> object
- static blockarray(iterable)¶
Construct a
BlockArrayfrom a list or tuple of existing array-like.
- byteswap()¶
Swap the bytes of the array elements.
This switches between a little-endian and big-endian data representation.
- Return type:
- Returns:
An array with the same dtype as
self, with underlying bytes of each entry reversed.
Examples
>>> import jax.numpy as jnp >>> x = jnp.arange(5, dtype='int32') >>> x Array([0, 1, 2, 3, 4], dtype=int32) >>> x.byteswap() Array([ 0, 16777216, 33554432, 50331648, 67108864], dtype=int32)
When the resulting bytes are viewed as a big-endian dtype (possible in NumPy, but not in JAX) they represent the original values:
>>> import numpy as np >>> np.array(x.byteswap()).view('>i4') # view as big-endian array([0, 1, 2, 3, 4], dtype='>i4')
Calling byteswap twice will return the original array:
>>> x.byteswap().byteswap() Array([0, 1, 2, 3, 4], dtype=int32)
- choose(choices, out=None, mode='raise')¶
Construct an array choosing from elements of multiple arrays.
Refer to
jax.numpy.choosefor the full documentation.- Return type:
- clip(min=None, max=None)¶
Return an array whose values are limited to a specified range.
Refer to
jax.numpy.clipfor full documentation.- Return type:
- clone¶
(self) -> Array
- compress(condition, axis=None, *, out=None, size=None, fill_value=0)¶
Return selected slices of this array along given axis.
Refer to
jax.numpy.compressfor full documentation.- Return type:
- conj()¶
Return the complex conjugate of the array.
Refer to
jax.numpy.conjfor the full documentation.- Return type:
- conjugate()¶
Return the complex conjugate of the array.
Refer to
jax.numpy.conjugatefor the full documentation.- Return type:
- copy()¶
Return a copy of the array.
Refer to
jax.numpy.copyfor the full documentation.- Return type:
- cumprod(axis=None, dtype=None, out=None)¶
Return the cumulative product of the array.
Refer to
jax.numpy.cumprodfor the full documentation.- Return type:
- cumsum(axis=None, dtype=None, out=None)¶
Return the cumulative sum of the array.
Refer to
jax.numpy.cumsumfor the full documentation.- Return type:
- delete¶
(self) -> None
- diagonal(offset=0, axis1=0, axis2=1)¶
Return the specified diagonal from the array.
Refer to
jax.numpy.diagonalfor the full documentation.- Return type:
- dot(b, *, precision=None, preferred_element_type=None)¶
Compute the dot product of two arrays.
Refer to
jax.numpy.dotfor the full documentation.- Return type:
- property dtype¶
Return the dtype of the blocks, which must currently be homogeneous.
This allows snp.zeros(x.shape, x.dtype) to work without a mechanism to handle lists of dtypes.
- flatten(order='C', *, out_sharding=None)¶
Flatten array into a 1-dimensional shape.
Refer to
jax.numpy.ravelfor the full documentation.- Return type:
- property global_shards¶
Returns list of all Shards of the Array across all devices.
The result includes shards that are not addressable by the current process. If a Shard is not addressable, then its data will be None.
- property imag¶
Return the imaginary part of the array.
- is_deleted¶
(self) -> bool
- property is_fully_addressable¶
Is this Array fully addressable?
A jax.Array is fully addressable if the current process can address all of the devices named in the
Sharding.is_fully_addressableis equivalent to “is_local” in multi-process JAX.Note that fully replicated is not equal to fully addressable i.e. a jax.Array which is fully replicated can span across multiple hosts and is not fully addressable.
- is_ready¶
(self) -> bool
- item(*args)¶
Copy an element of an array to a standard Python scalar and return it.
- property itemsize¶
Length of one array element in bytes.
- property mT¶
Compute the (batched) matrix transpose.
Refer to
jax.numpy.matrix_transposefor details.
- max(axis=None, out=None, keepdims=False, initial=None, where=None)¶
Return the maximum of array elements along a given axis.
Refer to
jax.numpy.maxfor the full documentation.- Return type:
- mean(axis=None, dtype=None, out=None, keepdims=False, *, where=None)¶
Return the mean of array elements along a given axis.
Refer to
jax.numpy.meanfor the full documentation.- Return type:
- min(axis=None, out=None, keepdims=False, initial=None, where=None)¶
Return the minimum of array elements along a given axis.
Refer to
jax.numpy.minfor the full documentation.- Return type:
- property nbytes¶
Total bytes consumed by the elements of the array.
- nonzero(*, fill_value=None, size=None)¶
Return indices of nonzero elements of an array.
Refer to
jax.numpy.nonzerofor the full documentation.
- on_device_size_in_bytes¶
(self) -> int
- platform¶
(self) -> str
- prod(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)¶
Return product of the array elements over a given axis.
Refer to
jax.numpy.prodfor the full documentation.- Return type:
- ptp(axis=None, out=None, keepdims=False)¶
Return the peak-to-peak range along a given axis.
Refer to
jax.numpy.ptpfor the full documentation.- Return type:
- ravel(order='C', *, out_sharding=None)¶
Flatten array into a 1-dimensional shape.
Refer to
jax.numpy.ravelfor the full documentation.- Return type:
- property real¶
Return the real part of the array.
- repeat(repeats, axis=None, *, total_repeat_length=None, out_sharding=None)¶
Construct an array from repeated elements.
Refer to
jax.numpy.repeatfor the full documentation.- Return type:
- reshape(*args, order='C', out_sharding=None)¶
Returns an array containing the same data with a new shape.
Refer to
jax.numpy.reshapefor full documentation.- Return type:
- round(decimals=0, out=None)¶
Round array elements to a given decimal.
Refer to
jax.numpy.roundfor full documentation.- Return type:
- searchsorted(v, side='left', sorter=None, *, method='scan')¶
Perform a binary search within a sorted array.
Refer to
jax.numpy.searchsortedfor full documentation.- Return type:
- sort(axis=-1, *, kind=None, order=None, stable=True, descending=False)¶
Return a sorted copy of an array.
Refer to
jax.numpy.sortfor full documentation.- Return type:
- squeeze(axis=None)¶
Remove one or more length-1 axes from array.
Refer to
jax.numpy.squeezefor full documentation.- Return type:
- stack(axis=0)[source]¶
Collapse a
BlockArraytojax.Array.Collapse a
BlockArraytojax.Arrayby stacking the blocks on axis axis.- Parameters:
axis – Index of new axis on which blocks are to be stacked.
- Returns:
A
jax.Arrayobtained by stacking.- Raises:
ValueError – When called on a
BlockArraythat is not stackable.
- std(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, correction=None)¶
Compute the standard deviation along a given axis.
Refer to
jax.numpy.stdfor full documentation.- Return type:
- sum(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)¶
Sum of the elements of the array over a given axis.
Refer to
jax.numpy.sumfor full documentation.- Return type:
- swapaxes(axis1, axis2)¶
Swap two axes of an array.
Refer to
jax.numpy.swapaxesfor full documentation.- Return type:
- take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)¶
Take elements from an array.
Refer to
jax.numpy.takefor full documentation.- Return type:
- to_device(device, *, stream=None)¶
Return a copy of the array on the specified device
- trace(offset=0, axis1=0, axis2=1, dtype=None, out=None)¶
Return the sum along the diagonal.
Refer to
jax.numpy.tracefor full documentation.- Return type:
- transpose(*args)¶
Returns a copy of the array with axes transposed.
Refer to
jax.numpy.transposefor full documentation.- Return type:
- unsafe_buffer_pointer¶
(self) -> int
- var(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, correction=None)¶
Compute the variance along a given axis.
Refer to
jax.numpy.varfor full documentation.- Return type:
- view(dtype=None, type=None)¶
Return a bitwise copy of the array, viewed as a new dtype.
This is fuller-featured wrapper around
jax.lax.bitcast_convert_type.If the source and target dtype have the same bitwidth, the result has the same shape as the input array. If the bitwidth of the target dtype is different from the source, the size of the last axis of the result is adjusted accordingly.
>>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape (1, 2, 6) >>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape (1, 2, 2)
Conversions involving booleans are not well-defined in all situations. With regards to the shape of result as explained above, booleans are treated as having a bitwidth of 8. However, when converting to a boolean array, the input should only contain 0 or 1 bytes. Otherwise, results may be unpredictable or may change depending on how the result is used.
This conversion is guaranteed and safe:
>>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_) Array([ True, False, True], dtype=bool)
However, there are no guarantees about the results of any expression involving a view such as this:
jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_). In particular, the results may change between JAX releases and depending on the platform. To safely convert such an array to a boolean array, compare it with 0:>>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 Array([ True, True, False], dtype=bool)
- Parameters:
- Return type:
- Returns:
The array, viewed as the new dtype. Unlike NumPy, the array may or may not be a copy of the input array.
- scico.numpy.blockarray(iterable)¶
Construct a
BlockArrayfrom a list or tuple of existing array-like.
- scico.numpy.ravel(ba)[source]¶
Completely flatten a
BlockArrayinto a singleArray.When called on an
Array, flattens the array.- Parameters:
ba (
Union[Array,BlockArray]) – TheBlockArrayto flatten.- Return type:
- Returns:
ba flattened into a single
Array.
- scico.numpy.abs(x, /)¶
Alias of
jax.numpy.absolute.- Return type:
- scico.numpy.absolute(x, /)¶
Calculate the absolute value element-wise.
JAX implementation of
numpy.absolute.This is the same function as
jax.numpy.abs.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array- Return type:
- Returns:
An array-like object containing the absolute value of each element in
x, with the same shape asx. For complex valued input, \(a + ib\), the absolute value is \(\sqrt{a^2+b^2}\).
Examples
>>> x1 = jnp.array([5, -2, 0, 12]) >>> jnp.absolute(x1) Array([ 5, 2, 0, 12], dtype=int32)
>>> x2 = jnp.array([[ 8, -3, 1],[ 0, 9, -6]]) >>> jnp.absolute(x2) Array([[8, 3, 1], [0, 9, 6]], dtype=int32)
>>> x3 = jnp.array([8 + 15j, 3 - 4j, -5 + 0j]) >>> jnp.absolute(x3) Array([17., 5., 5.], dtype=float32)
- scico.numpy.add(*args: ArrayLike, out: None = None, where: None = None) Any¶
Add two arrays element-wise.
JAX implementation of
numpy.add. This is a universal function, and supports the additional APIs described atjax.numpy.ufunc. This function provides the implementation of the+operator for JAX arrays.- Parameters:
x – arrays to add. Must be broadcastable to a common shape.
y – arrays to add. Must be broadcastable to a common shape.
- Returns:
Array containing the result of the element-wise addition.
Examples
Calling
addexplicitly:>>> x = jnp.arange(4) >>> jnp.add(x, 10) Array([10, 11, 12, 13], dtype=int32)
Calling
addvia the+operator:>>> x + 10 Array([10, 11, 12, 13], dtype=int32)
- scico.numpy.all(a, axis=None, out=None, keepdims=False, *, where=None)¶
Test whether all array elements along a given axis evaluate to True.
JAX implementation of
numpy.all.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or array, default=None. Axis along which to be tested. If None, tests along all the axes.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array of boolean dtype, default=None. The elements to be used in the test. Array should be broadcast compatible to the input.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of boolean values.
Examples
By default,
jnp.alltests for True values along all the axes.>>> x = jnp.array([[True, True, True, False], ... [True, False, True, False], ... [True, True, False, False]]) >>> jnp.all(x) Array(False, dtype=bool)
If
axis=0, tests for True values along axis 0.>>> jnp.all(x, axis=0) Array([ True, False, False, False], dtype=bool)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.all(x, axis=0, keepdims=True) Array([[ True, False, False, False]], dtype=bool)
To include specific elements in testing for True values, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.all(x, axis=0, keepdims=True, where=where) Array([[ True, True, False, False]], dtype=bool)
- scico.numpy.allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)¶
Check if two arrays are element-wise approximately equal within a tolerance.
JAX implementation of
numpy.allclose.Essentially this function evaluates the following condition:
\[|a - b| \le \mathtt{atol} + \mathtt{rtol} * |b|\]jnp.infinawill be considered equal tojnp.infinb.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first input array to compare.b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second input array to compare.rtol (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – relative tolerance used for approximate equality. Default = 1e-05.atol (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – absolute tolerance used for approximate equality. Default = 1e-08.equal_nan (
bool) – Boolean. IfTrue, NaNs inawill be considered equal to NaNs inb. Default isFalse.
- Return type:
- Returns:
Boolean scalar array indicating whether the input arrays are element-wise approximately equal within the specified tolerances.
See also
Examples
>>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]), jnp.array([1e6, 2e6, 3e7])) Array(False, dtype=bool) >>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]), ... jnp.array([1.00008e6, 2.00008e7, 3.00008e8]), rtol=1e3) Array(True, dtype=bool) >>> jnp.allclose(jnp.array([1e6, 2e6, 3e6]), ... jnp.array([1.00001e6, 2.00002e6, 3.00009e6]), atol=1e3) Array(True, dtype=bool) >>> jnp.allclose(jnp.array([jnp.nan, 1, 2]), ... jnp.array([jnp.nan, 1, 2]), equal_nan=True) Array(True, dtype=bool)
- scico.numpy.amax(a, axis=None, out=None, keepdims=False, initial=None, where=None)¶
Alias of
jax.numpy.max.- Return type:
- scico.numpy.amin(a, axis=None, out=None, keepdims=False, initial=None, where=None)¶
Alias of
jax.numpy.min.- Return type:
- scico.numpy.angle(z, deg=False)¶
Return the angle of a complex valued number or array.
JAX implementation of
numpy.angle.- Parameters:
- Return type:
- Returns:
An array of counterclockwise angle of each element of
z, with the same shape aszof dtype float.
Examples
If
zis a number>>> z1 = 2+3j >>> jnp.angle(z1) Array(0.98279375, dtype=float32, weak_type=True)
If
zis an array>>> z2 = jnp.array([[1+3j, 2-5j], ... [4-3j, 3+2j]]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.angle(z2)) [[ 1.25 -1.19] [-0.64 0.59]]
If
deg=True.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.angle(z2, deg=True)) [[ 71.57 -68.2 ] [-36.87 33.69]]
- scico.numpy.any(a, axis=None, out=None, keepdims=False, *, where=None)¶
Test whether any of the array elements along a given axis evaluate to True.
JAX implementation of
numpy.any.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or array, default=None. Axis along which to be tested. If None, tests along all the axes.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array of boolean dtype, default=None. The elements to be used in the test. Array should be broadcast compatible to the input.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of boolean values.
Examples
By default,
jnp.anytests along all the axes.>>> x = jnp.array([[True, True, True, False], ... [True, False, True, False], ... [True, True, False, False]]) >>> jnp.any(x) Array(True, dtype=bool)
If
axis=0, tests along axis 0.>>> jnp.any(x, axis=0) Array([ True, True, True, False], dtype=bool)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.any(x, axis=0, keepdims=True) Array([[ True, True, True, False]], dtype=bool)
To include specific elements in testing for True values, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0], ... [0, 1, 0, 1], ... [1, 0, 1, 0]], dtype=bool) >>> jnp.any(x, axis=0, keepdims=True, where=where) Array([[ True, False, True, False]], dtype=bool)
- scico.numpy.append(arr, values, axis=None)¶
Return a new array with values appended to the end of the original array.
JAX implementation of
numpy.append.- Parameters:
arr (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – original array.values (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – values to be appended to the array. Thevaluesmust have the same number of dimensions asarr, and all dimensions must match except in the specified axis.axis (
int|None) – axis along which to append values. If None (default), botharrandvalueswill be flattened before appending.
- Return type:
- Returns:
A new array with values appended to
arr.
See also
Examples
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.append(a, b) Array([1, 2, 3, 4, 5, 6], dtype=int32)
Appending along a specific axis:
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([[5, 6]]) >>> jnp.append(a, b, axis=0) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
Appending along a trailing axis:
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> b = jnp.array([[7], [8]]) >>> jnp.append(a, b, axis=1) Array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype=int32)
- scico.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)¶
Apply a function to 1D array slices along an axis.
JAX implementation of
numpy.apply_along_axis. While NumPy implements this iteratively, JAX implements this viajax.vmap, and sofunc1dmust be compatible withvmap.- Parameters:
func1d (
Callable) – a callable function with signaturefunc1d(arr, /, *args, **kwargs)where*argsand**kwargsare the additional positional and keyword arguments passed toapply_along_axis.axis (
int) – integer axis along which to apply the function.arr (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – the array over which to apply the function.args – additional positional and keyword arguments are passed through to
func1d.kwargs – additional positional and keyword arguments are passed through to
func1d.
- Return type:
- Returns:
The result of
func1dapplied along the specified axis.
See also
jax.vmap: a more direct way to create a vectorized version of a function.jax.numpy.apply_over_axes: repeatedly apply a function over multiple axes.jax.numpy.vectorize: create a vectorized version of a function.
Examples
A simple example in two dimensions, where the function is applied either row-wise or column-wise:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> def func1d(x): ... return jnp.sum(x ** 2) >>> jnp.apply_along_axis(func1d, 0, x) Array([17, 29, 45], dtype=int32) >>> jnp.apply_along_axis(func1d, 1, x) Array([14, 77], dtype=int32)
For 2D inputs, this can be equivalently expressed using
jax.vmap, though note that vmap specifies the mapped axis rather than the applied axis:>>> jax.vmap(func1d, in_axes=1)(x) # same as applying along axis 0 Array([17, 29, 45], dtype=int32) >>> jax.vmap(func1d, in_axes=0)(x) # same as applying along axis 1 Array([14, 77], dtype=int32)
For 3D inputs,
apply_along_axisis equivalent to mapping over two dimensions:>>> x_3d = jnp.arange(24).reshape(2, 3, 4) >>> jnp.apply_along_axis(func1d, 2, x_3d) Array([[ 14, 126, 366], [ 734, 1230, 1854]], dtype=int32) >>> jax.vmap(jax.vmap(func1d))(x_3d) Array([[ 14, 126, 366], [ 734, 1230, 1854]], dtype=int32)
The applied function may also take arbitrary positional or keyword arguments, which should be passed directly as additional arguments to
apply_along_axis:>>> def func1d(x, exponent): ... return jnp.sum(x ** exponent) >>> jnp.apply_along_axis(func1d, 0, x, exponent=3) Array([ 65, 133, 243], dtype=int32)
- scico.numpy.apply_over_axes(func, a, axes)¶
Apply a function repeatedly over specified axes.
JAX implementation of
numpy.apply_over_axes.- Parameters:
func (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex],int],Array]) – the function to apply, with signaturefunc(Array, int) -> Array, and wherey = func(x, axis)must satisfyy.ndim in [x.ndim, x.ndim - 1].a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array over which to apply the function.axes (
Sequence[int]) – the sequence of axes over which to apply the function.
- Return type:
- Returns:
An N-dimensional array containing the result of the repeated function application.
See also
jax.numpy.apply_along_axis: apply a 1D function along a single axis.
Examples
This function is designed to have similar semantics to typical associative
jax.numpyreductions over one or more axes withkeepdims=True. For example:>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]])
>>> jnp.apply_over_axes(jnp.sum, x, [0]) Array([[5, 7, 9]], dtype=int32) >>> jnp.sum(x, [0], keepdims=True) Array([[5, 7, 9]], dtype=int32)
>>> jnp.apply_over_axes(jnp.min, x, [1]) Array([[1], [4]], dtype=int32) >>> jnp.min(x, [1], keepdims=True) Array([[1], [4]], dtype=int32)
>>> jnp.apply_over_axes(jnp.prod, x, [0, 1]) Array([[720]], dtype=int32) >>> jnp.prod(x, [0, 1], keepdims=True) Array([[720]], dtype=int32)
- scico.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None, out_sharding=None)¶
Create an array of evenly-spaced values.
JAX implementation of
numpy.arange, implemented in terms ofjax.lax.iota.Similar to Python’s
rangefunction, this can be called with a few different positional signatures:jnp.arange(stop): generate values from 0 tostop, stepping by 1.jnp.arange(start, stop): generate values fromstarttostop, stepping by 1.jnp.arange(start, stop, step): generate values fromstarttostop, stepping bystep.
Like with Python’s
rangefunction, the starting value is inclusive, and the stop value is exclusive.- Parameters:
start (
Union[Array,ndarray,bool,number,bool,int,float,complex,Any]) – start of the interval, inclusive.stop (
Union[Array,ndarray,bool,number,bool,int,float,complex,Any,None]) – optional end of the interval, exclusive. If not specified, then(start, stop) = (0, start)step (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional step size for the interval. Default = 1.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype for the returned array; if not specified it will be determined via type promotion of start, stop, and step.device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed.out_sharding (
NamedSharding|P|None) – (optional)NamedShardingorPto which the created array will be committed. Use out_sharding argument, if using explicit sharding (https://docs.jax.dev/en/latest/parallel.html)
- Return type:
- Returns:
Array of evenly-spaced values from
starttostop, separated bystep.
Note
Using
arangewith a floating-pointstepargument can lead to unexpected results due to accumulation of floating-point errors, especially with lower-precision data types likefloat8_*andbfloat16. To avoid precision errors, consider generating a range of integers, and scaling it to the desired range. For example, instead of this:jnp.arange(-1, 1, 0.01, dtype='bfloat16')
it can be more accurate to generate a sequence of integers, and scale them:
(jnp.arange(-100, 100) * 0.01).astype('bfloat16')
Examples
Single-argument version specifies only the
stopvalue:>>> jnp.arange(4) Array([0, 1, 2, 3], dtype=int32)
Passing a floating-point
stopvalue leads to a floating-point result:>>> jnp.arange(4.0) Array([0., 1., 2., 3.], dtype=float32)
Two-argument version specifies
startandstop, withstep=1:>>> jnp.arange(1, 6) Array([1, 2, 3, 4, 5], dtype=int32)
Three-argument version specifies
start,stop, andstep:>>> jnp.arange(0, 2, 0.5) Array([0. , 0.5, 1. , 1.5], dtype=float32)
See also
jax.numpy.linspace: generate a fixed number of evenly-spaced values.jax.lax.iota: directly generate integer sequences in XLA.
- scico.numpy.arccos(x, /)¶
Compute element-wise inverse of trigonometric cosine of input.
JAX implementation of
numpy.arccos.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the inverse trigonometric cosine of each element of
xin radians in the range[0, pi], promoting to inexact dtype.
Note
jnp.arccosreturnsnanwhenxis real-valued and not in the closed interval[-1, 1].jnp.arccosfollows the branch cut convention ofnumpy.arccosfor complex inputs.
See also
jax.numpy.cos: Computes a trigonometric cosine of each element of input.jax.numpy.arcsinandjax.numpy.asin: Computes the inverse of trigonometric sine of each element of input.jax.numpy.arctanandjax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.
Examples
>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arccos(x) Array([ nan, 3.142, 2.094, 1.571, 1.047, 0. , nan], dtype=float32)
For complex inputs:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arccos(4-1j) Array(0.252+2.097j, dtype=complex64, weak_type=True)
- scico.numpy.arccosh(x, /)¶
Calculate element-wise inverse of hyperbolic cosine of input.
JAX implementation of
numpy.arccosh.The inverse of hyperbolic cosine is defined by:
\[arccosh(x) = \ln(x + \sqrt{x^2 - 1})\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array of same shape as
xcontaining the inverse of hyperbolic cosine of each element ofx, promoting to inexact dtype.
Note
jnp.arccoshreturnsnanfor real-values in the range[-inf, 1).jnp.arccoshfollows the branch cut convention ofnumpy.arccoshfor complex inputs.
See also
jax.numpy.cosh: Computes the element-wise hyperbolic cosine of the input.jax.numpy.arcsinh: Computes the element-wise inverse of hyperbolic sine of the input.jax.numpy.arctanh: Computes the element-wise inverse of hyperbolic tangent of the input.
Examples
>>> x = jnp.array([[1, 3, -4], ... [-5, 2, 7]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arccosh(x) Array([[0. , 1.763, nan], [ nan, 1.317, 2.634]], dtype=float32)
For complex-valued input:
>>> x1 = jnp.array([-jnp.inf+0j, 1+2j, -5+0j]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arccosh(x1) Array([ inf+3.142j, 1.529+1.144j, 2.292+3.142j], dtype=complex64)
- scico.numpy.arcsin(x, /)¶
Compute element-wise inverse of trigonometric sine of input.
JAX implementation of
numpy.arcsin.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the inverse trigonometric sine of each element of
xin radians in the range[-pi/2, pi/2], promoting to inexact dtype.
Note
jnp.arcsinreturnsnanwhenxis real-valued and not in the closed interval[-1, 1].jnp.arcsinfollows the branch cut convention ofnumpy.arcsinfor complex inputs.
See also
jax.numpy.sin: Computes a trigonometric sine of each element of input.jax.numpy.arccosandjax.numpy.acos: Computes the inverse of trigonometric cosine of each element of input.jax.numpy.arctanandjax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.
Examples
>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arcsin(x) Array([ nan, -1.571, -0.524, 0. , 0.524, 1.571, nan], dtype=float32)
For complex-valued inputs:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arcsin(3+4j) Array(0.634+2.306j, dtype=complex64, weak_type=True)
- scico.numpy.arcsinh(x, /)¶
Calculate element-wise inverse of hyperbolic sine of input.
JAX implementation of
numpy.arcsinh.The inverse of hyperbolic sine is defined by:
\[arcsinh(x) = \ln(x + \sqrt{1 + x^2})\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array of same shape as
xcontaining the inverse of hyperbolic sine of each element ofx, promoting to inexact dtype.
Note
jnp.arcsinhreturnsnanfor values outside the range(-inf, inf).jnp.arcsinhfollows the branch cut convention ofnumpy.arcsinhfor complex inputs.
See also
jax.numpy.sinh: Computes the element-wise hyperbolic sine of the input.jax.numpy.arccosh: Computes the element-wise inverse of hyperbolic cosine of the input.jax.numpy.arctanh: Computes the element-wise inverse of hyperbolic tangent of the input.
Examples
>>> x = jnp.array([[-2, 3, 1], ... [4, 9, -5]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arcsinh(x) Array([[-1.444, 1.818, 0.881], [ 2.095, 2.893, -2.312]], dtype=float32)
For complex-valued inputs:
>>> x1 = jnp.array([4-3j, 2j]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arcsinh(x1) Array([2.306-0.634j, 1.317+1.571j], dtype=complex64)
- scico.numpy.arctan(x, /)¶
Compute element-wise inverse of trigonometric tangent of input.
JAX implement of
numpy.arctan.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the inverse trigonometric tangent of each element
xin radians in the range[-pi/2, pi/2], promoting to inexact dtype.
Note
jnp.arctanfollows the branch cut convention ofnumpy.arctanfor complex inputs.See also
jax.numpy.tan: Computes a trigonometric tangent of each element of input.jax.numpy.arcsinandjax.numpy.asin: Computes the inverse of trigonometric sine of each element of input.jax.numpy.arccosandjax.numpy.atan: Computes the inverse of trigonometric cosine of each element of input.
Examples
>>> x = jnp.array([-jnp.inf, -20, -1, 0, 1, 20, jnp.inf]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arctan(x) Array([-1.571, -1.521, -0.785, 0. , 0.785, 1.521, 1.571], dtype=float32)
For complex-valued inputs:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arctan(2+7j) Array(1.532+0.133j, dtype=complex64, weak_type=True)
- scico.numpy.arctan2(x1, x2, /)¶
Compute the arctangent of x1/x2, choosing the correct quadrant.
JAX implementation of
numpy.arctan2- Parameters:
- Return type:
- Returns:
The elementwise arctangent of x1 / x2, tracking the correct quadrant.
See also
jax.numpy.tan: compute the tangent of an anglejax.numpy.atan2: the array API version of this function.
Examples
Consider a sequence of angles in radians between 0 and \(2\pi\):
>>> theta = jnp.linspace(-jnp.pi, jnp.pi, 9) >>> with jnp.printoptions(precision=2, suppress=True): ... print(theta) [-3.14 -2.36 -1.57 -0.79 0. 0.79 1.57 2.36 3.14]
These angles can equivalently be represented by
(x, y)coordinates on a unit circle:>>> x, y = jnp.cos(theta), jnp.sin(theta)
To reconstruct the input angle, we might be tempted to use the identity \(\tan(\theta) = y / x\), and compute \(\theta = \tan^{-1}(y/x)\). Unfortunately, this does not recover the input angle:
>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.arctan(y / x)) [-0. 0.79 1.57 -0.79 0. 0.79 1.57 -0.79 0. ]
The problem is that \(y/x\) contains some ambiguity: although \((y, x) = (-1, -1)\) and \((y, x) = (1, 1)\) represent different points in Cartesian space, in both cases \(y / x = 1\), and so the simple arctan approach loses information about which quadrant the angle lies in.
arctan2is built to address this:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.arctan2(y, x)) [ 3.14 -2.36 -1.57 -0.79 0. 0.79 1.57 2.36 -3.14]
The results match the input
theta, except at the endpoints where \(+\pi\) and \(-\pi\) represent indistinguishable points on the unit circle. By convention,arctan2always returns values between \(-\pi\) and \(+\pi\) inclusive.
- scico.numpy.arctanh(x, /)¶
Calculate element-wise inverse of hyperbolic tangent of input.
JAX implementation of
numpy.arctanh.The inverse of hyperbolic tangent is defined by:
\[arctanh(x) = \frac{1}{2} [\ln(1 + x) - \ln(1 - x)]\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array of same shape as
xcontaining the inverse of hyperbolic tangent of each element ofx, promoting to inexact dtype.
Note
jnp.arctanhreturnsnanfor real-values outside the range[-1, 1].jnp.arctanhfollows the branch cut convention ofnumpy.arctanhfor complex inputs.
See also
jax.numpy.tanh: Computes the element-wise hyperbolic tangent of the input.jax.numpy.arcsinh: Computes the element-wise inverse of hyperbolic sine of the input.jax.numpy.arccosh: Computes the element-wise inverse of hyperbolic cosine of the input.
Examples
>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arctanh(x) Array([ nan, -inf, -0.549, 0. , 0.549, inf, nan], dtype=float32)
For complex-valued input:
>>> x1 = jnp.array([-2+0j, 3+0j, 4-1j]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.arctanh(x1) Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64)
- scico.numpy.argmax(a, axis=None, out=None, keepdims=None)¶
Return the index of the maximum value of an array.
JAX implementation of
numpy.argmax.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input arrayaxis (
int|None) – optional integer specifying the axis along which to find the maximum value. Ifaxisis not specified,awill be flattened.out (
None) – unused by JAXkeepdims (
bool|None) – if True, then return an array with the same number of dimensions asa.
- Return type:
- Returns:
an array containing the index of the maximum value along the specified axis.
See also
jax.numpy.argmin: return the index of the minimum value.jax.numpy.nanargmax: computeargmaxwhile ignoring NaN values.
Note
When the maximum value occurs more than once along a particular axis, the smallest index is returned.
Examples
>>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) >>> jnp.argmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.argmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)
- scico.numpy.argmin(a, axis=None, out=None, keepdims=None)¶
Return the index of the minimum value of an array.
JAX implementation of
numpy.argmin.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input arrayaxis (
int|None) – optional integer specifying the axis along which to find the minimum value. Ifaxisis not specified,awill be flattened.out (
None) – unused by JAXkeepdims (
bool|None) – if True, then return an array with the same number of dimensions asa.
- Return type:
- Returns:
an array containing the index of the minimum value along the specified axis.
Note
When the minimum value occurs more than once along a particular axis, the smallest index is returned.
See also
jax.numpy.argmax: return the index of the maximum value.jax.numpy.nanargmin: computeargminwhile ignoring NaN values.
Examples
>>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmin(x) Array(0, dtype=int32)
>>> x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) >>> jnp.argmin(x, axis=1) Array([0, 2], dtype=int32)
>>> jnp.argmin(x, axis=1, keepdims=True) Array([[0], [2]], dtype=int32)
- scico.numpy.argsort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False, dtype=None)¶
Return indices that sort an array.
JAX implementation of
numpy.argsort.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array to sortaxis (
int|None) – integer axis along which to sort. Defaults to-1, i.e. the last axis. IfNone, thenais flattened before being sorted.stable (
bool) – boolean specifying whether a stable sort should be used. Default=True.descending (
bool) – boolean specifying whether to sort in descending order. Default=False.kind (
None) – deprecated; instead specify sort algorithm using stable=True or stable=False.order (
None) – not supported by JAXdtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally specify the dtype of the resulting indices. If not specified, the default integer dtype will be used.
- Return type:
- Returns:
Array of indices that sort an array. Returned array will be of shape
a.shape(ifaxisis an integer) or of shape(a.size,)(ifaxisis None).
Examples
Simple 1-dimensional sort
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> indices = jnp.argsort(x) >>> indices Array([0, 5, 4, 1, 3, 2], dtype=int32) >>> x[indices] Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
>>> x = jnp.array([[2, 1, 3], ... [6, 4, 3]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 0, 2], [2, 1, 0]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
See also
jax.numpy.sort: return sorted values directly.jax.numpy.lexsort: lexicographical sort of multiple arrays.jax.lax.sort: lower-level function wrapping XLA’s Sort operator.
- scico.numpy.argwhere(a, *, size=None, fill_value=None)¶
Find the indices of nonzero array elements
JAX implementation of
numpy.argwhere.jnp.argwhere(x)is essentially equivalent tojnp.column_stack(jnp.nonzero(x))with special handling for zero-dimensional (i.e. scalar) inputs.Because the size of the output of
argwhereis data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsizeargument, which specifies the size of the leading dimension of the output - it must be specified statically forjnp.argwhereto be compiled with non-static operands. Seejax.numpy.nonzerofor a full discussion ofsizeand its semantics.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array for which to find nonzero elementssize (
int|None) – optional integer specifying statically the number of expected nonzero elements. This must be specified in order to useargwherewithin JAX transformations likejax.jit. Seejax.numpy.nonzerofor more information.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional array specifying the fill value whensizeis specified. Seejax.numpy.nonzerofor more information.
- Return type:
- Returns:
a two-dimensional array of shape
[size, x.ndim]. Ifsizeis not specified as an argument, it is equal to the number of nonzero elements inx.
See also
Examples
Two-dimensional array:
>>> x = jnp.array([[1, 0, 2], ... [0, 3, 0]]) >>> jnp.argwhere(x) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Equivalent computation using
jax.numpy.column_stackandjax.numpy.nonzero:>>> jnp.column_stack(jnp.nonzero(x)) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Special case for zero-dimensional (i.e. scalar) inputs:
>>> jnp.argwhere(1) Array([], shape=(1, 0), dtype=int32) >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32)
- scico.numpy.around(a, decimals=0, out=None)¶
Alias of
jax.numpy.round- Return type:
- scico.numpy.array(object, dtype=None, *args, copy=True, order='K', ndmin=0, device=None, out_sharding=None)¶
Convert an object to a JAX array.
JAX implementation of
numpy.array.- Parameters:
object (
Any) – an object that is convertible to an array. This includes JAX arrays, NumPy arrays, Python scalars, Python collections like lists and tuples, objects with a__jax_array__method, and objects supporting the Python buffer protocol.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally specify the dtype of the output array. If not specified it will be inferred from the input.copy (
bool) – specify whether to force a copy of the input. Default: True.ndmin (
int) – integer specifying the minimum number of dimensions in the output array.device (
Device|Sharding|None) – optionalDeviceorShardingto which the created array will be committed.out_sharding (
NamedSharding|P|None) – (optional)PartitionSpecorNamedShardingrepresenting the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying bothout_shardinganddevicewill result in an error.
- Return type:
- Returns:
A JAX array constructed from the input.
See also
jax.numpy.asarray: like array, but by default only copies when necessary.jax.numpy.from_dlpack: construct a JAX array from an object that implements the dlpack interface.jax.numpy.frombuffer: construct a JAX array from an object that implements the buffer interface.
Examples
Constructing JAX arrays from Python scalars:
>>> jnp.array(True) Array(True, dtype=bool) >>> jnp.array(42) Array(42, dtype=int32, weak_type=True) >>> jnp.array(3.5) Array(3.5, dtype=float32, weak_type=True) >>> jnp.array(1 + 1j) Array(1.+1.j, dtype=complex64, weak_type=True)
Constructing JAX arrays from Python collections:
>>> jnp.array([1, 2, 3]) # list of ints -> 1D array Array([1, 2, 3], dtype=int32) >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.array(range(5)) Array([0, 1, 2, 3, 4], dtype=int32)
Constructing JAX arrays from NumPy arrays:
>>> jnp.array(np.linspace(0, 2, 5)) Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
Constructing a JAX array via the Python buffer interface, using Python’s built-in
arraymodule.>>> from array import array >>> pybuffer = array('i', [2, 3, 5, 7]) >>> jnp.array(pybuffer) Array([2, 3, 5, 7], dtype=int32)
- scico.numpy.array_equal(a1, a2, equal_nan=False)¶
Check if two arrays are element-wise equal.
JAX implementation of
numpy.array_equal.- Parameters:
a1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first input array to compare.a2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second input array to compare.equal_nan (
bool) – Boolean. IfTrue, NaNs ina1will be considered equal to NaNs ina2. Default isFalse.
- Return type:
- Returns:
Boolean scalar array indicating whether the input arrays are element-wise equal.
See also
Examples
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) Array(True, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')]), equal_nan=True) Array(True, dtype=bool)
- scico.numpy.array_equiv(a1, a2)¶
Check if two arrays are element-wise equal.
JAX implementation of
numpy.array_equiv.This function will return
Falseif the input arrays cannot be broadcasted to the same shape.- Parameters:
- Return type:
- Returns:
Boolean scalar array indicating whether the input arrays are element-wise equal after broadcasting.
See also
Examples
>>> jnp.array_equiv(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) Array(True, dtype=bool) >>> jnp.array_equiv(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) Array(False, dtype=bool) >>> jnp.array_equiv(jnp.array([[1, 2, 3], [1, 2, 3]]), ... jnp.array([1, 2, 3])) Array(True, dtype=bool)
- scico.numpy.array_split(ary, indices_or_sections, axis=0)¶
Split an array into sub-arrays.
JAX implementation of
numpy.array_split.Refer to the documentation of
jax.numpy.splitfor details;array_splitis equivalent tosplit, but allows integerindices_or_sectionswhich does not evenly divide the split axis.Examples
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> chunks = jnp.array_split(x, 4) >>> print(*chunks) [1 2 3] [4 5] [6 7] [8 9]
See also
jax.numpy.split: split an array along any axis.jax.numpy.vsplit: split vertically, i.e. along axis=0jax.numpy.hsplit: split horizontally, i.e. along axis=1jax.numpy.dsplit: split depth-wise, i.e. along axis=2
- scico.numpy.asarray(a, dtype=None, order=None, *, copy=None, device=None, out_sharding=None)¶
Convert an object to a JAX array.
JAX implementation of
numpy.asarray.- Parameters:
a (
Any) – an object that is convertible to an array. This includes JAX arrays, NumPy arrays, Python scalars, Python collections like lists and tuples, objects with a__jax_array__method, and objects supporting the Python buffer protocol.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally specify the dtype of the output array. If not specified it will be inferred from the input.copy (
bool|None) – optional boolean specifying the copy mode. If True, then always return a copy. If False, then error if a copy is necessary. Default is None, which will only copy when necessary.device (
Device|Sharding|None) – optionalDeviceorShardingto which the created array will be committed.
- Return type:
- Returns:
A JAX array constructed from the input.
See also
jax.numpy.array: like asarray, but defaults to copy=True.jax.numpy.from_dlpack: construct a JAX array from an object that implements the dlpack interface.jax.numpy.frombuffer: construct a JAX array from an object that implements the buffer interface.
Examples
Constructing JAX arrays from Python scalars:
>>> jnp.asarray(True) Array(True, dtype=bool) >>> jnp.asarray(42) Array(42, dtype=int32, weak_type=True) >>> jnp.asarray(3.5) Array(3.5, dtype=float32, weak_type=True) >>> jnp.asarray(1 + 1j) Array(1.+1.j, dtype=complex64, weak_type=True)
Constructing JAX arrays from Python collections:
>>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array Array([1, 2, 3], dtype=int32) >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.asarray(range(5)) Array([0, 1, 2, 3, 4], dtype=int32)
Constructing JAX arrays from NumPy arrays:
>>> jnp.asarray(np.linspace(0, 2, 5)) Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
Constructing a JAX array via the Python buffer interface, using Python’s built-in
arraymodule.>>> from array import array >>> pybuffer = array('i', [2, 3, 5, 7]) >>> jnp.asarray(pybuffer) Array([2, 3, 5, 7], dtype=int32)
- scico.numpy.astype(x, dtype, /, *, copy=False, device=None)¶
Convert an array to a specified dtype.
JAX implementation of
numpy.astype.This is implemented via
jax.lax.convert_element_type, which may have slightly different behavior thannumpy.astypein some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array to convertdtype (
Union[str,type[Any],dtype,SupportsDType,None]) – output dtypecopy (
bool) – if True, then always return a copy. If False (default) then only return a copy if necessary.device (
Device|Sharding|None) – optionally specify the device to which the output will be committed.
- Return type:
- Returns:
An array with the same shape as
x, containing values of the specified dtype.
See also
jax.lax.convert_element_type: lower-level function for XLA-style dtype conversions.
Examples
>>> x = jnp.array([0, 1, 2, 3]) >>> x Array([0, 1, 2, 3], dtype=int32) >>> x.astype('float32') Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0]) >>> y.astype(int) # truncates fractional values Array([0, 0, 1], dtype=int32)
- scico.numpy.atleast_1d(*arys)¶
Convert inputs to arrays with at least 1 dimension.
JAX implementation of
numpy.atleast_1d.- Parameters:
arguments. (zero or more arraylike)
- Return type:
- Returns:
an array or list of arrays corresponding to the input values. Arrays of shape
()are converted to shape(1,), and arrays with other shapes are returned unchanged.
Examples
Scalar arguments are converted to 1D, length-1 arrays:
>>> x = jnp.float32(1.0) >>> jnp.atleast_1d(x) Array([1.], dtype=float32)
Higher dimensional inputs are returned unchanged:
>>> y = jnp.arange(4) >>> jnp.atleast_1d(y) Array([0, 1, 2, 3], dtype=int32)
Multiple arguments can be passed to the function at once, in which case a list of results is returned:
>>> jnp.atleast_1d(x, y) [Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]
- scico.numpy.atleast_2d(*arys)¶
Convert inputs to arrays with at least 2 dimensions.
JAX implementation of
numpy.atleast_2d.- Parameters:
arguments. (zero or more arraylike)
- Return type:
- Returns:
an array or list of arrays corresponding to the input values. Arrays of shape
()are converted to shape(1, 1), 1D arrays of shape(N,)are converted to shape(1, N), and arrays of all other shapes are returned unchanged.
Examples
Scalar arguments are converted to 2D, size-1 arrays:
>>> x = jnp.float32(1.0) >>> jnp.atleast_2d(x) Array([[1.]], dtype=float32)
One-dimensional arguments have a unit dimension prepended to the shape:
>>> y = jnp.arange(4) >>> jnp.atleast_2d(y) Array([[0, 1, 2, 3]], dtype=int32)
Higher dimensional inputs are returned unchanged:
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_2d(z) Array([[1., 1., 1.], [1., 1., 1.]], dtype=float32)
Multiple arguments can be passed to the function at once, in which case a list of results is returned:
>>> jnp.atleast_2d(x, y) [Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]
- scico.numpy.atleast_3d(*arys)¶
Convert inputs to arrays with at least 3 dimensions.
JAX implementation of
numpy.atleast_3d.- Parameters:
arguments. (zero or more arraylike)
- Return type:
- Returns:
an array or list of arrays corresponding to the input values. Arrays of shape
()are converted to shape(1, 1, 1), 1D arrays of shape(N,)are converted to shape(1, N, 1), 2D arrays of shape(M, N)are converted to shape(M, N, 1), and arrays of all other shapes are returned unchanged.
Examples
Scalar arguments are converted to 3D, size-1 arrays:
>>> x = jnp.float32(1.0) >>> jnp.atleast_3d(x) Array([[[1.]]], dtype=float32)
1D arrays have a unit dimension prepended and appended:
>>> y = jnp.arange(4) >>> jnp.atleast_3d(y).shape (1, 4, 1)
2D arrays have a unit dimension appended:
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_3d(z).shape (2, 3, 1)
Multiple arguments can be passed to the function at once, in which case a list of results is returned:
>>> x3, y3 = jnp.atleast_3d(x, y) >>> print(x3) [[[1.]]] >>> print(y3) [[[0] [1] [2] [3]]]
- scico.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)¶
Compute the weighed average.
JAX Implementation of
numpy.average.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array to be averagedaxis (
Union[int,Sequence[int],None]) – an optional integer or sequence of integers specifying the axis along which the mean to be computed. If not specified, mean is computed along all the axes.weights (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – an optional array of weights for a weighted average. This must either exactly match the shape of a, or if axis is specified, it must have shapea.shape[axis]for a single axis, or shapetuple(a.shape[ax] for ax in axis)for multiple axes.returned (
bool) – If False (default) then return only the average. If True then return both the average and the normalization factor (i.e. the sum of weights).keepdims (
bool) – If True, reduced axes are left in the result with size 1. If False (default) then reduced axes are squeezed out.
- Return type:
- Returns:
An array
averageor tuple of arrays(average, normalization)ifreturnedis True.
See also
jax.numpy.mean: unweighted mean.
Examples
Simple average:
>>> x = jnp.array([1, 2, 3, 2, 4]) >>> jnp.average(x) Array(2.4, dtype=float32)
Weighted average:
>>> weights = jnp.array([2, 1, 3, 2, 2]) >>> jnp.average(x, weights=weights) Array(2.5, dtype=float32)
Use
returned=Trueto optionally return the normalization, i.e. the sum of weights:>>> jnp.average(x, returned=True) (Array(2.4, dtype=float32), Array(5., dtype=float32)) >>> jnp.average(x, weights=weights, returned=True) (Array(2.5, dtype=float32), Array(10., dtype=float32))
Weighted average along a specified axis:
>>> x = jnp.array([[8, 2, 7], ... [3, 6, 4]]) >>> weights = jnp.array([1, 2, 3]) >>> jnp.average(x, weights=weights, axis=1) Array([5.5, 4.5], dtype=float32)
- scico.numpy.bartlett(M)¶
Return a Bartlett window of size M.
JAX implementation of
numpy.bartlett.- Parameters:
M (
int) – The window size.- Return type:
- Returns:
An array of size M containing the Bartlett window.
Examples
>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.bartlett(4)) [0. 0.67 0.67 0. ]
See also
jax.numpy.blackman: return a Blackman window of size M.jax.numpy.hamming: return a Hamming window of size M.jax.numpy.hanning: return a Hanning window of size M.jax.numpy.kaiser: return a Kaiser window of size M.
- scico.numpy.bincount(x, weights=None, minlength=0, *, length=None)¶
Count the number of occurrences of each value in an integer array.
JAX implementation of
numpy.bincount.For an array of non-negative integers
x, this function returns an arraycountsof sizex.max() + 1, such thatcounts[i]contains the number of occurrences of the valueiinx.The JAX version has a few differences from the NumPy version:
In NumPy, passing an array
xwith negative entries will result in an error. In JAX, negative values are clipped to zero.JAX adds an optional
lengthparameter which can be used to statically specify the length of the output array so that this function can be used with transformations likejax.jit. In this case, items larger than length + 1 will be dropped.
- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – 1-dimensional array of non-negative integersweights (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional array of weights associated withx. If not specified, the weight for each entry will be1.minlength (
int) – the minimum length of the output counts array.length (
int|None) – the length of the output counts array. Must be specified statically forbincountto be used withjax.jitand other JAX transformations.
- Return type:
- Returns:
An array of counts or summed weights reflecting the number of occurrences of values in
x.
Examples
Basic bincount:
>>> x = jnp.array([1, 1, 2, 3, 3, 3]) >>> jnp.bincount(x) Array([0, 2, 1, 3], dtype=int32)
Weighted bincount:
>>> weights = jnp.array([1, 2, 3, 4, 5, 6]) >>> jnp.bincount(x, weights) Array([ 0, 3, 3, 15], dtype=int32)
Specifying a static
lengthmakes this jit-compatible:>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length']) >>> jit_bincount(x, length=5) Array([0, 2, 1, 3, 0], dtype=int32)
Any negative numbers are clipped to the first bin, and numbers beyond the specified
lengthare dropped:>>> x = jnp.array([-1, -1, 1, 3, 10]) >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32)
- scico.numpy.blackman(M)¶
Return a Blackman window of size M.
JAX implementation of
numpy.blackman.- Parameters:
M (
int) – The window size.- Return type:
- Returns:
An array of size M containing the Blackman window.
Examples
>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.blackman(4)) [-0. 0.63 0.63 -0. ]
See also
jax.numpy.bartlett: return a Bartlett window of size M.jax.numpy.hamming: return a Hamming window of size M.jax.numpy.hanning: return a Hanning window of size M.jax.numpy.kaiser: return a Kaiser window of size M.
- scico.numpy.block(arrays)¶
Create an array from a list of blocks.
JAX implementation of
numpy.block.- Parameters:
arrays (
Union[Array,ndarray,bool,number,bool,int,float,complex,list[Union[Array,ndarray,bool,number,bool,int,float,complex]]]) – an array, or nested list of arrays which will be concatenated together to form the final array.- Return type:
- Returns:
a single array constructed from the inputs.
See also
Examples
consider these blocks:
>>> zeros = jnp.zeros((2, 2)) >>> ones = jnp.ones((2, 2)) >>> twos = jnp.full((2, 2), 2) >>> threes = jnp.full((2, 2), 3)
Passing a single array to
blockreturns the array:>>> jnp.block(zeros) Array([[0., 0.], [0., 0.]], dtype=float32)
Passing a simple list of arrays concatenates them along the last axis:
>>> jnp.block([zeros, ones]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.]], dtype=float32)
Passing a doubly-nested list of arrays concatenates the inner list along the last axis, and the outer list along the second-to-last axis:
>>> jnp.block([[zeros, ones], ... [twos, threes]]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.], [2., 2., 3., 3.], [2., 2., 3., 3.]], dtype=float32)
Note that blocks need not align in all dimensions, though the size along the axis of concatenation must match. For example, this is valid because after the inner, horizontal concatenation, the resulting blocks have a valid shape for the outer, vertical concatenation.
>>> a = jnp.zeros((2, 1)) >>> b = jnp.ones((2, 3)) >>> c = jnp.full((1, 2), 2) >>> d = jnp.full((1, 2), 3) >>> jnp.block([[a, b], [c, d]]) Array([[0., 1., 1., 1.], [0., 1., 1., 1.], [2., 2., 3., 3.]], dtype=float32)
Note also that this logic generalizes to blocks in 3 or more dimensions. Here’s a 3-dimensional block-wise array:
>>> x = jnp.arange(6).reshape((1, 2, 3)) >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] >>> jnp.block(blocks).shape (5, 8, 9)
- scico.numpy.broadcast_arrays(*args)¶
Broadcast arrays to a common shape.
JAX implementation of
numpy.broadcast_arrays. JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.- Parameters:
args (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – zero or more array-like objects to be broadcasted.- Return type:
- Returns:
a list of arrays containing broadcasted copies of the inputs.
See also
jax.numpy.broadcast_shapes: broadcast input shapes to a common shape.jax.numpy.broadcast_to: broadcast an array to a specified shape.
Examples
>>> x = jnp.arange(3) >>> y = jnp.int32(1) >>> jnp.broadcast_arrays(x, y) [Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)]
>>> x = jnp.array([[1, 2, 3]]) >>> y = jnp.array([[10], ... [20]]) >>> x2, y2 = jnp.broadcast_arrays(x, y) >>> x2 Array([[1, 2, 3], [1, 2, 3]], dtype=int32) >>> y2 Array([[10, 10, 10], [20, 20, 20]], dtype=int32)
- scico.numpy.broadcast_shapes(*shapes)¶
Broadcast input shapes to a common output shape.
JAX implementation of
numpy.broadcast_shapes. JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.- Parameters:
shapes – 0 or more shapes specified as sequences of integers
- Returns:
The broadcasted shape as a tuple of integers.
See also
jax.numpy.broadcast_arrays: broadcast arrays to a common shape.jax.numpy.broadcast_to: broadcast an array to a specified shape.
Examples
Some compatible shapes:
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
Incompatible shapes:
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]
- scico.numpy.broadcast_to(array, shape, *, out_sharding=None)¶
Broadcast an array to a specified shape.
JAX implementation of
numpy.broadcast_to. JAX uses NumPy-style broadcasting rules, which you can read more about at NumPy broadcasting.- Parameters:
- Return type:
- Returns:
a copy of array broadcast to the specified shape.
See also
jax.numpy.broadcast_arrays: broadcast arrays to a common shape.jax.numpy.broadcast_shapes: broadcast input shapes to a common shape.
Examples
>>> x = jnp.int32(1) >>> jnp.broadcast_to(x, (1, 4)) Array([[1, 1, 1, 1]], dtype=int32)
>>> x = jnp.array([1, 2, 3]) >>> jnp.broadcast_to(x, (2, 3)) Array([[1, 2, 3], [1, 2, 3]], dtype=int32)
>>> x = jnp.array([[2], [4]]) >>> jnp.broadcast_to(x, (2, 4)) Array([[2, 2, 2, 2], [4, 4, 4, 4]], dtype=int32)
- scico.numpy.cbrt(x, /)¶
Calculates element-wise cube root of the input array.
JAX implementation of
numpy.cbrt.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.complexdtypes are not supported.- Return type:
- Returns:
An array containing the cube root of the elements of
x.
See also
jax.numpy.sqrt: Calculates the element-wise non-negative square root of the input.jax.numpy.square: Calculates the element-wise square of the input.
Examples
>>> x = jnp.array([[216, 125, 64], ... [-27, -8, -1]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.cbrt(x) Array([[ 6., 5., 4.], [-3., -2., -1.]], dtype=float32)
- scico.numpy.ceil(x, /)¶
Round input to the nearest integer upwards.
JAX implementation of
numpy.ceil.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar. Must not have complex dtype.- Return type:
- Returns:
An array with same shape and dtype as
xcontaining the values rounded to the nearest integer that is greater than or equal to the value itself.
See also
jax.numpy.fix: Rounds the input to the nearest integer towards zero.jax.numpy.trunc: Rounds the input to the nearest integer towards zero.jax.numpy.floor: Rounds the input down to the nearest integer.
Examples
>>> key = jax.random.key(1) >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) >>> with jnp.printoptions(precision=2, suppress=True): ... print(x) [[-0.61 0.34 -0.54] [-0.62 3.97 0.59] [ 4.84 3.42 -1.14]] >>> jnp.ceil(x) Array([[-0., 1., -0.], [-0., 4., 1.], [ 5., 4., -1.]], dtype=float32)
- scico.numpy.choose(a, choices, out=None, mode='raise')¶
Construct an array by stacking slices of choice arrays.
JAX implementation of
numpy.choose.The semantics of this function can be confusing, but in the simplest case where
ais a one-dimensional array,choicesis a two-dimensional array, and all entries ofaare in-bounds (i.e.0 <= a_i < len(choices)), then the function is equivalent to the following:def choose(a, choices): return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])
In the more general case,
amay have any number of dimensions andchoicesmay be an arbitrary sequence of broadcast-compatible arrays. In this case, again for in-bound indices, the logic is equivalent to:def choose(a, choices): a, *choices = jnp.broadcast_arrays(a, *choices) choices = jnp.array(choices) return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])
The only additional complexity comes from the
modeargument, which controls the behavior for out-of-bound indices inaas described below.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – an N-dimensional array of integer indices.choices (
Array|ndarray|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – an array or sequence of arrays. All arrays in the sequence must be mutually broadcast compatible witha.out (
None) – unused by JAXmode (
str) – specify the out-of-bounds indexing mode; one of'raise'(default),'wrap', or'clip'. Note that the default mode of'raise'is not compatible with JAX transformations.
- Return type:
- Returns:
an array containing stacked slices from
choicesat the indices specified bya. The shape of the result isbroadcast_shapes(a.shape, *(c.shape for c in choices)).
See also
jax.lax.switch: choose between N functions based on an index.
Examples
Here is the simplest case of a 1D index array with a 2D choice array, in which case this chooses the indexed value from each column:
>>> choices = jnp.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [ 9, 10, 11, 12]]) >>> a = jnp.array([2, 0, 1, 0]) >>> jnp.choose(a, choices) Array([9, 2, 7, 4], dtype=int32)
The
modeargument specifies what to do with out-of-bound indices; options are to eitherwraporclip:>>> a2 = jnp.array([2, 0, 1, 4]) # last index out-of-bound >>> jnp.choose(a2, choices, mode='clip') Array([ 9, 2, 7, 12], dtype=int32) >>> jnp.choose(a2, choices, mode='wrap') Array([9, 2, 7, 8], dtype=int32)
In the more general case,
choicesmay be a sequence of array-like objects with any broadcast-compatible shapes.>>> choice_1 = jnp.array([1, 2, 3, 4]) >>> choice_2 = 99 >>> choice_3 = jnp.array([[10], ... [20], ... [30]]) >>> a = jnp.array([[0, 1, 2, 0], ... [1, 2, 0, 1], ... [2, 0, 1, 2]]) >>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap') Array([[ 1, 99, 10, 4], [99, 20, 3, 99], [30, 2, 99, 30]], dtype=int32)
- scico.numpy.clip(arr=None, /, min=None, max=None)¶
Clip array values to a specified range.
JAX implementation of
numpy.clip.- Parameters:
arr (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – N-dimensional array to be clipped.min (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional minimum value of the clipped range; ifNone(default) then result will not be clipped to any minimum value. If specified, it should be broadcast-compatible witharrandmax.max (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional maximum value of the clipped range; ifNone(default) then result will not be clipped to any maximum value. If specified, it should be broadcast-compatible witharrandmin.
- Return type:
- Returns:
An array containing values from
arr, with values smaller thanminset tomin, and values larger thanmaxset tomax. Whereverminis larger thanmax, the value ofmaxis returned.
See also
jax.numpy.minimum: Compute the element-wise minimum value of two arrays.jax.numpy.maximum: Compute the element-wise maximum value of two arrays.
Examples
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) >>> jnp.clip(arr, 2, 5) Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)
- scico.numpy.column_stack(tup)¶
Stack arrays column-wise.
JAX implementation of
numpy.column_stack.For arrays of two or more dimensions, this is equivalent to
jax.numpy.concatenatewithaxis=1.- Parameters:
tup (
ndarray|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – a sequence of arrays to stack; each must have the same leading dimension. Input arrays will be promoted to at least rank 2. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.dtype – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Return type:
- Returns:
the stacked result.
See also
jax.numpy.stack: stack along arbitrary axesjax.numpy.concatenate: concatenation along existing axes.jax.numpy.vstack: stack vertically, i.e. along axis 0.jax.numpy.hstack: stack horizontally, i.e. along axis 1.jax.numpy.dstack: stack depth-wise, i.e. along axis 2.
Examples
Scalar values:
>>> jnp.column_stack([1, 2, 3]) Array([[1, 2, 3]], dtype=int32, weak_type=True)
1D arrays:
>>> x = jnp.arange(3) >>> y = jnp.ones(3) >>> jnp.column_stack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)
2D arrays:
>>> x = x.reshape(3, 1) >>> y = y.reshape(3, 1) >>> jnp.column_stack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)
- scico.numpy.compress(condition, a, axis=None, *, size=None, fill_value=0, out=None)¶
Compress an array along a given axis using a boolean condition.
JAX implementation of
numpy.compress.- Parameters:
condition (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – 1-dimensional array of conditions. Will be converted to boolean.a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array of values.axis (
int|None) – axis along which to compress. If None (default) thenawill be flattened, and axis will be set to 0.size (
int|None) – optional static size for output. Must be specified in order forcompressto be compatible with JAX transformations likejitorvmap.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – ifsizeis specified, fill padded entries with this value (default: 0).out (
None) – not implemented by JAX.
- Return type:
- Returns:
An array of dimension
a.ndim, compressed along the specified axis.
See also
jax.numpy.extract: 1D version ofcompress.jax.Array.compress: equivalent functionality as an array method.
Notes
This function does not require strict shape agreement between
conditionanda. Ifcondition.size > a.shape[axis], thenconditionwill be truncated, and ifa.shape[axis] > condition.size, thenawill be truncated.Examples
Compressing along the rows of a 2D array:
>>> a = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> condition = jnp.array([True, False, True]) >>> jnp.compress(condition, a, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
For convenience, you can equivalently use the
compressmethod of JAX arrays:>>> a.compress(condition, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
Note that the condition need not match the shape of the specified axis; here we compress the columns with the length-3 condition. Values beyond the size of the condition are ignored:
>>> jnp.compress(condition, a, axis=1) Array([[ 1, 3], [ 5, 7], [ 9, 11]], dtype=int32)
The optional
sizeargument lets you specify a static output size so that the output is statically-shaped, and so this function can be used with transformations likejitandvmap:>>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0) >>> mask = (a % 3 == 0) >>> jax.vmap(f)(mask, a) Array([[ 3, 0, 0, 0], [ 6, 0, 0, 0], [ 9, 12, 0, 0]], dtype=int32)
- scico.numpy.concat(arrays, /, *, axis=0)¶
Join arrays along an existing axis.
JAX implementation of
array_api.concat.- Parameters:
arrays (
Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – a sequence of arrays to concatenate; each must have the same shape except along the specified axis. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.axis (
int|None) – specify the axis along which to concatenate.
- Return type:
- Returns:
the concatenated result.
See also
jax.lax.concatenate: XLA concatenation API.jax.numpy.concatenate: NumPy version of this function.jax.numpy.stack: concatenate arrays along a new axis.
Examples
One-dimensional concatenation:
>>> x = jnp.arange(3) >>> y = jnp.zeros(3, dtype=int) >>> jnp.concat([x, y]) Array([0, 1, 2, 0, 0, 0], dtype=int32)
Two-dimensional concatenation:
>>> x = jnp.ones((2, 3)) >>> y = jnp.zeros((2, 1)) >>> jnp.concat([x, y], axis=1) Array([[1., 1., 1., 0.], [1., 1., 1., 0.]], dtype=float32)
- scico.numpy.concatenate(arrays, axis=0, dtype=None)¶
Join arrays along an existing axis.
JAX implementation of
numpy.concatenate.- Parameters:
arrays (
ndarray|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – a sequence of arrays to concatenate; each must have the same shape except along the specified axis. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.axis (
int|None) – specify the axis along which to concatenate. If None, the arrays are flattened before concatenation.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Return type:
- Returns:
the concatenated result.
See also
jax.lax.concatenate: XLA concatenation API.jax.numpy.concat: Array API version of this function.jax.numpy.stack: concatenate arrays along a new axis.
Examples
One-dimensional concatenation:
>>> x = jnp.arange(3) >>> y = jnp.zeros(3, dtype=int) >>> jnp.concatenate([x, y]) Array([0, 1, 2, 0, 0, 0], dtype=int32)
Two-dimensional concatenation:
>>> x = jnp.ones((2, 3)) >>> y = jnp.zeros((2, 1)) >>> jnp.concatenate([x, y], axis=1) Array([[1., 1., 1., 0.], [1., 1., 1., 0.]], dtype=float32)
- scico.numpy.conj(x, /)¶
Alias of
jax.numpy.conjugate- Return type:
- scico.numpy.conjugate(x, /)¶
Return element-wise complex-conjugate of the input.
JAX implementation of
numpy.conjugate.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – inpuat array or scalar.- Return type:
- Returns:
An array containing the complex-conjugate of
x.
See also
jax.numpy.real: Returns the element-wise real part of the complex argument.jax.numpy.imag: Returns the element-wise imaginary part of the complex argument.
Examples
>>> jnp.conjugate(3) Array(3, dtype=int32, weak_type=True) >>> x = jnp.array([2-1j, 3+5j, 7]) >>> jnp.conjugate(x) Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64)
- scico.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)¶
Convolution of two one dimensional arrays.
JAX implementation of
numpy.convolve.Convolution of one dimensional arrays is defined as:
\[c_k = \sum_j a_{k - j} v_j\]- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – left-hand input to the convolution. Must havea.ndim == 1.v (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – right-hand input to the convolution. Must havev.ndim == 1.mode (
str) –controls the size of the output. Available operations are:
"full": (default) output the full convolution of the inputs."same": return a centered portion of the"full"output which is the same size asa."valid": return the portion of the"full"output which do not depend on padding at the array edges.
precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – Specify the precision of the computation. Refer tojax.lax.Precisionfor a description of available values.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – A datatype, indicating to accumulate results to and return a result with that datatype. Default isNone, which means the default accumulation type for the input types.
- Return type:
- Returns:
Array containing the convolved result.
See also
jax.scipy.signal.convolve: ND convolutionjax.numpy.correlate: 1D correlation
Examples
A few 1D convolution examples:
>>> x = jnp.array([1, 2, 3, 2, 1]) >>> y = jnp.array([4, 1, 2])
jax.numpy.convolve, by default, returns full convolution using implicit zero-padding at the edges:>>> jnp.convolve(x, y) Array([ 4., 9., 16., 15., 12., 5., 2.], dtype=float32)
Specifying
mode = 'same'returns a centered convolution the same size as the first input:>>> jnp.convolve(x, y, mode='same') Array([ 9., 16., 15., 12., 5.], dtype=float32)
Specifying
mode = 'valid'returns only the portion where the two arrays fully overlap:>>> jnp.convolve(x, y, mode='valid') Array([16., 15., 12.], dtype=float32)
For complex-valued inputs:
>>> x1 = jnp.array([3+1j, 2, 4-3j]) >>> y1 = jnp.array([1, 2-3j, 4+5j]) >>> jnp.convolve(x1, y1) Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64)
- scico.numpy.copy(a, order=None)¶
Return a copy of the array.
JAX implementation of
numpy.copy.- Parameters:
- Return type:
- Returns:
a copy of the input array
a.
See also
jax.numpy.array: create an array with or without a copy.jax.Array.copy: same function accessed as an array method.
Examples
Since JAX arrays are immutable, in most cases explicit array copies are not necessary. One exception is when using a function with donated arguments (see the
donate_argnumsargument tojax.jit).>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0) >>> x = jnp.arange(4) >>> y = f(x) >>> print(y) [0 2 4 6]
Because we marked
xas being donated, the original array is no longer available:>>> print(x) Traceback (most recent call last): RuntimeError: Array has been deleted with shape=int32[4].
In situations like this, an explicit copy will let you keep access to the original buffer:
>>> x = jnp.arange(4) >>> y = f(x.copy()) >>> print(y) [0 2 4 6] >>> print(x) [0 1 2 3]
- scico.numpy.copysign(x1, x2, /)¶
Copies the sign of each element in
x2to the corresponding element inx1.JAX implementation of
numpy.copysign.- Parameters:
- Return type:
- Returns:
An array object containing the potentially changed elements of
x1, always promotes to inexact dtype, and has a shape ofjnp.broadcast_shapes(x1.shape, x2.shape)
Examples
>>> x1 = jnp.array([5, 2, 0]) >>> x2 = -1 >>> jnp.copysign(x1, x2) Array([-5., -2., -0.], dtype=float32)
>>> x1 = jnp.array([6, 8, 0]) >>> x2 = 2 >>> jnp.copysign(x1, x2) Array([6., 8., 0.], dtype=float32)
>>> x1 = jnp.array([2, -3]) >>> x2 = jnp.array([[1],[-4], [5]]) >>> jnp.copysign(x1, x2) Array([[ 2., 3.], [-2., -3.], [ 2., 3.]], dtype=float32)
- scico.numpy.cos(x, /)¶
Compute a trigonometric cosine of each element of input.
JAX implementation of
numpy.cos.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Angle in radians.- Return type:
- Returns:
An array containing the cosine of each element in
x, promotes to inexact dtype.
See also
jax.numpy.sin: Computes a trigonometric sine of each element of input.jax.numpy.tan: Computes a trigonometric tangent of each element of input.jax.numpy.arccosandjax.numpy.acos: Computes the inverse of trigonometric cosine of each element of input.
Examples
>>> pi = jnp.pi >>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.cos(x)) [ 0.707 -0. -0.707 -0.866]
- scico.numpy.cosh(x, /)¶
Calculate element-wise hyperbolic cosine of input.
JAX implementation of
numpy.cosh.The hyperbolic cosine is defined by:
\[cosh(x) = \frac{e^x + e^{-x}}{2}\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the hyperbolic cosine of each element of
x, promoting to inexact dtype.
Note
jnp.coshis equivalent to computingjnp.cos(1j * x).See also
jax.numpy.sinh: Computes the element-wise hyperbolic sine of the input.jax.numpy.tanh: Computes the element-wise hyperbolic tangent of the input.jax.numpy.arccosh: Computes the element-wise inverse of hyperbolic cosine of the input.
Examples
>>> x = jnp.array([[3, -1, 0], ... [4, 7, -5]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.cosh(x) Array([[ 10.068, 1.543, 1. ], [ 27.308, 548.317, 74.21 ]], dtype=float32) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.cos(1j * x) Array([[ 10.068+0.j, 1.543+0.j, 1. +0.j], [ 27.308+0.j, 548.317+0.j, 74.21 +0.j]], dtype=complex64, weak_type=True)
For complex-valued input:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.cosh(5+1j) Array(40.096+62.44j, dtype=complex64, weak_type=True) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.cos(1j * (5+1j)) Array(40.096+62.44j, dtype=complex64, weak_type=True)
- scico.numpy.count_nonzero(a, axis=None, keepdims=False)¶
Return the number of nonzero elements along a given axis.
JAX implementation of
numpy.count_nonzero.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array.axis (
Union[int,Sequence[int],None]) – optional, int or sequence of ints, default=None. Axis along which the number of nonzeros are counted. If None, counts within the flattened array.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.
- Return type:
- Returns:
An array with number of nonzeros elements along specified axis of the input.
Examples
By default,
jnp.count_nonzerocounts the nonzero values along all axes.>>> x = jnp.array([[1, 0, 0, 0], ... [0, 0, 1, 0], ... [1, 1, 1, 0]]) >>> jnp.count_nonzero(x) Array(5, dtype=int32)
If
axis=1, counts along axis 1.>>> jnp.count_nonzero(x, axis=1) Array([1, 1, 3], dtype=int32)
To preserve the dimensions of input, you can set
keepdims=True.>>> jnp.count_nonzero(x, axis=1, keepdims=True) Array([[1], [1], [3]], dtype=int32)
- scico.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)¶
Compute the (batched) cross product of two arrays.
JAX implementation of
numpy.cross.This computes the 2-dimensional or 3-dimensional cross product,
\[c = a \times b\]In 3 dimensions,
cis a length-3 array. In 2 dimensions,cis a scalar.- Parameters:
a – N-dimensional array.
a.shape[axisa]indicates the dimension of the cross product, and must be 2 or 3.b – N-dimensional array. Must have
b.shape[axisb] == a.shape[axisb], and other dimensions ofaandbmust be broadcast compatible.axisa (
int) – specicy the axis ofaalong which to compute the cross product.axisb (
int) – specicy the axis ofbalong which to compute the cross product.axisc (
int) – specicy the axis ofcalong which the cross product result will be stored.axis (
int|None) – if specified, this overridesaxisa,axisb, andaxiscwith a single value.
- Returns:
The array
ccontaining the (batched) cross product ofaandbalong the specified axes.
See also
jax.numpy.linalg.cross: an array API compatible function for computing cross products over 3-vectors.
Examples
A 2-dimensional cross product returns a scalar:
>>> a = jnp.array([1, 2]) >>> b = jnp.array([3, 4]) >>> jnp.cross(a, b) Array(-2, dtype=int32)
A 3-dimensional cross product returns a length-3 vector:
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.cross(a, b) Array([-3, 6, -3], dtype=int32)
With multi-dimensional inputs, the cross-product is computed along the last axis by default. Here’s a batched 3-dimensional cross product, operating on the rows of the inputs:
>>> a = jnp.array([[1, 2, 3], ... [3, 4, 3]]) >>> b = jnp.array([[2, 3, 2], ... [4, 5, 6]]) >>> jnp.cross(a, b) Array([[-5, 4, -1], [ 9, -6, -1]], dtype=int32)
Specifying axis=0 makes this a batched 2-dimensional cross product, operating on the columns of the inputs:
>>> jnp.cross(a, b, axis=0) Array([-2, -2, 12], dtype=int32)
Equivalently, we can independently specify the axis of the inputs
aandband the outputc:>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0) Array([-2, -2, 12], dtype=int32)
- scico.numpy.cumprod(a, axis=None, dtype=None, out=None)¶
Cumulative product of elements along an axis.
JAX implementation of
numpy.cumprod.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array to be accumulated.axis (
int|None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.out (
None) – unused by JAX
- Return type:
- Returns:
An array containing the accumulated product along the given axis.
See also
jax.numpy.multiply.accumulate: cumulative product via ufunc methods.jax.numpy.nancumprod: cumulative product ignoring NaN values.jax.numpy.prod: product along axis
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumprod(x) # flattened cumulative product Array([ 1, 2, 6, 24, 120, 720], dtype=int32) >>> jnp.cumprod(x, axis=1) # cumulative product along axis 1 Array([[ 1, 2, 6], [ 4, 20, 120]], dtype=int32)
- scico.numpy.cumsum(a, axis=None, dtype=None, out=None)¶
Cumulative sum of elements along an axis.
JAX implementation of
numpy.cumsum.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array to be accumulated.axis (
int|None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.out (
None) – unused by JAX
- Return type:
- Returns:
An array containing the accumulated sum along the given axis.
See also
jax.numpy.cumulative_sum: cumulative sum via the array API standard.jax.numpy.add.accumulate: cumulative sum via ufunc methods.jax.numpy.nancumsum: cumulative sum ignoring NaN values.jax.numpy.sum: sum along axis
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumsum(x) # flattened cumulative sum Array([ 1, 3, 6, 10, 15, 21], dtype=int32) >>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1 Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32)
- scico.numpy.cumulative_prod(x, /, *, axis=None, dtype=None, include_initial=False)¶
Cumulative product along the axis of an array.
JAX implementation of
numpy.cumulative_prod.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional arrayaxis (
int|None) – integer axis along which to accumulate. Ifxis one-dimensional, this argument is optional and defaults to zero.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype of the output.include_initial (
bool) – if True, then include the initial value in the cumulative product. Default is False.
- Return type:
- Returns:
An array containing the accumulated values.
See also
jax.numpy.cumprod: alternative API for cumulative product.jax.numpy.nancumprod: cumulative product while ignoring NaN values.jax.numpy.multiply.accumulate: cumulative product via the ufunc API.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumulative_prod(x, axis=1) Array([[ 1, 2, 6], [ 4, 20, 120]], dtype=int32) >>> jnp.cumulative_prod(x, axis=1, include_initial=True) Array([[ 1, 1, 2, 6], [ 1, 4, 20, 120]], dtype=int32)
- scico.numpy.cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False)¶
Cumulative sum along the axis of an array.
JAX implementation of
numpy.cumulative_sum.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional arrayaxis (
int|None) – integer axis along which to accumulate. Ifxis one-dimensional, this argument is optional and defaults to zero.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype of the output.include_initial (
bool) – if True, then include the initial value in the cumulative sum. Default is False.
- Return type:
- Returns:
An array containing the accumulated values.
See also
jax.numpy.cumsum: alternative API for cumulative sum.jax.numpy.nancumsum: cumulative sum while ignoring NaN values.jax.numpy.add.accumulate: cumulative sum via the ufunc API.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumulative_sum(x, axis=1) Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32) >>> jnp.cumulative_sum(x, axis=1, include_initial=True) Array([[ 0, 1, 3, 6], [ 0, 4, 9, 15]], dtype=int32)
- scico.numpy.deg2rad(x, /)¶
Convert angles from degrees to radians.
JAX implementation of
numpy.deg2rad.The angle in degrees is converted to radians by:
\[deg2rad(x) = x * \frac{pi}{180}\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the angle in degrees.- Return type:
- Returns:
An array containing the angles in radians.
See also
jax.numpy.rad2degandjax.numpy.degrees: Converts the angles from radians to degrees.jax.numpy.radians: Alias ofdeg2rad.
Examples
>>> x = jnp.array([60, 90, 120, 180]) >>> jnp.deg2rad(x) Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32) >>> x * jnp.pi / 180 Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32, weak_type=True)
- scico.numpy.degrees(x, /)¶
Alias of
jax.numpy.rad2deg- Return type:
- scico.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)¶
Delete entry or entries from an array.
JAX implementation of
numpy.delete.- Parameters:
arr (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array from which entries will be deleted.obj (
Union[Array,ndarray,bool,number,bool,int,float,complex,slice]) – index, indices, or slice to be deleted.axis (
int|None) – axis along which entries will be deleted.assume_unique_indices (
bool) – In case of array-like integer (not boolean) indices, assume the indices are unique, and perform the deletion in a way that is compatible with JIT and other JAX transformations.
- Return type:
- Returns:
Copy of
arrwith specified indices deleted.
Note
delete()usually requires the index specification to be static. If the index is an integer array that is guaranteed to contain unique entries, you may specifyassume_unique_indices=Trueto perform the operation in a manner that does not require static indices.See also
jax.numpy.insert: insert entries into an array.
Examples
Delete entries from a 1D array:
>>> a = jnp.array([4, 5, 6, 7, 8, 9]) >>> jnp.delete(a, 2) Array([4, 5, 7, 8, 9], dtype=int32) >>> jnp.delete(a, slice(1, 4)) # delete a[1:4] Array([4, 8, 9], dtype=int32) >>> jnp.delete(a, slice(None, None, 2)) # delete a[::2] Array([5, 7, 9], dtype=int32)
Delete entries from a 2D array along a specified axis:
>>> a2 = jnp.array([[4, 5, 6], ... [7, 8, 9]]) >>> jnp.delete(a2, 1, axis=1) Array([[4, 6], [7, 9]], dtype=int32)
Delete multiple entries via a sequence of indices:
>>> indices = jnp.array([0, 1, 3]) >>> jnp.delete(a, indices) Array([6, 8, 9], dtype=int32)
This will fail under
jitand other transformations, because the output shape cannot be known with the possibility of duplicate indices:>>> jax.jit(jnp.delete)(a, indices) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].
If you can ensure that the indices are unique, pass
assume_unique_indicesto allow this to be executed under JIT:>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices']) >>> jit_delete(a, indices, assume_unique_indices=True) Array([6, 8, 9], dtype=int32)
- scico.numpy.diag(v, k=0)¶
Returns the specified diagonal or constructs a diagonal array.
JAX implementation of
numpy.diag.The JAX version always returns a copy of the input, although if this is used within a JIT compilation, the compiler may avoid the copy.
- Parameters:
v (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array. Can be a 1-D array to create a diagonal matrix or a 2-D array to extract a diagonal.k (
int) – optional, default=0. Diagonal offset. Positive values place the diagonal above the main diagonal, negative values place it below the main diagonal.
- Return type:
- Returns:
If v is a 2-D array, a 1-D array containing the diagonal elements. If v is a 1-D array, a 2-D array with the input elements placed along the specified diagonal.
See also
Examples
Creating a diagonal matrix from a 1-D array:
>>> jnp.diag(jnp.array([1, 2, 3])) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32)
Specifying a diagonal offset:
>>> jnp.diag(jnp.array([1, 2, 3]), k=1) Array([[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3], [0, 0, 0, 0]], dtype=int32)
Extracting a diagonal from a 2-D array:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.diag(x) Array([1, 5, 9], dtype=int32)
- scico.numpy.diag_indices(n, ndim=2)¶
Return indices for accessing the main diagonal of a multidimensional array.
JAX implementation of
numpy.diag_indices.- Parameters:
- Return type:
- Returns:
A tuple of arrays, each of length n, containing the indices to access the main diagonal.
Examples
>>> jnp.diag_indices(3) (Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32)) >>> jnp.diag_indices(4, ndim=3) (Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32))
- scico.numpy.diag_indices_from(arr)¶
Return indices for accessing the main diagonal of a given array.
JAX implementation of
numpy.diag_indices_from.- Parameters:
arr (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array. Must be at least 2-dimensional and have equal length along all dimensions.- Return type:
- Returns:
A tuple of arrays containing the indices to access the main diagonal of the input array.
See also
Examples
>>> arr = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.diag_indices_from(arr) (Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32)) >>> arr = jnp.array([[[1, 2], [3, 4]], ... [[5, 6], [7, 8]]]) >>> jnp.diag_indices_from(arr) (Array([0, 1], dtype=int32), Array([0, 1], dtype=int32), Array([0, 1], dtype=int32))
- scico.numpy.diagflat(v, k=0)¶
Return a 2-D array with the flattened input array laid out on the diagonal.
JAX implementation of
numpy.diagflat.This differs from np.diagflat for some scalar values of v. JAX always returns a two-dimensional array, whereas NumPy may return a scalar depending on the type of v.
- Parameters:
- Return type:
- Returns:
A 2D array with the input elements placed along the diagonal with the specified offset (k). The remaining entries are filled with zeros.
See also
Examples
>>> jnp.diagflat(jnp.array([1, 2, 3])) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32) >>> jnp.diagflat(jnp.array([1, 2, 3]), k=1) Array([[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3], [0, 0, 0, 0]], dtype=int32) >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.diagflat(a) Array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], dtype=int32)
- scico.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)¶
Calculate n-th order difference between array elements along a given axis.
JAX implementation of
numpy.diff.The first order difference is computed by
a[i+1] - a[i], and the n-th order difference is computedntimes recursively.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array. Must havea.ndim >= 1.n (
int) – int, optional, default=1. Order of the difference. Specifies the number of times the difference is computed. If n=0, no difference is computed and input is returned as is.axis (
int) – int, optional, default=-1. Specifies the axis along which the difference is computed. The difference is computed alongaxis -1by default.prepend (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – scalar or array, optional, default=None. Specifies the values to be prepended alongaxisbefore computing the difference.append (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – scalar or array, optional, default=None. Specifies the values to be appended alongaxisbefore computing the difference.
- Return type:
- Returns:
An array containing the n-th order difference between the elements of
a.
See also
jax.numpy.ediff1d: Computes the differences between consecutive elements of an array.jax.numpy.cumsum: Computes the cumulative sum of the elements of the array along a given axis.jax.numpy.gradient: Computes the gradient of an N-dimensional array.
Examples
jnp.diffcomputes the first order difference alongaxis, by default.>>> a = jnp.array([[1, 5, 2, 9], ... [3, 8, 7, 4]]) >>> jnp.diff(a) Array([[ 4, -3, 7], [ 5, -1, -3]], dtype=int32)
When
n = 2, second order difference is computed alongaxis.>>> jnp.diff(a, n=2) Array([[-7, 10], [-6, -2]], dtype=int32)
When
prepend = 2, it is prepended toaalongaxisbefore computing the difference.>>> jnp.diff(a, prepend=2) Array([[-1, 4, -3, 7], [ 1, 5, -1, -3]], dtype=int32)
When
append = jnp.array([[3],[1]]), it is appended toaalongaxisbefore computing the difference.>>> jnp.diff(a, append=jnp.array([[3],[1]])) Array([[ 4, -3, 7, -6], [ 5, -1, -3, -3]], dtype=int32)
- scico.numpy.divide(x1, x2, /)¶
Alias of
jax.numpy.true_divide.- Return type:
- scico.numpy.divmod(x1, x2, /)¶
Calculates the integer quotient and remainder of x1 by x2 element-wise
JAX implementation of
numpy.divmod.- Parameters:
- Return type:
- Returns:
A tuple of arrays
(x1 // x2, x1 % x2).
See also
jax.numpy.floor_divide: floor division functionjax.numpy.remainder: remainder function
Examples
>>> x1 = jnp.array([10, 20, 30]) >>> x2 = jnp.array([3, 4, 7]) >>> jnp.divmod(x1, x2) (Array([3, 5, 4], dtype=int32), Array([1, 0, 2], dtype=int32))
>>> x1 = jnp.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]) >>> x2 = 3 >>> jnp.divmod(x1, x2) (Array([-2, -2, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=int32), Array([1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], dtype=int32))
>>> x1 = jnp.array([6, 6, 6], dtype=jnp.int32) >>> x2 = jnp.array([1.9, 2.5, 3.1], dtype=jnp.float32) >>> jnp.divmod(x1, x2) (Array([3., 2., 1.], dtype=float32), Array([0.30000007, 1. , 2.9 ], dtype=float32))
- scico.numpy.dot(a, b, *, precision=None, preferred_element_type=None, out_sharding=None)¶
Compute the dot product of two arrays.
JAX implementation of
numpy.dot.This differs from
jax.numpy.matmulin two respects:if either
aorbis a scalar, the result ofdotis equivalent tojax.numpy.multiply, while the result ofmatmulis an error.if
aandbhave more than 2 dimensions, the batch indices are stacked rather than broadcast.
- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first input array, of shape(..., N).b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second input array. Must have shape(N,)or(..., N, M). In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions ofa.precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofaandb.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
array containing the dot product of the inputs, with batch dimensions of
aandbstacked rather than broadcast.
See also
jax.numpy.matmul: broadcasted batched matmul.jax.lax.dot_general: general batched matrix multiplication.
Examples
For scalar inputs,
dotcomputes the element-wise product:>>> x = jnp.array([1, 2, 3]) >>> jnp.dot(x, 2) Array([2, 4, 6], dtype=int32)
For vector or matrix inputs,
dotcomputes the vector or matrix product:>>> M = jnp.array([[2, 3, 4], ... [5, 6, 7], ... [8, 9, 0]]) >>> jnp.dot(M, x) Array([20, 38, 26], dtype=int32) >>> jnp.dot(M, M) Array([[ 51, 60, 29], [ 96, 114, 62], [ 61, 78, 95]], dtype=int32)
For higher-dimensional matrix products, batch dimensions are stacked, whereas in
matmulthey are broadcast. For example:>>> a = jnp.zeros((3, 2, 4)) >>> b = jnp.zeros((3, 4, 1)) >>> jnp.dot(a, b).shape (3, 2, 3, 1) >>> jnp.matmul(a, b).shape (3, 2, 1)
- scico.numpy.dsplit(ary, indices_or_sections)¶
Split an array into sub-arrays depth-wise.
JAX implementation of
numpy.dsplit.Refer to the documentation of
jax.numpy.splitfor details.dsplitis equivalent tosplitwithaxis=2.Examples
>>> x = jnp.arange(12).reshape(3, 1, 4) >>> print(x) [[[ 0 1 2 3]] [[ 4 5 6 7]] [[ 8 9 10 11]]] >>> x1, x2 = jnp.dsplit(x, 2) >>> print(x1) [[[0 1]] [[4 5]] [[8 9]]] >>> print(x2) [[[ 2 3]] [[ 6 7]] [[10 11]]]
See also
jax.numpy.split: split an array along any axis.jax.numpy.vsplit: split vertically, i.e. along axis=0jax.numpy.hsplit: split horizontally, i.e. along axis=1jax.numpy.array_split: likesplit, but allowsindices_or_sectionsto be an integer that does not evenly divide the size of the array.
- scico.numpy.dstack(tup, dtype=None)¶
Stack arrays depth-wise.
JAX implementation of
numpy.dstack.For arrays of three or more dimensions, this is equivalent to
jax.numpy.concatenatewithaxis=2.- Parameters:
tup (
ndarray|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – a sequence of arrays to stack; each must have the same shape along all but the third axis. Input arrays will be promoted to at least rank 3. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Return type:
- Returns:
the stacked result.
See also
jax.numpy.stack: stack along arbitrary axesjax.numpy.concatenate: concatenation along existing axes.jax.numpy.vstack: stack vertically, i.e. along axis 0.jax.numpy.hstack: stack horizontally, i.e. along axis 1.
Examples
Scalar values:
>>> jnp.dstack([1, 2, 3]) Array([[[1, 2, 3]]], dtype=int32, weak_type=True)
1D arrays:
>>> x = jnp.arange(3) >>> y = jnp.ones(3) >>> jnp.dstack([x, y]) Array([[[0., 1.], [1., 1.], [2., 1.]]], dtype=float32)
2D arrays:
>>> x = x.reshape(1, 3) >>> y = y.reshape(1, 3) >>> jnp.dstack([x, y]) Array([[[0., 1.], [1., 1.], [2., 1.]]], dtype=float32)
- scico.numpy.ediff1d(ary, to_end=None, to_begin=None)¶
Compute the differences of the elements of the flattened array.
JAX implementation of
numpy.ediff1d.- Parameters:
ary (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.to_end (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – scalar or array, optional, default=None. Specifies the numbers to append to the resulting array.to_begin (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – scalar or array, optional, default=None. Specifies the numbers to prepend to the resulting array.
- Return type:
- Returns:
An array containing the differences between the elements of the input array.
Note
Unlike NumPy’s implementation of ediff1d,
jax.numpy.ediff1dwill not issue an error if castingto_endorto_beginto the type ofaryloses precision.See also
jax.numpy.diff: Computes the n-th order difference between elements of the array along a given axis.jax.numpy.cumsum: Computes the cumulative sum of the elements of the array along a given axis.jax.numpy.gradient: Computes the gradient of an N-dimensional array.
Examples
>>> a = jnp.array([2, 3, 5, 9, 1, 4]) >>> jnp.ediff1d(a) Array([ 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10) Array([-10, 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_end=jnp.array([20, 30])) Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30])) Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32)
For array with
ndim > 1, the differences are computed after flattening the input array.>>> a1 = jnp.array([[2, -1, 4, 7], ... [3, 5, -6, 9]]) >>> jnp.ediff1d(a1) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9]) >>> jnp.ediff1d(a2) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)
- scico.numpy.einsum(subscripts, /, *operands, out=None, optimize='auto', precision=None, preferred_element_type=None, _dot_general=<function dot_general>, out_sharding=None)¶
Einstein summation
JAX implementation of
numpy.einsum.einsumis a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. It has a somewhat complicated overloaded API; the arguments below reflect the most common calling convention. The Examples section below demonstrates some of the alternative calling conventions.- Parameters:
subscripts – string containing axes names separated by commas.
*operands – sequence of one or more arrays corresponding to the subscripts.
optimize (
str|bool|list[tuple[int,...]]) – specify how to optimize the order of computation. In JAX this defaults to"auto"which produces optimized expressions via the opt_einsum package. Other options areTrue(same as"optimal"),False(unoptimized), or any string supported byopt_einsum, which includes"optimal","greedy","eager", and others. It may also be a pre-computed path (seeeinsum_path).precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST).preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.out (
None) – unsupported by JAX_dot_general (
Callable[...,Array]) – optionally override thedot_generalcallable used byeinsum. This parameter is experimental, and may be removed without warning at any time.
- Return type:
- Returns:
array containing the result of the einstein summation.
See also
Examples
The mechanics of
einsumare perhaps best demonstrated by example. Here we show how to useeinsumto compute a number of quantities from one or more arrays. For more discussion and examples ofeinsum, see the documentation ofnumpy.einsum.>>> M = jnp.arange(16).reshape(4, 4) >>> x = jnp.arange(4) >>> y = jnp.array([5, 4, 3, 2])
Vector product
>>> jnp.einsum('i,i', x, y) Array(16, dtype=int32) >>> jnp.vecdot(x, y) Array(16, dtype=int32)
Here are some alternative
einsumcalling conventions to compute the same result:>>> jnp.einsum('i,i->', x, y) # explicit form Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,)) # implicit form via indices Array(16, dtype=int32) >>> jnp.einsum(x, (0,), y, (0,), ()) # explicit form via indices Array(16, dtype=int32)
Matrix product
>>> jnp.einsum('ij,j->i', M, x) # explicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.matmul(M, x) Array([14, 38, 62, 86], dtype=int32)
Here are some alternative
einsumcalling conventions to compute the same result:>>> jnp.einsum('ij,j', M, x) # implicit form Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices Array([14, 38, 62, 86], dtype=int32) >>> jnp.einsum(M, (0, 1), x, (1,)) # implicit form via indices Array([14, 38, 62, 86], dtype=int32)
Outer product
>>> jnp.einsum("i,j->ij", x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.outer(x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
Some other ways of computing outer products:
>>> jnp.einsum("i,j", x, y) # implicit form Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32) >>> jnp.einsum(x, (0,), y, (1,)) # implicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=int32)
1D array sum
>>> jnp.einsum("i->", x) # requires explicit form Array(6, dtype=int32) >>> jnp.einsum(x, (0,), ()) # explicit form via indices Array(6, dtype=int32) >>> jnp.sum(x) Array(6, dtype=int32)
Sum along an axis
>>> jnp.einsum("...j->...", M) # requires explicit form Array([ 6, 22, 38, 54], dtype=int32) >>> jnp.einsum(M, (..., 0), (...,)) # explicit form via indices Array([ 6, 22, 38, 54], dtype=int32) >>> M.sum(-1) Array([ 6, 22, 38, 54], dtype=int32)
Matrix transpose
>>> y = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.einsum("ij->ji", y) # explicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum("ji", y) # implicit form Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (1, 0)) # implicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.einsum(y, (0, 1), (1, 0)) # explicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=int32) >>> jnp.transpose(y) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
Matrix diagonal
>>> jnp.einsum("ii->i", M) Array([ 0, 5, 10, 15], dtype=int32) >>> jnp.diagonal(M) Array([ 0, 5, 10, 15], dtype=int32)
Matrix trace
>>> jnp.einsum("ii", M) Array(30, dtype=int32) >>> jnp.trace(M) Array(30, dtype=int32)
Tensor products
>>> x = jnp.arange(30).reshape(2, 3, 5) >>> y = jnp.arange(60).reshape(3, 4, 5) >>> jnp.einsum('ijk,jlk->il', x, y) # explicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.tensordot(x, y, axes=[(1, 2), (0, 2)]) Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum('ijk,jlk', x, y) # implicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32) >>> jnp.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=int32)
Chained dot products
>>> w = jnp.arange(5, 9).reshape(2, 2) >>> x = jnp.arange(6).reshape(2, 3) >>> y = jnp.arange(-2, 4).reshape(3, 2) >>> z = jnp.array([[2, 4, 6], [3, 5, 7]]) >>> jnp.einsum('ij,jk,kl,lm->im', w, x, y, z) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> w @ x @ y @ z # direct chain of matmuls Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32) >>> jnp.linalg.multi_dot([w, x, y, z]) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=int32)
- scico.numpy.einsum_path(subscripts, /, *operands, optimize='auto')¶
Evaluates the optimal contraction path without evaluating the einsum.
JAX implementation of
numpy.einsum_path. This function calls into the opt_einsum package, and makes use of its optimization routines.- Parameters:
subscripts – string containing axes names separated by commas.
*operands – sequence of one or more arrays corresponding to the subscripts.
optimize (
bool|str|list[tuple[int,...]]) – specify how to optimize the order of computation. In JAX this defaults to"auto". Other options areTrue(same as"optimize"),False(unoptimized), or any string supported byopt_einsum, which includes"optimize",,"greedy","eager", and others.
- Return type:
- Returns:
A tuple containing the path that may be passed to
einsum, and a printable object representing this optimal path.
Examples
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3)) >>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100)) >>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5)) >>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal") >>> print(path) [(1, 2), (0, 1)] >>> print(path_info) Complete contraction: ij,jk,kl->il Naive scaling: 4 Optimized scaling: 3 Naive FLOP count: 9.000e+3 Optimized FLOP count: 3.060e+3 Theoretical speedup: 2.941e+0 Largest intermediate: 1.500e+1 elements -------------------------------------------------------------------------------- scaling BLAS current remaining -------------------------------------------------------------------------------- 3 GEMM kl,jk->lj ij,lj->il 3 GEMM lj,ij->il il->il
Use the computed path in
einsum:>>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path) Array([[-754, 324, -142, 82, 50], [ 408, -50, 87, -29, 7]], dtype=int32)
- scico.numpy.empty(shape, dtype=None, *, device=None, out_sharding=None)¶
Create an empty array.
JAX implementation of
numpy.empty.Note
For historical reasons,
jax.numpy.emptyis currently equivalent tojax.numpy.zeros: i.e. it returns a buffer initialized with zeros. To create a buffer of uninitialized values, please usejax.lax.empty.- Parameters:
shape (
Any) – int or sequence of ints specifying the shape of the created array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype for the created array; defaults to float32 or float64 depending on the X64 configuration (see Default dtypes and the X64 flag).device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed. This argument exists for compatibility with the Python Array API standard.out_sharding (
NamedSharding|P|None) – (optional)PartitionSpecorNamedShardingrepresenting the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying bothout_shardinganddevicewill result in an error.
- Return type:
- Returns:
Array of the specified shape and dtype, with the given device/sharding if specified.
Examples
>>> jnp.empty(4) Array([0., 0., 0., 0.], dtype=float32) >>> jnp.empty((2, 3), dtype=bool) Array([[False, False, False], [False, False, False]], dtype=bool)
- scico.numpy.empty_like(prototype, dtype=None, shape=None, *, device=None)¶
Create an empty array with the same shape and dtype as an array.
JAX implementation of
numpy.empty_like. Because XLA cannot create an un-initialized array,jax.numpy.emptywill always return an array full of zeros.- Parameters:
a – Array-like object with
shapeanddtypeattributes.shape (
Any) – optionally override the shape of the created array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally override the dtype of the created array.device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed.
- Return type:
- Returns:
Array of the specified shape and dtype, on the specified device if specified.
Examples
>>> x = jnp.arange(4) >>> jnp.empty_like(x) Array([0, 0, 0, 0], dtype=int32) >>> jnp.empty_like(x, dtype=bool) Array([False, False, False, False], dtype=bool) >>> jnp.empty_like(x, shape=(2, 3)) Array([[0, 0, 0], [0, 0, 0]], dtype=int32)
- scico.numpy.equal(x, y, /)¶
Returns element-wise truth value of
x == y.JAX implementation of
numpy.equal. This function provides the implementation of the==operator for JAX arrays.- Parameters:
- Return type:
- Returns:
A boolean array containing
Truewhere the elements ofx == yandFalseotherwise.
See also
jax.numpy.not_equal: Returns element-wise truth value ofx != y.jax.numpy.greater_equal: Returns element-wise truth value ofx >= y.jax.numpy.less_equal: Returns element-wise truth value ofx <= y.jax.numpy.greater: Returns element-wise truth value ofx > y.jax.numpy.less: Returns element-wise truth value ofx < y.
Examples
>>> jnp.equal(0., -0.) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(1, 1.) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(5, jnp.array(5)) Array(True, dtype=bool, weak_type=True) >>> jnp.equal(2, -2) Array(False, dtype=bool, weak_type=True) >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> y = jnp.array([1, 5, 9]) >>> jnp.equal(x, y) Array([[ True, False, False], [False, True, False], [False, False, True]], dtype=bool) >>> x == y Array([[ True, False, False], [False, True, False], [False, False, True]], dtype=bool)
- scico.numpy.exp(x, /)¶
Calculate element-wise exponential of the input.
JAX implementation of
numpy.exp.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar- Return type:
- Returns:
An array containing the exponential of each element in
x, promotes to inexact dtype.
See also
jax.numpy.log: Calculates element-wise logarithm of the input.jax.numpy.expm1: Calculates \(e^x-1\) of each element of the input.jax.numpy.exp2: Calculates base-2 exponential of each element of the input.
Examples
jnp.expfollows the properties of exponential such as \(e^{(a+b)} = e^a * e^b\).>>> x1 = jnp.array([2, 4, 3, 1]) >>> x2 = jnp.array([1, 3, 2, 3]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x1+x2)) [ 20.09 1096.63 148.41 54.6 ] >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x1)*jnp.exp(x2)) [ 20.09 1096.63 148.41 54.6 ]
This property holds for complex input also:
>>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j)) Array(True, dtype=bool)
- scico.numpy.exp2(x, /)¶
Calculate element-wise base-2 exponential of input.
JAX implementation of
numpy.exp2.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar- Return type:
- Returns:
An array containing the base-2 exponential of each element in
x, promotes to inexact dtype.
See also
jax.numpy.log2: Calculates base-2 logarithm of each element of input.jax.numpy.exp: Calculates exponential of each element of the input.jax.numpy.expm1: Calculates \(e^x-1\) of each element of the input.
Examples
jnp.exp2follows the properties of the exponential such as \(2^{a+b} = 2^a * 2^b\).>>> x1 = jnp.array([2, -4, 3, -1]) >>> x2 = jnp.array([-1, 3, -2, 3]) >>> jnp.exp2(x1+x2) Array([2. , 0.5, 2. , 4. ], dtype=float32) >>> jnp.exp2(x1)*jnp.exp2(x2) Array([2. , 0.5, 2. , 4. ], dtype=float32)
- scico.numpy.expand_dims(a, axis)¶
Insert dimensions of length 1 into array
JAX implementation of
numpy.expand_dims, implemented viajax.lax.expand_dims.- Parameters:
- Return type:
- Returns:
Copy of
awith added dimensions.
Notes
Unlike
numpy.expand_dims,jax.numpy.expand_dimswill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.squeeze: inverse of this operation, i.e. remove length-1 dimensions.jax.lax.expand_dims: XLA version of this functionality.
Examples
>>> x = jnp.array([1, 2, 3]) >>> x.shape (3,)
Expand the leading dimension:
>>> jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> _.shape (1, 3)
Expand the trailing dimension:
>>> jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> _.shape (3, 1)
Expand multiple dimensions:
>>> jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32) >>> _.shape (1, 1, 3, 1)
Dimensions can also be expanded more succinctly by indexing with
None:>>> x[None] # equivalent to jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> x[:, None] # equivalent to jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> x[None, None, :, None] # equivalent to jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32)
- scico.numpy.expm1(x, /)¶
Calculate
exp(x)-1of each element of the input.JAX implementation of
numpy.expm1.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing
exp(x)-1of each element inx, promotes to inexact dtype.
Note
jnp.expm1has much higher precision than the naive computation ofexp(x)-1for small values ofx.See also
jax.numpy.log1p: Calculates element-wise logarithm of one plus input.jax.numpy.exp: Calculates element-wise exponential of the input.jax.numpy.exp2: Calculates base-2 exponential of each element of the input.
Examples
>>> x = jnp.array([2, -4, 3, -1]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.expm1(x)) [ 6.39 -0.98 19.09 -0.63] >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x)-1) [ 6.39 -0.98 19.09 -0.63]
For values very close to 0,
jnp.expm1(x)is much more accurate thanjnp.exp(x)-1:>>> x1 = jnp.array([1e-4, 1e-6, 2e-10]) >>> jnp.expm1(x1) Array([1.0000500e-04, 1.0000005e-06, 2.0000000e-10], dtype=float32) >>> jnp.exp(x1)-1 Array([1.00016594e-04, 9.53674316e-07, 0.00000000e+00], dtype=float32)
- scico.numpy.extract(condition, arr, *, size=None, fill_value=0)¶
Return the elements of an array that satisfy a condition.
JAX implementation of
numpy.extract.- Parameters:
condition (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of conditions. Will be converted to boolean and flattened to 1D.arr (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of values to extract. Will be flattened to 1D.size (
int|None) – optional static size for output. Must be specified in order forextractto be compatible with JAX transformations likejitorvmap.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – ifsizeis specified, fill padded entries with this value (default: 0).
- Return type:
- Returns:
1D array of extracted entries . If
sizeis specified, the result will have shape(size,)and be right-padded withfill_value. Ifsizeis not specified, the output shape will depend on the number of True entries incondition.
Notes
This function does not require strict shape agreement between
conditionandarr. Ifcondition.size > arr.size, thenconditionwill be truncated, and ifarr.size > condition.size, thenarrwill be truncated.See also
jax.numpy.compress: multi-dimensional version ofextract.Examples
Extract values from a 1D array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
In the simplest case, this is equivalent to boolean indexing:
>>> x[mask] Array([2, 4, 6], dtype=int32)
For use with JAX transformations, you can pass the
sizeargument to specify a static shape for the output, along with an optionalfill_valuethat defaults to zero:>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
Notice that unlike with boolean indexing,
extractdoes not require strict agreement between the sizes of the array and condition, and will effectively truncate both to the minimum size:>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)
- scico.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)¶
Create a square or rectangular identity matrix
JAX implementation of
numpy.eye.- Parameters:
N (
Union[int,Any]) – integer specifying the first dimension of the array.M (
Union[int,Any,None]) – optional integer specifying the second dimension of the array; defaults to the same value asN.k (
Union[int,Array,ndarray,bool,number,bool,float,complex]) – optional integer specifying the offset of the diagonal. Use positive values for upper diagonals, and negative values for lower diagonals. Default is zero.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype; defaults to floating point.device (
Device|Sharding|None) – optionalDeviceorShardingto which the created array will be committed.
- Return type:
- Returns:
Identity array of shape
(N, M), or(N, N)ifMis not specified.
See also
jax.numpy.identity: Simpler API for generating square identity matrices.Examples
A simple 3x3 identity matrix:
>>> jnp.eye(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Integer identity matrices with offset diagonals:
>>> jnp.eye(3, k=1, dtype=int) Array([[0, 1, 0], [0, 0, 1], [0, 0, 0]], dtype=int32) >>> jnp.eye(3, k=-1, dtype=int) Array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=int32)
Non-square identity matrix:
>>> jnp.eye(3, 5, k=1) Array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.]], dtype=float32)
- scico.numpy.fabs(x, /)¶
Compute the element-wise absolute values of the real-valued input.
JAX implementation of
numpy.fabs.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar. Must not have a complex dtype.- Return type:
- Returns:
An array with same shape as
xand dtype float, containing the element-wise absolute values.
See also
jax.numpy.absolute: Computes the absolute values of the input including complex dtypes.jax.numpy.abs: Computes the absolute values of the input including complex dtypes.
Examples
For integer inputs:
>>> x = jnp.array([-5, -9, 1, 10, 15]) >>> jnp.fabs(x) Array([ 5., 9., 1., 10., 15.], dtype=float32)
For float type inputs:
>>> x1 = jnp.array([-1.342, 5.649, 3.927]) >>> jnp.fabs(x1) Array([1.342, 5.649, 3.927], dtype=float32)
For boolean inputs:
>>> x2 = jnp.array([True, False]) >>> jnp.fabs(x2) Array([1., 0.], dtype=float32)
- scico.numpy.fill_diagonal(a, val, wrap=False, *, inplace=True)¶
Return a copy of the array with the diagonal overwritten.
JAX implementation of
numpy.fill_diagonal.The semantics of
numpy.fill_diagonalare to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplaceparameter which must be set to False` by the user as a reminder of this API difference.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array. Must havea.ndim >= 2. Ifa.ndim >= 3, then all dimensions must be the same size.val (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array with which to fill the diagonal. If an array, it will be flattened and repeated to fill the diagonal entries.wrap (
bool) – Not implemented by JAX. Only the default value ofFalseis supported.inplace (
bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.
- Return type:
- Returns:
A copy of
awith the diagonal set toval.
Examples
>>> x = jnp.zeros((3, 3), dtype=int) >>> jnp.fill_diagonal(x, jnp.array([1, 2, 3]), inplace=False) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32)
Unlike
numpy.fill_diagonal, the inputxis not modified.If the diagonal value has too many entries, it will be truncated
>>> jnp.fill_diagonal(x, jnp.arange(100, 200), inplace=False) Array([[100, 0, 0], [ 0, 101, 0], [ 0, 0, 102]], dtype=int32)
If the diagonal has too few entries, it will be repeated:
>>> x = jnp.zeros((4, 4), dtype=int) >>> jnp.fill_diagonal(x, jnp.array([3, 4]), inplace=False) Array([[3, 0, 0, 0], [0, 4, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], dtype=int32)
For non-square arrays, the diagonal of the leading square slice is filled:
>>> x = jnp.zeros((3, 5), dtype=int) >>> jnp.fill_diagonal(x, 1, inplace=False) Array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]], dtype=int32)
And for square N-dimensional arrays, the N-dimensional diagonal is filled:
>>> y = jnp.zeros((2, 2, 2)) >>> jnp.fill_diagonal(y, 1, inplace=False) Array([[[1., 0.], [0., 0.]], [[0., 0.], [0., 1.]]], dtype=float32)
- scico.numpy.flatnonzero(a, *, size=None, fill_value=None)¶
Return indices of nonzero elements in a flattened array
JAX implementation of
numpy.flatnonzero.jnp.flatnonzero(x)is equivalent tononzero(ravel(a))[0]. For a full discussion of the parameters to this function, refer tojax.numpy.nonzero.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array.size (
int|None) – optional static integer specifying the number of nonzero entries to return. Seejax.numpy.nonzerofor more discussion of this parameter.fill_value (
Union[None,Array,ndarray,bool,number,bool,int,float,complex,tuple[Union[Array,ndarray,bool,number,bool,int,float,complex],...]]) – optional padding value whensizeis specified. Defaults to 0. Seejax.numpy.nonzerofor more discussion of this parameter.
- Return type:
- Returns:
Array containing the indices of each nonzero value in the flattened array.
See also
Examples
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 8]]) >>> jnp.flatnonzero(x) Array([1, 3, 5], dtype=int32)
This is equivalent to calling
nonzeroon the flattened array, and extracting the first entry in the resulting tuple:>>> jnp.nonzero(x.ravel())[0] Array([1, 3, 5], dtype=int32)
The returned indices can be used to extract nonzero entries from the flattened array:
>>> indices = jnp.flatnonzero(x) >>> x.ravel()[indices] Array([5, 6, 8], dtype=int32)
- scico.numpy.flip(m, axis=None)¶
Reverse the order of elements of an array along the given axis.
JAX implementation of
numpy.flip.- Parameters:
- Return type:
- Returns:
An array with the elements in reverse order along
axis.
See also
jax.numpy.fliplr: reverse the order along axis 1 (left/right)jax.numpy.flipud: reverse the order along axis 0 (up/down)
Examples
>>> x1 = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.flip(x1) Array([[4, 3], [2, 1]], dtype=int32)
If
axisis specified with an integer, thenjax.numpy.flipreverses the array along that particular axis only.>>> jnp.flip(x1, axis=1) Array([[2, 1], [4, 3]], dtype=int32)
>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2) >>> x2 Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=int32) >>> jnp.flip(x2) Array([[[8, 7], [6, 5]], [[4, 3], [2, 1]]], dtype=int32)
When
axisis specified with a sequence of integers, thenjax.numpy.flipreverses the array along the specified axes.>>> jnp.flip(x2, axis=[1, 2]) Array([[[4, 3], [2, 1]], [[8, 7], [6, 5]]], dtype=int32)
- scico.numpy.fliplr(m)¶
Reverse the order of elements of an array along axis 1.
JAX implementation of
numpy.fliplr.- Parameters:
m (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Array with at least two dimensions.- Return type:
- Returns:
An array with the elements in reverse order along axis 1.
See also
jax.numpy.flip: reverse the order along the given axisjax.numpy.flipud: reverse the order along axis 0
Examples
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.fliplr(x) Array([[2, 1], [4, 3]], dtype=int32)
- scico.numpy.flipud(m)¶
Reverse the order of elements of an array along axis 0.
JAX implementation of
numpy.flipud.- Parameters:
m (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Array with at least one dimension.- Return type:
- Returns:
An array with the elements in reverse order along axis 0.
See also
jax.numpy.flip: reverse the order along the given axisjax.numpy.fliplr: reverse the order along axis 1
Examples
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.flipud(x) Array([[3, 4], [1, 2]], dtype=int32)
- scico.numpy.float_power(x, y, /)¶
Calculate element-wise base
xexponential ofy.JAX implementation of
numpy.float_power.- Parameters:
- Return type:
- Returns:
An array containing the base
xexponentials ofy, promoting to the inexact dtype.
See also
jax.numpy.exp: Calculates element-wise exponential of the input.jax.numpy.exp2: Calculates base-2 exponential of each element of the input.
Examples
Inputs with same shape:
>>> x = jnp.array([3, 1, -5]) >>> y = jnp.array([2, 4, -1]) >>> jnp.float_power(x, y) Array([ 9. , 1. , -0.2], dtype=float32)
Inputs with broadcast compatibility:
>>> x1 = jnp.array([[2, -4, 1], ... [-1, 2, 3]]) >>> y1 = jnp.array([-2, 1, 4]) >>> jnp.float_power(x1, y1) Array([[ 0.25, -4. , 1. ], [ 1. , 2. , 81. ]], dtype=float32)
jnp.float_powerproducesnanfor negative values raised to a non-integer values.>>> jnp.float_power(-3, 1.7) Array(nan, dtype=float32, weak_type=True)
- scico.numpy.floor(x, /)¶
Round input to the nearest integer downwards.
JAX implementation of
numpy.floor.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar. Must not have complex dtype.- Return type:
- Returns:
An array with same shape and dtype as
xcontaining the values rounded to the nearest integer that is less than or equal to the value itself.
See also
jax.numpy.fix: Rounds the input to the nearest integer towards zero.jax.numpy.trunc: Rounds the input to the nearest integer towards zero.jax.numpy.ceil: Rounds the input up to the nearest integer.
Examples
>>> key = jax.random.key(42) >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) >>> with jnp.printoptions(precision=2, suppress=True): ... print(x) [[-0.11 1.8 1.16] [ 0.61 -0.49 0.86] [-4.25 2.75 1.99]] >>> jnp.floor(x) Array([[-1., 1., 1.], [ 0., -1., 0.], [-5., 2., 1.]], dtype=float32)
- scico.numpy.floor_divide(x1, x2, /)¶
Calculates the floor division of x1 by x2 element-wise
JAX implementation of
numpy.floor_divide.- Parameters:
- Return type:
- Returns:
An array-like object containing each of the quotients rounded down to the nearest integer towards negative infinity. This is equivalent to
x1 // x2in Python.
Note
x1 // x2is equivalent tojnp.floor_divide(x1, x2)for arraysx1andx2See also
jax.numpy.divideandjax.numpy.true_dividefor floating point division.Examples
>>> x1 = jnp.array([10, 20, 30]) >>> x2 = jnp.array([3, 4, 7]) >>> jnp.floor_divide(x1, x2) Array([3, 5, 4], dtype=int32)
>>> x1 = jnp.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]) >>> x2 = 3 >>> jnp.floor_divide(x1, x2) Array([-2, -2, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=int32)
>>> x1 = jnp.array([6, 6, 6], dtype=jnp.int32) >>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32) >>> jnp.floor_divide(x1, x2) Array([3., 2., 2.], dtype=float32)
- scico.numpy.fmax(x1, x2)¶
Return element-wise maximum of the input arrays.
JAX implementation of
numpy.fmax.- Parameters:
- Return type:
- Returns:
An array containing the element-wise maximum of x1 and x2.
Note
- For each pair of elements,
jnp.fmaxreturns: the larger of the two if both elements are finite numbers.
finite number if one element is
nan.nanif both elements arenan.infif one element isinfand the other is finite ornan.-infif one element is-infand the other isnan.
Examples
>>> jnp.fmax(3, 7) Array(7, dtype=int32, weak_type=True) >>> jnp.fmax(5, jnp.array([1, 7, 9, 4])) Array([5, 7, 9, 5], dtype=int32)
>>> x1 = jnp.array([1, 3, 7, 8]) >>> x2 = jnp.array([-1, 4, 6, 9]) >>> jnp.fmax(x1, x2) Array([1, 4, 7, 9], dtype=int32)
>>> x3 = jnp.array([[2, 3, 5, 10], ... [11, 9, 7, 5]]) >>> jnp.fmax(x1, x3) Array([[ 2, 3, 7, 10], [11, 9, 7, 8]], dtype=int32)
>>> x4 = jnp.array([jnp.inf, 6, -jnp.inf, nan]) >>> x5 = jnp.array([[3, 5, 7, nan], ... [nan, 9, nan, -1]]) >>> jnp.fmax(x4, x5) Array([[ inf, 6., 7., nan], [ inf, 9., -inf, -1.]], dtype=float32)
- scico.numpy.fmin(x1, x2)¶
Return element-wise minimum of the input arrays.
JAX implementation of
numpy.fmin.- Parameters:
- Return type:
- Returns:
An array containing the element-wise minimum of x1 and x2.
Note
- For each pair of elements,
jnp.fminreturns: the smaller of the two if both elements are finite numbers.
finite number if one element is
nan.-infif one element is-infand the other is finite ornan.infif one element isinfand the other isnan.nanif both elements arenan.
Examples
>>> jnp.fmin(2, 3) Array(2, dtype=int32, weak_type=True) >>> jnp.fmin(2, jnp.array([1, 4, 2, -1])) Array([ 1, 2, 2, -1], dtype=int32)
>>> x1 = jnp.array([1, 3, 2]) >>> x2 = jnp.array([2, 1, 4]) >>> jnp.fmin(x1, x2) Array([1, 1, 2], dtype=int32)
>>> x3 = jnp.array([1, 5, 3]) >>> x4 = jnp.array([[2, 3, 1], ... [5, 6, 7]]) >>> jnp.fmin(x3, x4) Array([[1, 3, 1], [1, 5, 3]], dtype=int32)
>>> nan = jnp.nan >>> x5 = jnp.array([jnp.inf, 5, nan]) >>> x6 = jnp.array([[2, 3, nan], ... [nan, 6, 7]]) >>> jnp.fmin(x5, x6) Array([[ 2., 3., nan], [inf, 5., 7.]], dtype=float32)
- scico.numpy.fmod(x1, x2, /)¶
Calculate element-wise floating-point modulo operation.
JAX implementation of
numpy.fmod.- Parameters:
- Return type:
- Returns:
An array containing the result of the element-wise floating-point modulo operation of
x1andx2with same sign as the elements ofx1.
Note
The result of
jnp.fmodis equivalent tox1 - x2 * jnp.trunc(x1 / x2).See also
jax.numpy.modandjax.numpy.remainder: Returns the element-wise remainder of the division.jax.numpy.divmod: Calculates the integer quotient and remainder ofx1byx2, element-wise.
Examples
>>> x1 = jnp.array([[3, -1, 4], ... [8, 5, -2]]) >>> x2 = jnp.array([2, 3, -5]) >>> jnp.fmod(x1, x2) Array([[ 1, -1, 4], [ 0, 2, -2]], dtype=int32) >>> x1 - x2 * jnp.trunc(x1 / x2) Array([[ 1., -1., 4.], [ 0., 2., -2.]], dtype=float32)
- scico.numpy.frexp(x, /)¶
Split floating point values into mantissa and twos exponent.
JAX implementation of
numpy.frexp.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – real-valued array- Return type:
- Returns:
A tuple
(mantissa, exponent)wheremantissais a floating point value between -1 and 1, andexponentis an integer such thatx == mantissa * 2 ** exponent.
See also
jax.numpy.ldexp: compute the inverse offrexp.
Examples
Split values into mantissa and exponent:
>>> x = jnp.array([1., 2., 3., 4., 5.]) >>> m, e = jnp.frexp(x) >>> m Array([0.5 , 0.5 , 0.75 , 0.5 , 0.625], dtype=float32) >>> e Array([1, 2, 2, 3, 3], dtype=int32)
Reconstruct the original array:
>>> m * 2 ** e Array([1., 2., 3., 4., 5.], dtype=float32)
- scico.numpy.from_dlpack(x, /, *, device=None, copy=None)¶
Construct a JAX array via DLPack.
JAX implementation of
numpy.from_dlpack.- Parameters:
x (
Any) – An object that implements the DLPack protocol via the__dlpack__and__dlpack_device__methods, or a legacy DLPack tensor on either CPU or GPU.device (
Device|Sharding|None) – An optionalDeviceorSharding, representing the single device onto which the returned array should be placed. If given, then the result is committed to the device. If unspecified, the resulting array will be unpacked onto the same device it originated from. Settingdeviceto a device different from the source ofexternal_arraywill require a copy, meaningcopymust be set to eitherTrueorNone.copy (
bool|None) – An optional boolean, controlling whether or not a copy is performed. Ifcopy=Truethen a copy is always performed, even if unpacked onto the same device. Ifcopy=Falsethen the copy is never performed and will raise an error if necessary. Whencopy=None(default) then a copy may be performed if needed for a device transfer.
- Return type:
- Returns:
A JAX array of the input buffer.
Note
While JAX arrays are always immutable, dlpack buffers cannot be marked as immutable, and it is possible for processes external to JAX to mutate them in-place. If a JAX Array is constructed from a dlpack buffer without copying and the source buffer is later modified in-place, it may lead to undefined behavior when using the associated JAX array.
Examples
Passing data between NumPy and JAX via DLPack:
>>> import numpy as np >>> rng = np.random.default_rng(42) >>> x_numpy = rng.random(4, dtype='float32') >>> print(x_numpy) [0.08925092 0.773956 0.6545715 0.43887842] >>> hasattr(x_numpy, "__dlpack__") # NumPy supports the DLPack interface True
>>> import jax.numpy as jnp >>> x_jax = jnp.from_dlpack(x_numpy) >>> print(x_jax) [0.08925092 0.773956 0.6545715 0.43887842] >>> hasattr(x_jax, "__dlpack__") # JAX supports the DLPack interface True
>>> x_numpy_round_trip = np.from_dlpack(x_jax) >>> print(x_numpy_round_trip) [0.08925092 0.773956 0.6545715 0.43887842]
- scico.numpy.frombuffer(buffer, dtype=<class 'float'>, count=-1, offset=0)¶
Convert a buffer into a 1-D JAX array.
JAX implementation of
numpy.frombuffer.- Parameters:
buffer (
bytes|Any) – an object containing the data. It must be either a bytes object with a length that is an integer multiple of the dtype element size, or it must be an object exporting the Python buffer interface.dtype (
Union[str,type[Any],dtype,SupportsDType]) – optional. Desired data type for the array. Default isfloat64. This specifies the dtype used to parse the buffer, but note that after parsing, 64-bit values will be cast to 32-bit JAX arrays if thejax_enable_x64flag is set toFalse.count (
int) – optional integer specifying the number of items to read from the buffer. If -1 (default), all items from the buffer are read.offset (
int) – optional integer specifying the number of bytes to skip at the beginning of the buffer. Default is 0.
- Return type:
- Returns:
A 1-D JAX array representing the interpreted data from the buffer.
See also
jax.numpy.fromstring: convert a string of text into 1-D JAX array.
Examples
Using a bytes buffer:
>>> buf = b"\x00\x01\x02\x03\x04" >>> jnp.frombuffer(buf, dtype=jnp.uint8) Array([0, 1, 2, 3, 4], dtype=uint8) >>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1) Array([1, 2, 3, 4], dtype=uint8)
Constructing a JAX array via the Python buffer interface, using Python’s built-in
arraymodule.>>> from array import array >>> pybuffer = array('i', [0, 1, 2, 3, 4]) >>> jnp.frombuffer(pybuffer, dtype=jnp.int32) Array([0, 1, 2, 3, 4], dtype=int32)
- scico.numpy.fromfile(*args, **kwargs)¶
Unimplemented JAX wrapper for jnp.fromfile.
This function is left deliberately unimplemented because it may be non-pure and thus unsafe for use with JIT and other JAX transformations. Consider using
jnp.asarray(np.fromfile(...))instead, although care should be taken ifnp.fromfileis used within jax transformations because of its potential side-effect of consuming the file object; for more information see Common Gotchas: Pure Functions.
- scico.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)¶
Create an array from a function applied over indices.
JAX implementation of
numpy.fromfunction. The JAX implementation differs in that it dispatches viajax.vmap, and so unlike in NumPy the function logically operates on scalar inputs, and need not explicitly handle broadcasted inputs (See Examples below).- Parameters:
function (
Callable[...,Array]) – a function that takes N dynamic scalars and outputs a scalar.shape (
Any) – a length-N tuple of integers specifying the output shape.dtype (
Union[str,type[Any],dtype,SupportsDType]) – optionally specify the dtype of the inputs. Defaults to floating-point.kwargs – additional keyword arguments are passed statically to
function.
- Return type:
- Returns:
An array of shape
shapeiffunctionreturns a scalar, or in general a pytree of arrays with leading dimensionsshape, as determined by the output offunction.
See also
jax.vmap: the core transformation that thefromfunctionAPI is built on.
Examples
Generate a multiplication table of a given shape:
>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int) Array([[ 0, 0, 0, 0, 0, 0], [ 0, 1, 2, 3, 4, 5], [ 0, 2, 4, 6, 8, 10]], dtype=int32)
When
functionreturns a non-scalar the output will have leading dimension ofshape:>>> def f(x): ... return (x + 1) * jnp.arange(3) >>> jnp.fromfunction(f, shape=(2,)) Array([[0., 1., 2.], [0., 2., 4.]], dtype=float32)
functionmay return multiple results, in which case each is mapped independently:>>> def f(x, y): ... return x + y, x * y >>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5)) >>> print(x_plus_y) [[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]] >>> print(x_times_y) [[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]
The JAX implementation differs slightly from NumPy’s implementation. In
numpy.fromfunction, the function is expected to explicitly operate element-wise on the full grid of input values:>>> def f(x, y): ... print(f"{x.shape = }\n{y.shape = }") ... return x + y ... >>> np.fromfunction(f, (2, 3)) x.shape = (2, 3) y.shape = (2, 3) array([[0., 1., 2.], [1., 2., 3.]])
In
jax.numpy.fromfunction, the function is vectorized viajax.vmap, and so is expected to operate on scalar values:>>> jnp.fromfunction(f, (2, 3)) x.shape = () y.shape = () Array([[0., 1., 2.], [1., 2., 3.]], dtype=float32)
- scico.numpy.fromiter(*args, **kwargs)¶
Unimplemented JAX wrapper for jnp.fromiter.
This function is left deliberately unimplemented because it may be non-pure and thus unsafe for use with JIT and other JAX transformations. Consider using
jnp.asarray(np.fromiter(...))instead, although care should be taken ifnp.fromiteris used within jax transformations because of its potential side-effect of consuming the iterable object; for more information see Common Gotchas: Pure Functions.
- scico.numpy.frompyfunc(func, /, nin, nout, *, identity=None)¶
Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
- Parameters:
func (
Callable[...,Any]) – a callable that takes nin scalar arguments and returns nout outputs.nin (
int) – integer specifying the number of scalar inputsnout (
int) – integer specifying the number of scalar outputsidentity (
Any) – (optional) a scalar specifying the identity of the operation, if any.
- Return type:
- Returns:
wrapped – jax.numpy.ufunc wrapper of func.
Examples
Here is an example of creating a ufunc similar to
jax.numpy.add:>>> import operator >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)
Now all the standard
jax.numpy.ufuncmethods are available:>>> x = jnp.arange(4) >>> add(x, 10) Array([10, 11, 12, 13], dtype=int32) >>> add.outer(x, x) Array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]], dtype=int32) >>> add.reduce(x) Array(6, dtype=int32) >>> add.accumulate(x) Array([0, 1, 3, 6], dtype=int32) >>> add.at(x, 1, 10, inplace=False) Array([ 0, 11, 2, 3], dtype=int32)
- scico.numpy.fromstring(string, dtype=<class 'float'>, count=-1, *, sep)¶
Convert a string of text into 1-D JAX array.
JAX implementation of
numpy.fromstring.- Parameters:
string (
str) – input string containing the data.dtype (
Union[str,type[Any],dtype,SupportsDType]) – optional. Desired data type for the array. Default isfloat.count (
int) – optional integer specifying the number of items to read from the string. If -1 (default), all items are read.sep (
str) – the string used to separate values in the input string.
- Return type:
- Returns:
A 1-D JAX array containing the parsed data from the input string.
See also
jax.numpy.frombuffer: construct a JAX array from an object that implements the buffer interface.
Examples
>>> jnp.fromstring("1 2 3", dtype=int, sep=" ") Array([1, 2, 3], dtype=int32) >>> jnp.fromstring("0.1, 0.2, 0.3", dtype=float, count=2, sep=",") Array([0.1, 0.2], dtype=float32)
- scico.numpy.full(shape, fill_value, dtype=None, *, device=None)¶
Create an array full of a specified value.
JAX implementation of
numpy.full.- Parameters:
shape (
Any) – int or sequence of ints specifying the shape of the created array.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array with which to fill the created array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype for the created array; defaults to the dtype of the fill value.device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed.
- Return type:
- Returns:
Array of the specified shape and dtype, on the specified device if specified.
Examples
>>> jnp.full(4, 2, dtype=float) Array([2., 2., 2., 2.], dtype=float32) >>> jnp.full((2, 3), 0, dtype=bool) Array([[False, False, False], [False, False, False]], dtype=bool)
fill_value may also be an array that is broadcast to the specified shape:
>>> jnp.full((2, 3), fill_value=jnp.arange(3)) Array([[0, 1, 2], [0, 1, 2]], dtype=int32)
- scico.numpy.full_like(a, fill_value, dtype=None, shape=None, *, device=None)¶
Create an array full of a specified value with the same shape and dtype as an array.
JAX implementation of
numpy.full_like.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex,DuckTypedArray]) – Array-like object withshapeanddtypeattributes.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array with which to fill the created array.shape (
Any) – optionally override the shape of the created array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally override the dtype of the created array.device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed.
- Return type:
- Returns:
Array of the specified shape and dtype, on the specified device if specified.
Examples
>>> x = jnp.arange(4.0) >>> jnp.full_like(x, 2) Array([2., 2., 2., 2.], dtype=float32) >>> jnp.full_like(x, 0, shape=(2, 3)) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
fill_value may also be an array that is broadcast to the specified shape:
>>> x = jnp.arange(6).reshape(2, 3) >>> jnp.full_like(x, fill_value=jnp.array([[1], [2]])) Array([[1, 1, 1], [2, 2, 2]], dtype=int32)
- scico.numpy.gcd(x1, x2)¶
Compute the greatest common divisor of two arrays.
JAX implementation of
numpy.gcd.- Parameters:
- Return type:
- Returns:
An array containing the greatest common divisors of the corresponding elements from the absolute values of x1 and x2.
See also
jax.numpy.lcm: compute the least common multiple of two arrays.
Examples
Scalar inputs:
>>> jnp.gcd(12, 18) Array(6, dtype=int32, weak_type=True)
Array inputs:
>>> x1 = jnp.array([12, 18, 24]) >>> x2 = jnp.array([5, 10, 15]) >>> jnp.gcd(x1, x2) Array([1, 2, 3], dtype=int32)
Broadcasting:
>>> x1 = jnp.array([12]) >>> x2 = jnp.array([6, 9, 12]) >>> jnp.gcd(x1, x2) Array([ 6, 3, 12], dtype=int32)
- scico.numpy.geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0)¶
Generate geometrically-spaced values.
JAX implementation of
numpy.geomspace.- Parameters:
start (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the starting values.stop (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the stop values.num (
int) – int, optional, default=50. Number of values to generate.endpoint (
bool) – bool, optional, default=True. If True, then include thestopvalue in the result. If False, then exclude thestopvalue.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional. Specifies the dtype of the output.axis (
int) – int, optional, default=0. Axis along which to generate the geomspace.
- Return type:
- Returns:
An array containing the geometrically-spaced values.
See also
jax.numpy.arange: GenerateNevenly-spaced values given a starting point and a step value.jax.numpy.linspace: Generate evenly-spaced values.jax.numpy.logspace: Generate logarithmically-spaced values.
Examples
List 5 geometrically-spaced values between 1 and 16:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.geomspace(1, 16, 5) Array([ 1., 2., 4., 8., 16.], dtype=float32)
List 4 geomtrically-spaced values between 1 and 16, with
endpoint=False:>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.geomspace(1, 16, 4, endpoint=False) Array([1., 2., 4., 8.], dtype=float32)
Multi-dimensional geomspace:
>>> start = jnp.array([1, 1000]) >>> stop = jnp.array([27, 1]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.geomspace(start, stop, 4) Array([[ 1., 1000.], [ 3., 100.], [ 9., 10.], [ 27., 1.]], dtype=float32)
- scico.numpy.get_printoptions()¶
Alias of
numpy.get_printoptions.JAX arrays are printed via NumPy, so NumPy’s printoptions configurations will apply to printed JAX arrays.
See the
numpy.set_printoptionsdocumentation for details on the available options and their meanings.
- scico.numpy.gradient(f, *varargs, axis=None, edge_order=None)¶
Compute the numerical gradient of a sampled function.
JAX implementation of
numpy.gradient.The gradient in
jnp.gradientis computed using second-order finite differences across the array of sampled function values. This should not be confused withjax.grad, which computes a precise gradient of a callable function via automatic differentiation.- Parameters:
f (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array of function values.varargs (
Union[Array,ndarray,bool,number,bool,int,float,complex]) –optional list of scalars or arrays specifying spacing of function evaluations. Options are:
not specified: unit spacing in all dimensions.
a single scalar: constant spacing in all dimensions.
N values: specify different spacing in each dimension:
scalar values indicate constant spacing in that dimension.
array values must match the length of the corresponding dimension, and specify the coordinates at which
fis evaluated.
axis (
int|Sequence[int] |None) – integer or tuple of integers specifying the axis along which to compute the gradient. If None (default) calculates the gradient along all axes.
- Return type:
- Returns:
an array or tuple of arrays containing the numerical gradient along each specified axis.
See also
jax.grad: automatic differentiation of a function with a single output.
Examples
Comparing numerical and automatic differentiation of a simple function:
>>> def f(x): ... return jnp.sin(x) * jnp.exp(-x / 4) ... >>> def gradf_exact(x): ... # exact analytical gradient of f(x) ... return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4) ... >>> x = jnp.linspace(0, 5, 10)
>>> with jnp.printoptions(precision=2, suppress=True): ... print("numerical gradient:", jnp.gradient(f(x), x)) ... print("automatic gradient:", jax.vmap(jax.grad(f))(x)) ... print("exact gradient: ", gradf_exact(x)) ... numerical gradient: [ 0.83 0.61 0.18 -0.2 -0.43 -0.49 -0.39 -0.21 -0.02 0.08] automatic gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15] exact gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15]
Notice that, as expected, the numerical gradient has some approximation error compared to the automatic gradient computed via
jax.grad.
- scico.numpy.greater(x, y, /)¶
Return element-wise truth value of
x > y.JAX implementation of
numpy.greater.- Parameters:
- Return type:
- Returns:
An array containing boolean values.
Trueif the elements ofx > y, andFalseotherwise.
See also
jax.numpy.less: Returns element-wise truth value ofx < y.jax.numpy.greater_equal: Returns element-wise truth value ofx >= y.jax.numpy.less_equal: Returns element-wise truth value ofx <= y.
Examples
Scalar inputs:
>>> jnp.greater(5, 2) Array(True, dtype=bool, weak_type=True)
Inputs with same shape:
>>> x = jnp.array([5, 9, -2]) >>> y = jnp.array([4, -1, 6]) >>> jnp.greater(x, y) Array([ True, True, False], dtype=bool)
Inputs with broadcast compatibility:
>>> x1 = jnp.array([[5, -6, 7], ... [-2, 5, 9]]) >>> y1 = jnp.array([-4, 3, 10]) >>> jnp.greater(x1, y1) Array([[ True, False, False], [ True, True, False]], dtype=bool)
- scico.numpy.greater_equal(x, y, /)¶
Return element-wise truth value of
x >= y.JAX implementation of
numpy.greater_equal.- Parameters:
- Return type:
- Returns:
An array containing boolean values.
Trueif the elements ofx >= y, andFalseotherwise.
See also
jax.numpy.less_equal: Returns element-wise truth value ofx <= y.jax.numpy.greater: Returns element-wise truth value ofx > y.jax.numpy.less: Returns element-wise truth value ofx < y.
Examples
Scalar inputs:
>>> jnp.greater_equal(4, 7) Array(False, dtype=bool, weak_type=True)
Inputs with same shape:
>>> x = jnp.array([2, 5, -1]) >>> y = jnp.array([-6, 4, 3]) >>> jnp.greater_equal(x, y) Array([ True, True, False], dtype=bool)
Inputs with broadcast compatibility:
>>> x1 = jnp.array([[3, -1, 4], ... [5, 9, -6]]) >>> y1 = jnp.array([-1, 4, 2]) >>> jnp.greater_equal(x1, y1) Array([[ True, False, True], [ True, True, False]], dtype=bool)
- scico.numpy.hamming(M)¶
Return a Hamming window of size M.
JAX implementation of
numpy.hamming.- Parameters:
M (
int) – The window size.- Return type:
- Returns:
An array of size M containing the Hamming window.
Examples
>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.hamming(4)) [0.08 0.77 0.77 0.08]
See also
jax.numpy.bartlett: return a Bartlett window of size M.jax.numpy.blackman: return a Blackman window of size M.jax.numpy.hanning: return a Hanning window of size M.jax.numpy.kaiser: return a Kaiser window of size M.
- scico.numpy.hanning(M)¶
Return a Hanning window of size M.
JAX implementation of
numpy.hanning.- Parameters:
M (
int) – The window size.- Return type:
- Returns:
An array of size M containing the Hanning window.
Examples
>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.hanning(4)) [0. 0.75 0.75 0. ]
See also
jax.numpy.bartlett: return a Bartlett window of size M.jax.numpy.blackman: return a Blackman window of size M.jax.numpy.hamming: return a Hamming window of size M.jax.numpy.kaiser: return a Kaiser window of size M.
- scico.numpy.heaviside(x1, x2, /)¶
Compute the heaviside step function.
JAX implementation of
numpy.heaviside.The heaviside step function is defined by:
\[\begin{split}\mathrm{heaviside}(x1, x2) = \begin{cases} 0, & x1 < 0\\ x2, & x1 = 0\\ 1, & x1 > 0. \end{cases}\end{split}\]- Parameters:
x1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.complexdtype are not supported.x2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the return values whenx1is0.complexdtype are not supported.x1andx2must either have same shape or broadcast compatible.
- Return type:
- Returns:
An array containing the heaviside step function of
x1, promoting to inexact dtype.
Examples
>>> x1 = jnp.array([[-2, 0, 3], ... [5, -1, 0], ... [0, 7, -3]]) >>> x2 = jnp.array([2, 0.5, 1]) >>> jnp.heaviside(x1, x2) Array([[0. , 0.5, 1. ], [1. , 0. , 1. ], [2. , 1. , 0. ]], dtype=float32) >>> jnp.heaviside(x1, 0.5) Array([[0. , 0.5, 1. ], [1. , 0. , 0.5], [0.5, 1. , 0. ]], dtype=float32) >>> jnp.heaviside(-3, x2) Array([0., 0., 0.], dtype=float32)
- scico.numpy.histogram(a, bins=10, range=None, weights=None, density=None)¶
Compute a 1-dimensional histogram.
JAX implementation of
numpy.histogram.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of values to be binned. May be any size or dimension.bins (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Specify the number of bins in the histogram (default: 10).binsmay also be an array specifying the locations of the bin edges.range (
Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]] |None) – tuple of scalars. Specifies the range of the data. If not specified, the range is inferred from the data.weights (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – An optional array specifying the weights of the data points. Should be broadcast-compatible witha. If not specified, each data point is weighted equally.density (
bool|None) – If True, return the normalized histogram in units of counts per unit length. If False (default) return the (weighted) counts per bin.
- Return type:
- Returns:
A tuple of arrays
(histogram, bin_edges), wherehistogramcontains the aggregated data, andbin_edgesspecifies the boundaries of the bins.
See also
jax.numpy.bincount: Count the number of occurrences of each value in an array.jax.numpy.histogram2d: Compute the histogram of a 2D array.jax.numpy.histogramdd: Compute the histogram of an N-dimensional array.jax.numpy.histogram_bin_edges: Compute the bin edges for a histogram.
Examples
>>> a = jnp.array([1, 2, 3, 10, 11, 15, 19, 25]) >>> counts, bin_edges = jnp.histogram(a, bins=8) >>> print(counts) [3. 0. 0. 2. 1. 0. 1. 1.] >>> print(bin_edges) [ 1. 4. 7. 10. 13. 16. 19. 22. 25.]
Specifying the bin range:
>>> counts, bin_edges = jnp.histogram(a, range=(0, 25), bins=5) >>> print(counts) [3. 0. 2. 2. 1.] >>> print(bin_edges) [ 0. 5. 10. 15. 20. 25.]
Specifying the bin edges explicitly:
>>> bin_edges = jnp.array([0, 10, 20, 30]) >>> counts, _ = jnp.histogram(a, bins=bin_edges) >>> print(counts) [3. 4. 1.]
Using
density=Truereturns a normalized histogram:>>> density, bin_edges = jnp.histogram(a, density=True) >>> dx = jnp.diff(bin_edges) >>> normed_sum = jnp.sum(density * dx) >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool)
- scico.numpy.histogram2d(x, y, bins=10, range=None, weights=None, density=None)¶
Compute a 2-dimensional histogram.
JAX implementation of
numpy.histogram2d.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – one-dimensional array of x-values for points to be binned.y (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – one-dimensional array of y-values for points to be binned.bins (
Union[Array,ndarray,bool,number,bool,int,float,complex,list[Union[Array,ndarray,bool,number,bool,int,float,complex]]]) – Specify the number of bins in the histogram (default: 10).binsmay also be an array specifying the locations of the bin edges, or a pair of integers or pair of arrays specifying the number of bins in each dimension.range (
Sequence[None|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]] |None) – Pair of arrays or lists of the form[[xmin, xmax], [ymin, ymax]]specifying the range of the data in each dimension. If not specified, the range is inferred from the data.weights (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – An optional array specifying the weights of the data points. Should be the same shape asxandy. If not specified, each data point is weighted equally.density (
bool|None) – If True, return the normalized histogram in units of counts per unit area. If False (default) return the (weighted) counts per bin.
- Return type:
- Returns:
A tuple of arrays
(histogram, x_edges, y_edges), wherehistogramcontains the aggregated data, andx_edgesandy_edgesspecify the boundaries of the bins.
See also
jax.numpy.histogram: Compute the histogram of a 1D array.jax.numpy.histogramdd: Compute the histogram of an N-dimensional array.jax.numpy.histogram_bin_edges: Compute the bin edges for a histogram.
Examples
>>> x = jnp.array([1, 2, 3, 10, 11, 15, 19, 25]) >>> y = jnp.array([2, 5, 6, 8, 13, 16, 17, 18]) >>> counts, x_edges, y_edges = jnp.histogram2d(x, y, bins=8) >>> counts.shape (8, 8) >>> x_edges Array([ 1., 4., 7., 10., 13., 16., 19., 22., 25.], dtype=float32) >>> y_edges Array([ 2., 4., 6., 8., 10., 12., 14., 16., 18.], dtype=float32)
Specifying the bin range:
>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, range=[(0, 25), (0, 25)], bins=5) >>> counts.shape (5, 5) >>> x_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32) >>> y_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32)
Specifying the bin edges explicitly:
>>> x_edges = jnp.array([0, 10, 20, 30]) >>> y_edges = jnp.array([0, 10, 20, 30]) >>> counts, _, _ = jnp.histogram2d(x, y, bins=[x_edges, y_edges]) >>> counts Array([[3, 0, 0], [1, 3, 0], [0, 1, 0]], dtype=int32)
Using
density=Truereturns a normalized histogram:>>> density, x_edges, y_edges = jnp.histogram2d(x, y, density=True) >>> dx = jnp.diff(x_edges) >>> dy = jnp.diff(y_edges) >>> normed_sum = jnp.sum(density * dx[:, None] * dy[None, :]) >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool)
- scico.numpy.histogram_bin_edges(a, bins=10, range=None, weights=None)¶
Compute the bin edges for a histogram.
JAX implementation of
numpy.histogram_bin_edges.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of values to be binnedbins (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Specify the number of bins in the histogram (default: 10).range (
None|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – tuple of scalars. Specifies the range of the data. If not specified, the range is inferred from the data.weights (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – unused by JAX.
- Return type:
- Returns:
An array of bin edges for the histogram.
See also
jax.numpy.histogram: compute a 1D histogram.jax.numpy.histogram2d: compute a 2D histogram.jax.numpy.histogramdd: compute an N-dimensional histogram.
Examples
>>> a = jnp.array([2, 5, 3, 6, 4, 1]) >>> jnp.histogram_bin_edges(a, bins=5) Array([1., 2., 3., 4., 5., 6.], dtype=float32) >>> jnp.histogram_bin_edges(a, bins=5, range=(-10, 10)) Array([-10., -6., -2., 2., 6., 10.], dtype=float32)
- scico.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)¶
Compute an N-dimensional histogram.
JAX implementation of
numpy.histogramdd.- Parameters:
sample (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array of shape(N, D)representingNpoints inDdimensions.bins (
Union[Array,ndarray,bool,number,bool,int,float,complex,list[Union[Array,ndarray,bool,number,bool,int,float,complex]]]) – Specify the number of bins in each dimension of the histogram. (default: 10). May also be a length-D sequence of integers or arrays of bin edges.range (
Sequence[None|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]] |None) – Length-D sequence of pairs specifying the range for each dimension. If not specified, the range is inferred from the data.weights (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – An optional shape(N,)array specifying the weights of the data points. Should be the same shape assample. If not specified, each data point is weighted equally.density (
bool|None) – If True, return the normalized histogram in units of counts per unit volume. If False (default) return the (weighted) counts per bin.
- Return type:
- Returns:
A tuple of arrays
(histogram, bin_edges), wherehistogramcontains the aggregated data, andbin_edgesspecifies the boundaries of the bins.
See also
jax.numpy.histogram: Compute the histogram of a 1D array.jax.numpy.histogram2d: Compute the histogram of a 2D array.jax.numpy.histogram_bin_edges: Compute the bin edges for a histogram.
Examples
A histogram over 100 points in three dimensions
>>> key = jax.random.key(42) >>> a = jax.random.normal(key, (100, 3)) >>> counts, bin_edges = jnp.histogramdd(a, bins=6, ... range=[(-3, 3), (-3, 3), (-3, 3)]) >>> counts.shape (6, 6, 6) >>> bin_edges [Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32)]
Using
density=Truereturns a normalized histogram:>>> density, bin_edges = jnp.histogramdd(a, density=True) >>> bin_widths = map(jnp.diff, bin_edges) >>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij') >>> normed = jnp.sum(density * dx * dy * dz) >>> jnp.allclose(normed, 1.0) Array(True, dtype=bool)
- scico.numpy.hsplit(ary, indices_or_sections)¶
Split an array into sub-arrays horizontally.
JAX implementation of
numpy.hsplit.Refer to the documentation of
jax.numpy.splitfor details.hsplitis equivalent tosplitwithaxis=1, oraxis=0for one-dimensional arrays.Examples
1D array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> x1, x2 = jnp.hsplit(x, 2) >>> print(x1, x2) [1 2 3] [4 5 6]
2D array:
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8]]) >>> x1, x2 = jnp.hsplit(x, 2) >>> print(x1) [[1 2] [5 6]] >>> print(x2) [[3 4] [7 8]]
See also
jax.numpy.split: split an array along any axis.jax.numpy.vsplit: split vertically, i.e. along axis=0jax.numpy.dsplit: split depth-wise, i.e. along axis=2jax.numpy.array_split: likesplit, but allowsindices_or_sectionsto be an integer that does not evenly divide the size of the array.
- scico.numpy.hstack(tup, dtype=None)¶
Horizontally stack arrays.
JAX implementation of
numpy.hstack.For arrays of one or more dimensions, this is equivalent to
jax.numpy.concatenatewithaxis=1.- Parameters:
tup (
ndarray|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – a sequence of arrays to stack; each must have the same shape along all but the second axis. Input arrays will be promoted to at least rank 1. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Return type:
- Returns:
the stacked result.
See also
jax.numpy.stack: stack along arbitrary axesjax.numpy.concatenate: concatenation along existing axes.jax.numpy.vstack: stack vertically, i.e. along axis 0.jax.numpy.dstack: stack depth-wise, i.e. along axis 2.
Examples
Scalar values:
>>> jnp.hstack([1, 2, 3]) Array([1, 2, 3], dtype=int32, weak_type=True)
1D arrays:
>>> x = jnp.arange(3) >>> y = jnp.ones(3) >>> jnp.hstack([x, y]) Array([0., 1., 2., 1., 1., 1.], dtype=float32)
2D arrays:
>>> x = x.reshape(3, 1) >>> y = y.reshape(3, 1) >>> jnp.hstack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)
- scico.numpy.hypot(x1, x2, /)¶
Return element-wise hypotenuse for the given legs of a right angle triangle.
JAX implementation of
numpy.hypot.- Parameters:
x1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies one of the legs of right angle triangle.complexdtype are not supported.x2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the other leg of right angle triangle.complexdtype are not supported.x1andx2must either have same shape or be broadcast compatible.
- Return type:
- Returns:
An array containing the hypotenuse for the given given legs
x1andx2of a right angle triangle, promoting to inexact dtype.
Note
jnp.hypotis a more numerically stable way of computingjnp.sqrt(x1 ** 2 + x2 **2).Examples
>>> jnp.hypot(3, 4) Array(5., dtype=float32, weak_type=True) >>> x1 = jnp.array([[3, -2, 5], ... [9, 1, -4]]) >>> x2 = jnp.array([-5, 6, 8]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.hypot(x1, x2) Array([[ 5.831, 6.325, 9.434], [10.296, 6.083, 8.944]], dtype=float32)
- scico.numpy.i0(x)¶
Calculate modified Bessel function of first kind, zeroth order.
JAX implementation of
numpy.i0.Modified Bessel function of first kind, zeroth order is defined by:
\[\mathrm{i0}(x) = I_0(x) = \sum_{k=0}^{\infty} \frac{(x^2/4)^k}{(k!)^2}\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the argument of Bessel function. Complex inputs are not supported.- Return type:
- Returns:
An array containing the corresponding values of the modified Bessel function of
x.
See also
jax.scipy.special.i0: Calculates the modified Bessel function of zeroth order.jax.scipy.special.i1: Calculates the modified Bessel function of first order.jax.scipy.special.i0e: Calculates the exponentially scaled modified Bessel function of zeroth order.
Examples
>>> x = jnp.array([-2, -1, 0, 1, 2]) >>> jnp.i0(x) Array([2.2795851, 1.266066 , 1.0000001, 1.266066 , 2.2795851], dtype=float32)
- scico.numpy.identity(n, dtype=None)¶
Create a square identity matrix
JAX implementation of
numpy.identity.- Parameters:
- Return type:
- Returns:
Identity array of shape
(n, n).
See also
jax.numpy.eye: non-square and/or offset identity matrices.Examples
A simple 3x3 identity matrix:
>>> jnp.identity(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
A 2x2 integer identity matrix:
>>> jnp.identity(2, dtype=int) Array([[1, 0], [0, 1]], dtype=int32)
- scico.numpy.imag(val, /)¶
Return element-wise imaginary of part of the complex argument.
JAX implementation of
numpy.imag.- Parameters:
val (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the imaginary part of the elements of
val.
See also
jax.numpy.conjugateandjax.numpy.conj: Returns the element-wise complex-conjugate of the input.jax.numpy.real: Returns the element-wise real part of the complex argument.
Examples
>>> jnp.imag(4) Array(0, dtype=int32, weak_type=True) >>> jnp.imag(5j) Array(5., dtype=float32, weak_type=True) >>> x = jnp.array([2+3j, 5-1j, -3]) >>> jnp.imag(x) Array([ 3., -1., 0.], dtype=float32)
- scico.numpy.indices(dimensions, dtype=None, sparse=False)¶
Generate arrays of grid indices.
JAX implementation of
numpy.indices.- Parameters:
- Return type:
- Returns:
An array of shape
(len(dimensions), *dimensions)Ifsparseis False, or a sequence of arrays of the same length asdimensionsifsparseis True.
See also
jax.numpy.meshgrid: generate a grid from arbitrary input arrays.jax.numpy.mgrid: generate dense indices using a slicing syntax.jax.numpy.ogrid: generate sparse indices using a slicing syntax.
Examples
>>> jnp.indices((2, 3)) Array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], dtype=int32) >>> jnp.indices((2, 3), sparse=True) (Array([[0], [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))
- scico.numpy.inner(a, b, *, precision=None, preferred_element_type=None)¶
Compute the inner product of two arrays.
JAX implementation of
numpy.inner.Unlike
jax.numpy.matmulorjax.numpy.dot, this always performs a contraction along the last dimension of each input.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., N)b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., N)precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofaandb.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
array of shape
(*a.shape[:-1], *b.shape[:-1])containing the batched vector product of the inputs.
See also
jax.numpy.vecdot: conjugate multiplication along a specified axis.jax.numpy.tensordot: general tensor multiplication.jax.numpy.matmul: general batched matrix & vector multiplication.
Examples
For 1D inputs, this implements standard (non-conjugate) vector multiplication:
>>> a = jnp.array([1j, 3j, 4j]) >>> b = jnp.array([4., 2., 5.]) >>> jnp.inner(a, b) Array(0.+30.j, dtype=complex64)
For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:
>>> a = jnp.ones((2, 3)) >>> b = jnp.ones((5, 3)) >>> jnp.inner(a, b).shape (2, 5)
- scico.numpy.insert(arr, obj, values, axis=None)¶
Insert entries into an array at specified indices.
JAX implementation of
numpy.insert.- Parameters:
arr (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array object into which values will be inserted.obj (
Union[Array,ndarray,bool,number,bool,int,float,complex,slice]) – slice or array of indices specifying insertion locations.values (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of values to be inserted.axis (
int|None) – specify the insertion axis in the case of multi-dimensional arrays. If unspecified,arrwill be flattened.
- Return type:
- Returns:
A copy of
arrwith values inserted at the specified locations.
See also
jax.numpy.delete: delete entries from an array.
Examples
Inserting a single value:
>>> x = jnp.arange(5) >>> jnp.insert(x, 2, 99) Array([ 0, 1, 99, 2, 3, 4], dtype=int32)
Inserting multiple identical values using a slice:
>>> jnp.insert(x, slice(None, None, 2), -1) Array([-1, 0, 1, -1, 2, 3, -1, 4], dtype=int32)
Inserting multiple values using an index:
>>> indices = jnp.array([4, 2, 5]) >>> values = jnp.array([10, 11, 12]) >>> jnp.insert(x, indices, values) Array([ 0, 1, 11, 2, 3, 10, 4, 12], dtype=int32)
Inserting columns into a 2D array:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> indices = jnp.array([1, 3]) >>> values = jnp.array([[10, 11], ... [12, 13]]) >>> jnp.insert(x, indices, values, axis=1) Array([[ 1, 10, 2, 3, 11], [ 4, 12, 5, 6, 13]], dtype=int32)
- scico.numpy.interp(x, xp, fp, left=None, right=None, period=None)¶
One-dimensional linear interpolation.
JAX implementation of
numpy.interp.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array of x coordinates at which to evaluate the interpolation.xp (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – one-dimensional sorted array of points to be interpolated.fp (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shapexp.shapecontaining the function values associated withxp.left (
Union[Array,ndarray,bool,number,bool,int,float,complex,str,None]) – specify how to handle pointsx < xp[0]. Default is to returnfp[0]. Ifleftis a scalar value, it will return this value. ifleftis the string"extrapolate", then the value will be determined by linear extrapolation.leftis ignored ifperiodis specified.right (
Union[Array,ndarray,bool,number,bool,int,float,complex,str,None]) – specify how to handle pointsx > xp[-1]. Default is to returnfp[-1]. Ifrightis a scalar value, it will return this value. ifrightis the string"extrapolate", then the value will be determined by linear extrapolation.rightis ignored ifperiodis specified.period (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optionally specify the period for the x coordinates, for e.g. interpolation in angular space.
- Return type:
- Returns:
an array of shape
x.shapecontaining the interpolated function at valuesx.
Examples
>>> xp = jnp.arange(10) >>> fp = 2 * xp >>> x = jnp.array([0.5, 2.0, 3.5]) >>> interp(x, xp, fp) Array([1., 4., 7.], dtype=float32)
Unless otherwise specified, extrapolation will be constant:
>>> x = jnp.array([-10., 10.]) >>> interp(x, xp, fp) Array([ 0., 18.], dtype=float32)
Use
"extrapolate"mode for linear extrapolation:>>> interp(x, xp, fp, left='extrapolate', right='extrapolate') Array([-20., 20.], dtype=float32)
For periodic interpolation, specify the
period:>>> xp = jnp.array([0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2]) >>> fp = jnp.sin(xp) >>> x = 2 * jnp.pi # note: not in input array >>> jnp.interp(x, xp, fp, period=2 * jnp.pi) Array(0., dtype=float32)
- scico.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)¶
Compute the set intersection of two 1D arrays.
JAX implementation of
numpy.intersect1d.Because the size of the output of
intersect1dis data-dependent, the function is not typically compatible withjitand other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.intersect1dto be used in such contexts.- Parameters:
ar1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first array of values to intersect.ar2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second array of values to intersect.assume_unique (
bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but ifassume_uniqueis True and the input arrays contain duplicates, the behavior is undefined. default: False.return_indices (
bool) – If True, return arrays of indices specifying where the intersected values first appear in the input arrays.size (
int|None) – if specified, return only the firstsizesorted elements. If there are fewer elements thansizeindicates, the return value will be padded withfill_value, and returned indices will be padded with an out-of-bound index.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the smallest value in the intersection.
- Return type:
- Returns:
An array
intersection, or ifreturn_indices=True, a tuple of arrays(intersection, ar1_indices, ar2_indices). Returned values areintersection: A 1D array containing each value that appears in bothar1andar2.ar1_indices: (returned if return_indices=True) an array of shapeintersection.shapecontaining the indices in flattenedar1of values inintersection. For 1D inputs,intersectionis equivalent toar1[ar1_indices].ar2_indices: (returned if return_indices=True) an array of shapeintersection.shapecontaining the indices in flattenedar2of values inintersection. For 1D inputs,intersectionis equivalent toar2[ar2_indices].
See also
jax.numpy.union1d: the set union of two 1D arrays.jax.numpy.setxor1d: the set XOR of two 1D arrays.jax.numpy.setdiff1d: the set difference of two 1D arrays.
Examples
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
Computing intersection with indices:
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) >>> intersection Array([3, 4], dtype=int32)
ar1_indicesgives the indices of the intersected values withinar1:>>> ar1_indices Array([2, 3], dtype=int32) >>> jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indicesgives the indices of the intersected values withinar2:>>> ar2_indices Array([0, 1], dtype=int32) >>> jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)
- scico.numpy.isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)¶
Check if the elements of two arrays are approximately equal within a tolerance.
JAX implementation of
numpy.allclose.Essentially this function evaluates the following condition:
\[|a - b| \le \mathtt{atol} + \mathtt{rtol} * |b|\]jnp.infinawill be considered equal tojnp.infinb.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first input array to compare.b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second input array to compare.rtol (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – relative tolerance used for approximate equality. Default = 1e-05.atol (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – absolute tolerance used for approximate equality. Default = 1e-08.equal_nan (
bool) – Boolean. IfTrue, NaNs inawill be considered equal to NaNs inb. Default isFalse.
- Return type:
- Returns:
A new array containing boolean values indicating whether the input arrays are element-wise approximately equal within the specified tolerances.
See also
Examples
>>> jnp.isclose(jnp.array([1e6, 2e6, jnp.inf]), jnp.array([1e6, 2e7, jnp.inf])) Array([ True, False, True], dtype=bool) >>> jnp.isclose(jnp.array([1e6, 2e6, 3e6]), ... jnp.array([1.00008e6, 2.00008e7, 3.00008e8]), rtol=1e3) Array([ True, True, True], dtype=bool) >>> jnp.isclose(jnp.array([1e6, 2e6, 3e6]), ... jnp.array([1.00001e6, 2.00002e6, 3.00009e6]), atol=1e3) Array([ True, True, True], dtype=bool) >>> jnp.isclose(jnp.array([jnp.nan, 1, 2]), ... jnp.array([jnp.nan, 1, 2]), equal_nan=True) Array([ True, True, True], dtype=bool)
- scico.numpy.iscomplex(x)¶
Return boolean array showing where the input is complex.
JAX implementation of
numpy.iscomplex.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array to check.- Return type:
- Returns:
A new array containing boolean values indicating complex elements.
Examples
>>> jnp.iscomplex(jnp.array([True, 0, 1, 2j, 1+2j])) Array([False, False, False, True, True], dtype=bool)
- scico.numpy.iscomplexobj(x)¶
Check if the input is a complex number or an array containing complex elements.
JAX implementation of
numpy.iscomplexobj.The function evaluates based on input type rather than value. Inputs with zero imaginary parts are still considered complex.
- Parameters:
x (
Any) – input object to check.- Return type:
- Returns:
True if
xis a complex number or an array containing at least one complex element, False otherwise.
See also
Examples
>>> jnp.iscomplexobj(True) False >>> jnp.iscomplexobj(0) False >>> jnp.iscomplexobj(jnp.array([1, 2])) False >>> jnp.iscomplexobj(1+2j) True >>> jnp.iscomplexobj(jnp.array([0, 1+2j])) True
- scico.numpy.isdtype(dtype, kind)¶
Returns a boolean indicating whether a provided dtype is of a specified kind.
- Parameters:
dtype (
Union[str,type[Any],dtype,SupportsDType]) – the input dtypekind (
Union[str,type[Any],dtype,SupportsDType,tuple[Union[str,type[Any],dtype,SupportsDType],...]]) –the data type kind. If
kindis dtype-like, returndtype = kind. Ifkindis a string, then return True if the dtype is in the specified category:'bool':{bool}'signed integer':{int4, int8, int16, int32, int64}'unsigned integer':{uint4, uint8, uint16, uint32, uint64}'integral': shorthand for('signed integer', 'unsigned integer')'real floating':{float8_*, float16, bfloat16, float32, float64}'complex floating':{complex64, complex128}'numeric': shorthand for('integral', 'real floating', 'complex floating')
If
kindis a tuple, then return True if dtype matches any entry of the tuple.
- Return type:
- Returns:
True or False
- scico.numpy.isfinite(x, /)¶
Return a boolean array indicating whether each element of input is finite.
JAX implementation of
numpy.isfinite.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
A boolean array of same shape as
xcontainingTruewherexis notinf,-inf, orNaN, andFalseotherwise.
See also
jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity.jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity.jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).
Examples
>>> x = jnp.array([-1, 3, jnp.inf, jnp.nan]) >>> jnp.isfinite(x) Array([ True, True, False, False], dtype=bool) >>> jnp.isfinite(3-4j) Array(True, dtype=bool, weak_type=True)
- scico.numpy.isin(element, test_elements, assume_unique=False, invert=False, *, method='auto')¶
Determine whether elements in
elementappear intest_elements.JAX implementation of
numpy.isin.- Parameters:
element (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array of elements for which membership will be checked.test_elements (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array of test values to check for the presence of each element.invert (
bool) – If True, return~isin(element, test_elements). Default is False.assume_unique (
bool) – if true, input arrays are assumed to be unique, which can lead to more efficient computation. If the input arrays are not unique and assume_unique is set to True, the results are undefined.method – string specifying the method used to compute the result. Supported options are ‘compare_all’, ‘binary_search’, ‘sort’, and ‘auto’ (default).
- Return type:
- Returns:
A boolean array of shape
element.shapethat specifies whether each element appears intest_elements.
Examples
>>> elements = jnp.array([1, 2, 3, 4]) >>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]]) >>> jnp.isin(elements, test_elements) Array([ True, False, True, False], dtype=bool)
- scico.numpy.isinf(x, /)¶
Return a boolean array indicating whether each element of input is infinite.
JAX implementation of
numpy.isinf.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
A boolean array of same shape as
xcontainingTruewherexisinfor-inf, andFalseotherwise.
See also
jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity.jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite.jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).
Examples
>>> jnp.isinf(jnp.inf) Array(True, dtype=bool) >>> x = jnp.array([2+3j, -jnp.inf, 6, jnp.inf, jnp.nan]) >>> jnp.isinf(x) Array([False, True, False, True, False], dtype=bool)
- scico.numpy.isnan(x, /)¶
Returns a boolean array indicating whether each element of input is
NaN.JAX implementation of
numpy.isnan.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
A boolean array of same shape as
xcontainingTruewherexis not a number (i.e.NaN) andFalseotherwise.
See also
jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite.jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity.jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity.jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.
Examples
>>> jnp.isnan(6) Array(False, dtype=bool, weak_type=True) >>> x = jnp.array([2, 1+4j, jnp.inf, jnp.nan]) >>> jnp.isnan(x) Array([False, False, False, True], dtype=bool)
- scico.numpy.isneginf(x, /, out=None)¶
Return boolean array indicating whether each element of input is negative infinite.
JAX implementation of
numpy.isneginf.- Parameters:
x – input array or scalar.
complexdtype are not supported.- Returns:
A boolean array of same shape as
xcontainingTruewherexis-inf, andFalseotherwise.
See also
jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity.jax.numpy.isposinf: Returns a boolean array indicating whether each element of input is positive infinity.jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite.jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).
Examples
>>> jnp.isneginf(jnp.inf) Array(False, dtype=bool) >>> x = jnp.array([-jnp.inf, 5, jnp.inf, jnp.nan, 1]) >>> jnp.isneginf(x) Array([ True, False, False, False, False], dtype=bool)
- scico.numpy.isposinf(x, /, out=None)¶
Return boolean array indicating whether each element of input is positive infinite.
JAX implementation of
numpy.isposinf.- Parameters:
x – input array or scalar.
complexdtype are not supported.- Returns:
A boolean array of same shape as
xcontainingTruewherexisinf, andFalseotherwise.
See also
jax.numpy.isinf: Returns a boolean array indicating whether each element of input is either positive or negative infinity.jax.numpy.isneginf: Returns a boolean array indicating whether each element of input is negative infinity.jax.numpy.isfinite: Returns a boolean array indicating whether each element of input is finite.jax.numpy.isnan: Returns a boolean array indicating whether each element of input is not a number (NaN).
Examples
>>> jnp.isposinf(5) Array(False, dtype=bool) >>> x = jnp.array([-jnp.inf, 5, jnp.inf, jnp.nan, 1]) >>> jnp.isposinf(x) Array([False, False, True, False, False], dtype=bool)
- scico.numpy.isreal(x)¶
Return boolean array showing where the input is real.
JAX implementation of
numpy.isreal.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array to check.- Return type:
- Returns:
A new array containing boolean values indicating real elements.
See also
Examples
>>> jnp.isreal(jnp.array([False, 0j, 1, 2.1, 1+2j])) Array([ True, True, True, True, False], dtype=bool)
- scico.numpy.isrealobj(x)¶
Check if the input is not a complex number or an array containing complex elements.
JAX implementation of
numpy.isrealobj.The function evaluates based on input type rather than value. Inputs with zero imaginary parts are still considered complex.
- Parameters:
x (
Any) – input object to check.- Return type:
- Returns:
False if
xis a complex number or an array containing at least one complex element, True otherwise.
See also
Examples
>>> jnp.isrealobj(0) True >>> jnp.isrealobj(1.2) True >>> jnp.isrealobj(jnp.array([1, 2])) True >>> jnp.isrealobj(1+2j) False >>> jnp.isrealobj(jnp.array([0, 1+2j])) False
- scico.numpy.isscalar(element)¶
Return True if the input is a scalar.
JAX implementation of
numpy.isscalar. JAX’s implementation differs from NumPy’s in that it considers zero-dimensional arrays to be scalars; see the Note below for more details.- Parameters:
element (
Any) – input object to check; any type is valid input.- Return type:
- Returns:
True if
elementis a scalar value or an array-like object with zero dimensions, False otherwise.
Note
JAX and NumPy differ in their representation of scalar values. NumPy has special scalar objects (e.g.
np.int32(0)) which are distinct from zero-dimensional arrays (e.g.np.array(0)), andnumpy.isscalarreturnsTruefor the former andFalsefor the latter.JAX does not define special scalar objects, but rather represents scalars as zero-dimensional arrays. As such,
jax.numpy.isscalarreturnsTruefor both scalar objects (e.g.0.0ornp.float32(0.0)) and array-like objects with zero dimensions (e.g.jnp.array(0.0),np.array(0.0)).One reason for the different conventions in
isscalaris to maintain JIT-invariance: i.e. the property that the result of a function should not change when it is JIT-compiled. Because scalar inputs are cast to zero-dimensional JAX arrays at JIT boundaries, the semantics ofnumpy.isscalarare such that the result changes under JIT:>>> np.isscalar(1.0) True >>> jax.jit(np.isscalar)(1.0) Array(False, dtype=bool)
By treating zero-dimensional arrays as scalars,
jax.numpy.isscalaravoids this issue:>>> jnp.isscalar(1.0) True >>> jax.jit(jnp.isscalar)(1.0) Array(True, dtype=bool)
Examples
In JAX, both scalars and zero-dimensional array-like objects are considered scalars:
>>> jnp.isscalar(1.0) True >>> jnp.isscalar(1 + 1j) True >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array True >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor True >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array True >>> jnp.isscalar(np.int32(1)) # NumPy scalar type True
Arrays with one or more dimension are not considered scalars:
>>> jnp.isscalar(jnp.array([1])) False >>> jnp.isscalar(np.array([1])) False
Compare this to
numpy.isscalar, which returnsTruefor scalar-typed objects, andFalsefor all arrays, even those with zero dimensions:>>> np.isscalar(np.int32(1)) # scalar object True >>> np.isscalar(np.array(1)) # zero-dimensional array False
In JAX, as in NumPy, objects which are not array-like are not considered scalars:
>>> jnp.isscalar(None) False >>> jnp.isscalar([1]) False >>> jnp.isscalar(()) False >>> jnp.isscalar(slice(10)) False
- scico.numpy.issubdtype(arg1, arg2)¶
Return True if arg1 is equal or lower than arg2 in the type hierarchy.
JAX implementation of
numpy.issubdtype.The main difference in JAX’s implementation is that it properly handles dtype extensions such as
bfloat16.- Parameters:
arg1 (
Union[str,type[Any],dtype,SupportsDType]) – dtype-like object. In typical usage, this will be a dtype specifier, such as"float32"(i.e. a string),np.dtype('int32')(i.e. an instance ofnumpy.dtype),jnp.complex64(i.e. a JAX scalar constructor), ornp.uint8(i.e. a NumPy scalar type).arg2 (
Union[str,type[Any],dtype,SupportsDType]) – dtype-like object. In typical usage, this will be a generic scalar type, such asjnp.integer,jnp.floating, orjnp.complexfloating.
- Return type:
- Returns:
True if arg1 represents a dtype that is equal or lower in the type hierarchy than arg2.
See also
jax.numpy.isdtype: similar function aligning with the array API standard.
Examples
>>> jnp.issubdtype('uint32', jnp.unsignedinteger) True >>> jnp.issubdtype(np.int32, jnp.integer) True >>> jnp.issubdtype(jnp.bfloat16, jnp.floating) True >>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating) True >>> jnp.issubdtype('complex64', jnp.integer) False
Be aware that while this is very similar to
numpy.issubdtype, the results of these differ in the case of JAX’s custom floating point types:>>> np.issubdtype('bfloat16', np.floating) False >>> jnp.issubdtype('bfloat16', jnp.floating) True
- scico.numpy.iterable(y)¶
Check whether or not an object can be iterated over.
- Parameters:
y (object) – Input object.
- Returns:
b (bool) – Return
Trueif the object has an iterator method or is a sequence andFalseotherwise.
Examples
>>> import numpy as np >>> np.iterable([1, 2, 3]) True >>> np.iterable(2) False
Notes
In most cases, the results of
np.iterable(obj)are consistent withisinstance(obj, collections.abc.Iterable). One notable exception is the treatment of 0-dimensional arrays:>>> from collections.abc import Iterable >>> a = np.array(1.0) # 0-dimensional numpy array >>> isinstance(a, Iterable) True >>> np.iterable(a) False
- scico.numpy.ix_(*args)¶
Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.
JAX implementation of
numpy.ix_.- Parameters:
*args (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N one-dimensional arrays- Return type:
- Returns:
Tuple of Jax arrays forming an open mesh, each with N dimensions.
Examples
>>> rows = jnp.array([0, 2]) >>> cols = jnp.array([1, 3]) >>> open_mesh = jnp.ix_(rows, cols) >>> open_mesh (Array([[0], [2]], dtype=int32), Array([[1, 3]], dtype=int32)) >>> [grid.shape for grid in open_mesh] [(2, 1), (1, 2)] >>> x = jnp.array([[10, 20, 30, 40], ... [50, 60, 70, 80], ... [90, 100, 110, 120], ... [130, 140, 150, 160]]) >>> x[open_mesh] Array([[ 20, 40], [100, 120]], dtype=int32)
- scico.numpy.kaiser(M, beta)¶
Return a Kaiser window of size M.
JAX implementation of
numpy.kaiser.- Parameters:
- Return type:
- Returns:
An array of size M containing the Kaiser window.
Examples
>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.kaiser(4, 1.5)) [0.61 0.95 0.95 0.61]
See also
jax.numpy.bartlett: return a Bartlett window of size M.jax.numpy.blackman: return a Blackman window of size M.jax.numpy.hamming: return a Hamming window of size M.jax.numpy.hanning: return a Hanning window of size M.
- scico.numpy.kron(a, b)¶
Compute the Kronecker product of two input arrays.
JAX implementation of
numpy.kron.The Kronecker product is an operation on two matrices of arbitrary size that produces a block matrix. Each element of the first matrix
ais multiplied by the entire second matrixb. Ifahas shape (m, n) andbhas shape (p, q), the resulting matrix will have shape (m * p, n * q).- Parameters:
- Return type:
- Returns:
A new array representing the Kronecker product of the inputs
aandb. The shape of the output is the element-wise product of the input shapes.
See also
jax.numpy.outer: compute the outer product of two arrays.
Examples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([[5, 6], ... [7, 8]]) >>> jnp.kron(a, b) Array([[ 5, 6, 10, 12], [ 7, 8, 14, 16], [15, 18, 20, 24], [21, 24, 28, 32]], dtype=int32)
- scico.numpy.lcm(x1, x2)¶
Compute the least common multiple of two arrays.
JAX implementation of
numpy.lcm.- Parameters:
- Return type:
- Returns:
An array containing the least common multiple of the corresponding elements from the absolute values of x1 and x2.
See also
jax.numpy.gcd: compute the greatest common divisor of two arrays.
Examples
Scalar inputs:
>>> jnp.lcm(12, 18) Array(36, dtype=int32, weak_type=True)
Array inputs:
>>> x1 = jnp.array([12, 18, 24]) >>> x2 = jnp.array([5, 10, 15]) >>> jnp.lcm(x1, x2) Array([ 60, 90, 120], dtype=int32)
Broadcasting:
>>> x1 = jnp.array([12]) >>> x2 = jnp.array([6, 9, 12]) >>> jnp.lcm(x1, x2) Array([12, 36, 12], dtype=int32)
- scico.numpy.ldexp(x1, x2, /)¶
Compute x1 * 2 ** x2
JAX implementation of
numpy.ldexp.Note that XLA does not provide an
ldexpoperation, so this is implemneted in JAX via a standard multiplication and exponentiation.- Parameters:
- Return type:
- Returns:
x1 * 2 ** x2computed element-wise.
See also
jax.numpy.frexp: decompose values into mantissa and exponent.
Examples
>>> x1 = jnp.arange(5.0) >>> x2 = 10 >>> jnp.ldexp(x1, x2) Array([ 0., 1024., 2048., 3072., 4096.], dtype=float32)
ldexpcan be used to reconstruct the input tofrexp:>>> x = jnp.array([2., 3., 5., 11.]) >>> m, e = jnp.frexp(x) >>> m Array([0.5 , 0.75 , 0.625 , 0.6875], dtype=float32) >>> e Array([2, 2, 3, 4], dtype=int32) >>> jnp.ldexp(m, e) Array([ 2., 3., 5., 11.], dtype=float32)
- scico.numpy.less(x, y, /)¶
Return element-wise truth value of
x < y.JAX implementation of
numpy.less.- Parameters:
- Return type:
- Returns:
An array containing boolean values.
Trueif the elements ofx < y, andFalseotherwise.
See also
jax.numpy.greater: Returns element-wise truth value ofx > y.jax.numpy.greater_equal: Returns element-wise truth value ofx >= y.jax.numpy.less_equal: Returns element-wise truth value ofx <= y.
Examples
Scalar inputs:
>>> jnp.less(3, 7) Array(True, dtype=bool, weak_type=True)
Inputs with same shape:
>>> x = jnp.array([5, 9, -3]) >>> y = jnp.array([1, 6, 4]) >>> jnp.less(x, y) Array([False, False, True], dtype=bool)
Inputs with broadcast compatibility:
>>> x1 = jnp.array([[2, -4, 6, -8], ... [-1, 5, -3, 7]]) >>> y1 = jnp.array([0, 3, -5, 9]) >>> jnp.less(x1, y1) Array([[False, True, False, True], [ True, False, False, True]], dtype=bool)
- scico.numpy.less_equal(x, y, /)¶
Return element-wise truth value of
x <= y.JAX implementation of
numpy.less_equal.- Parameters:
- Return type:
- Returns:
An array containing the boolean values.
Trueif the elements ofx <= y, andFalseotherwise.
See also
jax.numpy.greater_equal: Returns element-wise truth value ofx >= y.jax.numpy.greater: Returns element-wise truth value ofx > y.jax.numpy.less: Returns element-wise truth value ofx < y.
Examples
Scalar inputs:
>>> jnp.less_equal(6, -2) Array(False, dtype=bool, weak_type=True)
Inputs with same shape:
>>> x = jnp.array([-4, 1, 7]) >>> y = jnp.array([2, -3, 8]) >>> jnp.less_equal(x, y) Array([ True, False, True], dtype=bool)
Inputs with broadcast compatibility:
>>> x1 = jnp.array([2, -5, 9]) >>> y1 = jnp.array([[1, -6, 5], ... [-2, 4, -6]]) >>> jnp.less_equal(x1, y1) Array([[False, False, False], [False, True, False]], dtype=bool)
- scico.numpy.lexsort(keys, axis=-1)¶
Sort a sequence of keys in lexicographic order.
JAX implementation of
numpy.lexsort.- Parameters:
- Return type:
- Returns:
An array of integers of shape
keys[0].shapegiving the indices of the entries in lexicographically-sorted order.
See also
jax.numpy.argsort: sort a single entry by index.jax.lax.sort: direct XLA sorting API.
Examples
lexsortwith a single key is equivalent toargsort:>>> key1 = jnp.array([4, 2, 3, 2, 5]) >>> jnp.lexsort([key1]) Array([1, 3, 2, 0, 4], dtype=int32) >>> jnp.argsort(key1) Array([1, 3, 2, 0, 4], dtype=int32)
With multiple keys,
lexsortuses the last key as the primary key:>>> key2 = jnp.array([2, 1, 1, 2, 2]) >>> jnp.lexsort([key1, key2]) Array([1, 2, 3, 0, 4], dtype=int32)
The meaning of the indices become more clear when printing the sorted keys:
>>> indices = jnp.lexsort([key1, key2]) >>> print(f"{key1[indices]}\n{key2[indices]}") [2 3 2 4 5] [1 1 2 2 2]
Notice that the elements of
key2appear in order, and within the sequences of duplicated values the corresponding elements of`key1appear in order.For multi-dimensional inputs,
lexsortdefaults to sorting along the last axis:>>> key1 = jnp.array([[2, 4, 2, 3], ... [3, 1, 2, 2]]) >>> key2 = jnp.array([[1, 2, 1, 3], ... [2, 1, 2, 1]]) >>> jnp.lexsort([key1, key2]) Array([[0, 2, 1, 3], [1, 3, 2, 0]], dtype=int32)
A different sort axis can be chosen using the
axiskeyword; here we sort along the leading axis:>>> jnp.lexsort([key1, key2], axis=0) Array([[0, 1, 0, 1], [1, 0, 1, 0]], dtype=int32)
- scico.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)¶
Return evenly-spaced numbers within an interval.
JAX implementation of
numpy.linspace.- Parameters:
start (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array of starting values.stop (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array of stop values.num (
int) – number of values to generate. Default: 50.endpoint (
bool) – if True (default) then include thestopvalue in the result. If False, then exclude thestopvalue.retstep (
bool) – If True, then return a(result, step)tuple, wherestepis the interval between adjacent values inresult.axis (
int) – integer axis along which to generate the linspace. Defaults to zero.device (
Device|Sharding|None) – optionalDeviceorShardingto which the created array will be committed.
- Return type:
- Returns:
An array
values, or a tuple(values, step)ifretstepis True, where –valuesis an array of evenly-spaced values fromstarttostopstepis the interval between adjacent values.
See also
jax.numpy.arange: GenerateNevenly-spaced values given a starting point and a stepjax.numpy.logspace: Generate logarithmically-spaced values.jax.numpy.geomspace: Generate geometrically-spaced values.
Examples
List of 5 values between 0 and 10:
>>> jnp.linspace(0, 10, 5) Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32)
List of 8 values between 0 and 10, excluding the endpoint:
>>> jnp.linspace(0, 10, 8, endpoint=False) Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32)
List of values and the step size between them
>>> vals, step = jnp.linspace(0, 10, 9, retstep=True) >>> vals Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) >>> step Array(1.25, dtype=float32)
Multi-dimensional linspace:
>>> start = jnp.array([0, 5]) >>> stop = jnp.array([5, 10]) >>> jnp.linspace(start, stop, 5) Array([[ 0. , 5. ], [ 1.25, 6.25], [ 2.5 , 7.5 ], [ 3.75, 8.75], [ 5. , 10. ]], dtype=float32)
- scico.numpy.load(file, *args, **kwargs)¶
Load JAX arrays from npy files.
JAX wrapper of
numpy.load.This function is a simple wrapper of
numpy.load, but in the case of.npyfiles created withnumpy.saveorjax.numpy.save, the output will be returned as ajax.Array, andbfloat16data types will be restored. For.npzfiles, results will be returned as normal NumPy arrays.This function requires concrete array inputs, and is not compatible with transformations like
jax.jitorjax.vmap.- Parameters:
file (
Union[IO[bytes],str,PathLike[Any]]) – string, bytes, or path-like object containing the array data.args (
Any) – for additional arguments, seenumpy.loadkwargs (
Any) – for additional arguments, seenumpy.load
- Return type:
- Returns:
the array stored in the file.
See also
jax.numpy.save: save an array to a file.
Examples
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)
- scico.numpy.log(x, /)¶
Calculate element-wise natural logarithm of the input.
JAX implementation of
numpy.log.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the logarithm of each element in
x, promotes to inexact dtype.
See also
jax.numpy.exp: Calculates element-wise exponential of the input.jax.numpy.log2: Calculates base-2 logarithm of each element of input.jax.numpy.log1p: Calculates element-wise logarithm of one plus input.
Examples
jnp.logandjnp.expare inverse functions of each other. Applyingjnp.logon the result ofjnp.exp(x)yields the original inputx.>>> x = jnp.array([2, 3, 4, 5]) >>> jnp.log(jnp.exp(x)) Array([2., 3., 4., 5.], dtype=float32)
Using
jnp.logwe can demonstrate well-known properties of logarithms, such as \(log(a*b) = log(a)+log(b)\).>>> x1 = jnp.array([2, 1, 3, 1]) >>> x2 = jnp.array([1, 3, 2, 4]) >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) Array(True, dtype=bool)
- scico.numpy.log10(x, /)¶
Calculates the base-10 logarithm of x element-wise
JAX implementation of
numpy.log10.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array- Return type:
- Returns:
An array containing the base-10 logarithm of each element in
x, promotes to inexact dtype.
Examples
>>> x1 = jnp.array([0.01, 0.1, 1, 10, 100, 1000]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.log10(x1)) [-2. -1. 0. 1. 2. 3.]
- scico.numpy.log1p(x, /)¶
Calculates element-wise logarithm of one plus input,
log(x+1).JAX implementation of
numpy.log1p.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the logarithm of one plus of each element in
x, promotes to inexact dtype.
Note
jnp.log1pis more accurate than when using the naive computation oflog(x+1)for small values ofx.See also
jax.numpy.expm1: Calculates \(e^x-1\) of each element of the input.jax.numpy.log2: Calculates base-2 logarithm of each element of input.jax.numpy.log: Calculates element-wise logarithm of the input.
Examples
>>> x = jnp.array([2, 5, 9, 4]) >>> jnp.allclose(jnp.log1p(x), jnp.log(x+1)) Array(True, dtype=bool)
For values very close to 0,
jnp.log1p(x)is more accurate thanjnp.log(x+1):>>> x1 = jnp.array([1e-4, 1e-6, 2e-10]) >>> jnp.expm1(jnp.log1p(x1)) Array([1.00000005e-04, 9.99999997e-07, 2.00000003e-10], dtype=float32) >>> jnp.expm1(jnp.log(x1+1)) Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32)
- scico.numpy.log2(x, /)¶
Calculates the base-2 logarithm of
xelement-wise.JAX implementation of
numpy.log2.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array- Return type:
- Returns:
An array containing the base-2 logarithm of each element in
x, promotes to inexact dtype.
Examples
>>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8]) >>> jnp.log2(x1) Array([-2., -1., 0., 1., 2., 3.], dtype=float32)
- scico.numpy.logaddexp(*args: ArrayLike, out: None = None, where: None = None) Any¶
Compute
log(exp(x1) + exp(x2))avoiding overflow.JAX implementation of
numpy.logaddexp- Parameters:
x1 – input array
x2 – input array
- Returns:
array containing the result.
Examples:
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> result1 = jnp.logaddexp(x1, x2) >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) >>> print(jnp.allclose(result1, result2)) True
- scico.numpy.logaddexp2(*args: ArrayLike, out: None = None, where: None = None) Any¶
Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.
JAX implementation of
numpy.logaddexp2.- Parameters:
x1 – input array or scalar.
x2 – input array or scalar.
x1andx2should either have same shape or be broadcast compatible.
- Returns:
An array containing the result, \(log_2(2^{x1}+2^{x2})\), element-wise.
See also
jax.numpy.logaddexp: Computeslog(exp(x1) + exp(x2)), element-wise.jax.numpy.log2: Calculates the base-2 logarithm ofxelement-wise.
Examples
>>> x1 = jnp.array([[3, -1, 4], ... [8, 5, -2]]) >>> x2 = jnp.array([2, 3, -5]) >>> result1 = jnp.logaddexp2(x1, x2) >>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2)) >>> jnp.allclose(result1, result2) Array(True, dtype=bool)
- scico.numpy.logical_and(*args: ArrayLike, out: None = None, where: None = None) Any¶
Compute the logical AND operation elementwise.
JAX implementation of
numpy.logical_and. This is a universal function, and supports the additional APIs described atjax.numpy.ufunc.- Parameters:
x – input arrays. Must be broadcastable to a common shape.
y – input arrays. Must be broadcastable to a common shape.
- Returns:
Array containing the result of the element-wise logical AND.
Examples
>>> x = jnp.arange(4) >>> jnp.logical_and(x, 1) Array([False, True, True, True], dtype=bool)
- scico.numpy.logical_not(x, /)¶
Compute NOT bool(x) element-wise.
JAX implementation of
numpy.logical_not.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array of any dtype.- Return type:
- Returns:
A boolean array that computes NOT bool(x) element-wise
See also
jax.numpy.invertorjax.numpy.bitwise_invert: bitwise NOT operation
Examples
Compute NOT x element-wise on a boolean array:
>>> x = jnp.array([True, False, True]) >>> jnp.logical_not(x) Array([False, True, False], dtype=bool)
For boolean input, this is equivalent to
invert, which implements the unary~operator:>>> ~x Array([False, True, False], dtype=bool)
For non-boolean input, the input of
logical_notis implicitly cast to boolean:>>> x = jnp.array([-1, 0, 1]) >>> jnp.logical_not(x) Array([False, True, False], dtype=bool)
- scico.numpy.logical_or(*args: ArrayLike, out: None = None, where: None = None) Any¶
Compute the logical OR operation elementwise.
JAX implementation of
numpy.logical_or. This is a universal function, and supports the additional APIs described atjax.numpy.ufunc.- Parameters:
x – input arrays. Must be broadcastable to a common shape.
y – input arrays. Must be broadcastable to a common shape.
- Returns:
Array containing the result of the element-wise logical OR.
Examples
>>> x = jnp.arange(4) >>> jnp.logical_or(x, 1) Array([ True, True, True, True], dtype=bool)
- scico.numpy.logical_xor(*args: ArrayLike, out: None = None, where: None = None) Any¶
Compute the logical XOR operation elementwise.
JAX implementation of
numpy.logical_xor. This is a universal function, and supports the additional APIs described atjax.numpy.ufunc.- Parameters:
x – input arrays. Must be broadcastable to a common shape.
y – input arrays. Must be broadcastable to a common shape.
- Returns:
Array containing the result of the element-wise logical XOR.
Examples
>>> x = jnp.arange(4) >>> jnp.logical_xor(x, 1) Array([ True, False, False, False], dtype=bool)
- scico.numpy.logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0)¶
Generate logarithmically-spaced values.
JAX implementation of
numpy.logspace.- Parameters:
start (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Used to specify the start value. The start value isbase ** start.stop (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Used to specify the stop value. The end value isbase ** stop.num (
int) – int, optional, default=50. Number of values to generate.endpoint (
bool) – bool, optional, default=True. If True, then include thestopvalue in the result. If False, then exclude thestopvalue.base (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array, optional, default=10. Specifies the base of the logarithm.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional. Specifies the dtype of the output.axis (
int) – int, optional, default=0. Axis along which to generate the logspace.
- Return type:
- Returns:
An array of logarithm.
See also
jax.numpy.arange: GenerateNevenly-spaced values given a starting point and a step value.jax.numpy.linspace: Generate evenly-spaced values.jax.numpy.geomspace: Generate geometrically-spaced values.
Examples
List 5 logarithmically spaced values between 1 (
10 ** 0) and 100 (10 ** 2):>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.logspace(0, 2, 5) Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32)
List 5 logarithmically-spaced values between 1(
10 ** 0) and 100 (10 ** 2), excluding endpoint:>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.logspace(0, 2, 5, endpoint=False) Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32)
List 7 logarithmically-spaced values between 1 (
2 ** 0) and 4 (2 ** 2) with base 2:>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.logspace(0, 2, 7, base=2) Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32)
Multi-dimensional logspace:
>>> start = jnp.array([0, 5]) >>> stop = jnp.array([5, 0]) >>> base = jnp.array([2, 3]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.logspace(start, stop, 5, base=base) Array([[ 1. , 243. ], [ 2.378, 61.547], [ 5.657, 15.588], [ 13.454, 3.948], [ 32. , 1. ]], dtype=float32)
- scico.numpy.mask_indices(n, mask_func, k=0, *, size=None)¶
Return indices of a mask of an (n, n) array.
- Parameters:
n (
int) – static integer array dimension.mask_func (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex],int],Array]) – a function that takes a shape(n, n)array and an optional offsetk, and returns a shape(n, n)mask. Examples of functions with this signature aretriuandtril.k (
int) – a scalar value passed tomask_func.size (
int|None) – optional argument specifying the static size of the output arrays. This is passed tononzerowhen generating the indices from the mask.
- Return type:
- Returns:
a tuple of indices where
mask_funcis nonzero.
See also
jax.numpy.triu_indices: computemask_indicesfortriu.jax.numpy.tril_indices: computemask_indicesfortril.
Examples
Calling
mask_indiceson built-in masking functions:>>> jnp.mask_indices(3, jnp.triu) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
Calling
mask_indiceson a custom masking function:>>> def mask_func(x, k=0): ... i = jnp.arange(x.shape[0])[:, None] ... j = jnp.arange(x.shape[1]) ... return (i + 1) % (j + 1 + k) == 0 >>> mask_func(jnp.ones((3, 3))) Array([[ True, False, False], [ True, True, False], [ True, False, True]], dtype=bool) >>> jnp.mask_indices(3, mask_func) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))
- scico.numpy.matmul(a, b, *, precision=None, preferred_element_type=None, out_sharding=None)¶
Perform a matrix multiplication.
JAX implementation of
numpy.matmul.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first input array, of shape(N,)or(..., K, N).b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second input array. Must have shape(N,)or(..., N, M). In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions ofa.precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofaandb.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
array containing the matrix product of the inputs. Shape is
a.shape[:-1]ifb.ndim == 1, otherwise the shape is(..., K, M), where leading dimensions ofaandbare broadcast together.
See also
jax.numpy.linalg.vecdot: batched vector product.jax.numpy.linalg.tensordot: batched tensor product.jax.lax.dot_general: general N-dimensional batched dot product.
Examples
Vector dot products:
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.matmul(a, b) Array(32, dtype=int32)
Matrix dot product:
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> b = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> jnp.matmul(a, b) Array([[22, 28], [49, 64]], dtype=int32)
For convenience, in all cases you can do the same computation using the
@operator:>>> a @ b Array([[22, 28], [49, 64]], dtype=int32)
- scico.numpy.matrix_transpose(x, /)¶
Transpose the last two dimensions of an array.
JAX implementation of
numpy.matrix_transpose, implemented in terms ofjax.lax.transpose.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array, Must havex.ndim >= 2- Return type:
- Returns:
matrix-transposed copy of the array.
See also
jax.Array.mT: same operation accessed via anArrayproperty.jax.numpy.transpose: general multi-axis transpose
Note
Unlike
numpy.matrix_transpose,jax.numpy.matrix_transposewill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.Examples
Here is a 2x2x2 matrix representing a batched 2x2 matrix:
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
For convenience, you can perform the same transpose via the
mTproperty ofjax.Array:>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
- scico.numpy.max(a, axis=None, out=None, keepdims=False, initial=None, where=None)¶
Return the maximum of the array elements along a given axis.
JAX implementation of
numpy.max.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or array, default=None. Axis along which the maximum to be computed. If None, the maximum is computed along all the axes.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.initial (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, default=None. Initial value for the maximum.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array of boolean dtype, default=None. The elements to be used in the maximum. Array should be broadcast compatible to the input.initialmust be specified whenwhereis used.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of maximum values along the given axis.
See also
jax.numpy.min: Compute the minimum of array elements along a given axis.jax.numpy.sum: Compute the sum of array elements along a given axis.jax.numpy.prod: Compute the product of array elements along a given axis.
Examples
By default,
jnp.maxcomputes the maximum of elements along all the axes.>>> x = jnp.array([[9, 3, 4, 5], ... [5, 2, 7, 4], ... [8, 1, 3, 6]]) >>> jnp.max(x) Array(9, dtype=int32)
If
axis=1, the maximum will be computed along axis 1.>>> jnp.max(x, axis=1) Array([9, 7, 8], dtype=int32)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.max(x, axis=1, keepdims=True) Array([[9], [7], [8]], dtype=int32)
To include only specific elements in computing the maximum, you can use
where. It can either have same dimension as input>>> where=jnp.array([[0, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.max(x, axis=1, keepdims=True, initial=0, where=where) Array([[4], [7], [8]], dtype=int32)
or must be broadcast compatible with input.
>>> where = jnp.array([[False], ... [False], ... [False]]) >>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where) Array([[0, 0, 0, 0]], dtype=int32)
- scico.numpy.maximum(*args: ArrayLike, out: None = None, where: None = None) Any¶
Return element-wise maximum of the input arrays.
JAX implementation of
numpy.maximum.- Parameters:
x – input array or scalar.
y – input array or scalar. Both
xandyshould either have same shape or be broadcast compatible.
- Returns:
An array containing the element-wise maximum of
xandy.
Note
- For each pair of elements,
jnp.maximumreturns: larger of the two if both elements are finite numbers.
nanif one element isnan.
See also
jax.numpy.minimum: Returns element-wise minimum of the input arrays.jax.numpy.fmax: Returns element-wise maximum of the input arrays, ignoring NaNs.jax.numpy.amax: Returns the maximum of array elements along a given axis.jax.numpy.nanmax: Returns the maximum of the array elements along a given axis, ignoring NaNs.
Examples
Inputs with
x.shape == y.shape:>>> x = jnp.array([1, -5, 3, 2]) >>> y = jnp.array([-2, 4, 7, -6]) >>> jnp.maximum(x, y) Array([1, 4, 7, 2], dtype=int32)
Inputs with broadcast compatibility:
>>> x1 = jnp.array([[-2, 5, 7, 4], ... [1, -6, 3, 8]]) >>> y1 = jnp.array([-5, 3, 6, 9]) >>> jnp.maximum(x1, y1) Array([[-2, 5, 7, 9], [ 1, 3, 6, 9]], dtype=int32)
Inputs having
nan:>>> nan = jnp.nan >>> x2 = jnp.array([nan, -3, 9]) >>> y2 = jnp.array([[4, -2, nan], ... [-3, -5, 10]]) >>> jnp.maximum(x2, y2) Array([[nan, -2., nan], [nan, -3., 10.]], dtype=float32)
- scico.numpy.mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=None)¶
Return the mean of array elements along a given axis.
JAX implementation of
numpy.mean.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array.axis (
Union[int,Sequence[int],None]) – optional, int or sequence of ints, default=None. Axis along which the mean to be computed. If None, mean is computed along all the axes.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – The type of the output array. If None (default) then the output dtype will be match the input dtype for floating point inputs, or be set to float32 or float64 for non-floating-point inputs.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional, boolean array, default=None. The elements to be used in the mean. Array should be broadcast compatible to the input.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of the mean along the given axis.
Notes
For inputs of type float16 or bfloat16, the reductions will be performed at float32 precision.
See also
jax.numpy.average: Compute the weighted average of array elementsjax.numpy.sum: Compute the sum of array elements.
Examples
By default, the mean is computed along all the axes.
>>> x = jnp.array([[1, 3, 4, 2], ... [5, 2, 6, 3], ... [8, 1, 2, 9]]) >>> jnp.mean(x) Array(3.8333335, dtype=float32)
If
axis=1, the mean is computed along axis 1.>>> jnp.mean(x, axis=1) Array([2.5, 4. , 5. ], dtype=float32)
If
keepdims=True,ndimof the output is equal to that of the input.>>> jnp.mean(x, axis=1, keepdims=True) Array([[2.5], [4. ], [5. ]], dtype=float32)
To use only specific elements of
xto compute the mean, you can usewhere.>>> where = jnp.array([[1, 0, 1, 0], ... [0, 1, 0, 1], ... [1, 1, 0, 1]], dtype=bool) >>> jnp.mean(x, axis=1, keepdims=True, where=where) Array([[2.5], [2.5], [6. ]], dtype=float32)
- scico.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')¶
Construct N-dimensional grid arrays from N 1-dimensional vectors.
JAX implementation of
numpy.meshgrid.- Parameters:
xi (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N arrays to convert to a grid.copy (
bool) – whether to copy the input arrays. JAX supports onlycopy=True, though under JIT compilation the compiler may opt to avoid copies.sparse (
bool) – if False (default), then each returned arrays will be of shape[len(x1), len(x2), ..., len(xN)]. If False, then returned arrays will be of shape[1, 1, ..., len(xi), ..., 1, 1].indexing (
str) – options are'xy'for cartesian indexing (default) or'ij'for matrix indexing.
- Return type:
- Returns:
A length-N list of grid arrays.
See also
jax.numpy.indices: generate a grid of indices.jax.numpy.mgrid: create a meshgrid using indexing syntax.jax.numpy.ogrid: create an open meshgrid using indexing syntax.
Examples
For the following examples, we’ll use these 1D arrays as inputs:
>>> x = jnp.array([1, 2]) >>> y = jnp.array([10, 20, 30])
2D cartesian mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y) >>> print(x_grid) [[1 2] [1 2] [1 2]] >>> print(y_grid) [[10 10] [20 20] [30 30]]
2D sparse cartesian mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True) >>> print(x_grid) [[1 2]] >>> print(y_grid) [[10] [20] [30]]
2D matrix-index mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij') >>> print(x_grid) [[1 1 1] [2 2 2]] >>> print(y_grid) [[10 20 30] [10 20 30]]
- scico.numpy.min(a, axis=None, out=None, keepdims=False, initial=None, where=None)¶
Return the minimum of array elements along a given axis.
JAX implementation of
numpy.min.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or array, default=None. Axis along which the minimum to be computed. If None, the minimum is computed along all the axes.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.initial (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, Default=None. Initial value for the minimum.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, default=None. The elements to be used in the minimum. Array should be broadcast compatible to the input.initialmust be specified whenwhereis used.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of minimum values along the given axis.
See also
jax.numpy.max: Compute the maximum of array elements along a given axis.jax.numpy.sum: Compute the sum of array elements along a given axis.jax.numpy.prod: Compute the product of array elements along a given axis.
Examples
By default, the minimum is computed along all the axes.
>>> x = jnp.array([[2, 5, 1, 6], ... [3, -7, -2, 4], ... [8, -4, 1, -3]]) >>> jnp.min(x) Array(-7, dtype=int32)
If
axis=1, the minimum is computed along axis 1.>>> jnp.min(x, axis=1) Array([ 1, -7, -4], dtype=int32)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.min(x, axis=1, keepdims=True) Array([[ 1], [-7], [-4]], dtype=int32)
To include only specific elements in computing the minimum, you can use
where.wherecan either have same dimension as input.>>> where=jnp.array([[1, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where) Array([[ 0], [-2], [-4]], dtype=int32)
or must be broadcast compatible with input.
>>> where = jnp.array([[False], ... [False], ... [False]]) >>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where) Array([[0, 0, 0, 0]], dtype=int32)
- scico.numpy.minimum(*args: ArrayLike, out: None = None, where: None = None) Any¶
Return element-wise minimum of the input arrays.
JAX implementation of
numpy.minimum.- Parameters:
x – input array or scalar.
y – input array or scalar. Both
xandyshould either have same shape or be broadcast compatible.
- Returns:
An array containing the element-wise minimum of
xandy.
Note
- For each pair of elements,
jnp.minimumreturns: smaller of the two if both elements are finite numbers.
nanif one element isnan.
See also
jax.numpy.maximum: Returns element-wise maximum of the input arrays.jax.numpy.fmin: Returns element-wise minimum of the input arrays, ignoring NaNs.jax.numpy.amin: Returns the minimum of array elements along a given axis.jax.numpy.nanmin: Returns the minimum of the array elements along a given axis, ignoring NaNs.
Examples
Inputs with
x.shape == y.shape:>>> x = jnp.array([2, 3, 5, 1]) >>> y = jnp.array([-3, 6, -4, 7]) >>> jnp.minimum(x, y) Array([-3, 3, -4, 1], dtype=int32)
Inputs having broadcast compatibility:
>>> x1 = jnp.array([[1, 5, 2], ... [-3, 4, 7]]) >>> y1 = jnp.array([-2, 3, 6]) >>> jnp.minimum(x1, y1) Array([[-2, 3, 2], [-3, 3, 6]], dtype=int32)
Inputs with
nan:>>> nan = jnp.nan >>> x2 = jnp.array([[2.5, nan, -2], ... [nan, 5, 6], ... [-4, 3, 7]]) >>> y2 = jnp.array([1, nan, 5]) >>> jnp.minimum(x2, y2) Array([[ 1., nan, -2.], [nan, nan, 5.], [-4., nan, 5.]], dtype=float32)
- scico.numpy.mod(x1, x2, /)¶
Alias of
jax.numpy.remainder- Return type:
- scico.numpy.modf(x, /, out=None)¶
Return element-wise fractional and integral parts of the input array.
JAX implementation of
numpy.modf.- Parameters:
- Return type:
- Returns:
An array containing the fractional and integral parts of the elements of
x, promoting dtypes inexact.
See also
jax.numpy.divmod: Calculates the integer quotient and remainder ofx1byx2element-wise.
Examples
>>> jnp.modf(4.8) (Array(0.8000002, dtype=float32, weak_type=True), Array(4., dtype=float32, weak_type=True)) >>> x = jnp.array([-3.4, -5.7, 0.6, 1.5, 2.3]) >>> jnp.modf(x) (Array([-0.4000001 , -0.6999998 , 0.6 , 0.5 , 0.29999995], dtype=float32), Array([-3., -5., 0., 1., 2.], dtype=float32))
- scico.numpy.moveaxis(a, source, destination)¶
Move an array axis to a new position
JAX implementation of
numpy.moveaxis, implemented in terms ofjax.lax.transpose.- Parameters:
- Return type:
- Returns:
Copy of
awith axes moved fromsourcetodestination.
Notes
Unlike
numpy.moveaxis,jax.numpy.moveaxiswill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.swapaxes: swap two axes.jax.numpy.rollaxis: older API for moving an axis.jax.numpy.transpose: general axes permutation.
Examples
>>> a = jnp.ones((2, 3, 4, 5))
Move axis
1to the end of the array:>>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)
Move the last axis to position 1:
>>> jnp.moveaxis(a, -1, 1).shape (2, 5, 3, 4)
Move multiple axes:
>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape (4, 5, 3, 2)
This can also be accomplished via
transpose:>>> a.transpose(2, 3, 1, 0).shape (4, 5, 3, 2)
- scico.numpy.multiply(*args: ArrayLike, out: None = None, where: None = None) Any¶
Multiply two arrays element-wise.
JAX implementation of
numpy.multiply. This is a universal function, and supports the additional APIs described atjax.numpy.ufunc. This function provides the implementation of the*operator for JAX arrays.- Parameters:
x – arrays to multiply. Must be broadcastable to a common shape.
y – arrays to multiply. Must be broadcastable to a common shape.
- Returns:
Array containing the result of the element-wise multiplication.
Examples
Calling
multiplyexplicitly:>>> x = jnp.arange(4) >>> jnp.multiply(x, 10) Array([ 0, 10, 20, 30], dtype=int32)
Calling
multiplyvia the*operator:>>> x * 10 Array([ 0, 10, 20, 30], dtype=int32)
- scico.numpy.nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)¶
Replace NaN and infinite entries in an array.
JAX implementation of
numpy.nan_to_num.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of values to be replaced. If it does not have an inexact dtype it will be returned unmodified.copy (
bool) – unused by JAXnan (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – value to substitute for NaN entries. Defaults to 0.0.posinf (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – value to substitute for positive infinite entries. Defaults to the maximum representable value.neginf (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – value to substitute for positive infinite entries. Defaults to the minimum representable value.
- Return type:
- Returns:
A copy of
xwith the requested substitutions.
See also
jax.numpy.isnan: return True where the array contains NaNjax.numpy.isposinf: return True where the array contains +infjax.numpy.isneginf: return True where the array contains -inf
Examples
>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])
Default substitution values:
>>> jnp.nan_to_num(x) Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38, 2.0000000e+00, -3.4028235e+38], dtype=float32)
Overriding substitutions for
-infand+inf:>>> jnp.nan_to_num(x, posinf=999, neginf=-999) Array([ 0., 0., 1., 999., 2., -999.], dtype=float32)
If you only wish to substitute for NaN values while leaving
infvalues untouched, usingwherewithjax.numpy.isnanis a better option:>>> jnp.where(jnp.isnan(x), 0, x) Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)
- scico.numpy.nanargmax(a, axis=None, out=None, keepdims=None)¶
Return the index of the maximum value of an array, ignoring NaNs.
JAX implementation of
numpy.nanargmax.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input arrayaxis (
int|None) – optional integer specifying the axis along which to find the maximum value. Ifaxisis not specified,awill be flattened.out (
None) – unused by JAXkeepdims (
bool|None) – if True, then return an array with the same number of dimensions asa.
- Return type:
- Returns:
an array containing the index of the maximum value along the specified axis.
Note
In the case of an axis with all-NaN values, the returned index will be -1. This differs from the behavior of
numpy.nanargmax, which raises an error.See also
jax.numpy.argmax: return the index of the maximum value.jax.numpy.nanargmin: computeargminwhile ignoring NaN values.
Examples
>>> x = jnp.array([1, 3, 5, 4, jnp.nan])
Using a standard
argmaxleads to potentially unexpected results:>>> jnp.argmax(x) Array(4, dtype=int32)
Using
nanargmaxreturns the index of the maximum non-NaN value.>>> jnp.nanargmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan], ... [5, 4, jnp.nan]]) >>> jnp.nanargmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.nanargmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)
- scico.numpy.nanargmin(a, axis=None, out=None, keepdims=None)¶
Return the index of the minimum value of an array, ignoring NaNs.
JAX implementation of
numpy.nanargmin.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input arrayaxis (
int|None) – optional integer specifying the axis along which to find the maximum value. Ifaxisis not specified,awill be flattened.out (
None) – unused by JAXkeepdims (
bool|None) – if True, then return an array with the same number of dimensions asa.
- Return type:
- Returns:
an array containing the index of the minimum value along the specified axis.
Note
In the case of an axis with all-NaN values, the returned index will be -1. This differs from the behavior of
numpy.nanargmin, which raises an error.See also
jax.numpy.argmin: return the index of the minimum value.jax.numpy.nanargmax: computeargmaxwhile ignoring NaN values.
Examples
>>> x = jnp.array([jnp.nan, 3, 5, 4, 2]) >>> jnp.nanargmin(x) Array(4, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan], ... [5, 4, jnp.nan]]) >>> jnp.nanargmin(x, axis=1) Array([0, 1], dtype=int32)
>>> jnp.nanargmin(x, axis=1, keepdims=True) Array([[0], [1]], dtype=int32)
- scico.numpy.nancumprod(a, axis=None, dtype=None, out=None)¶
Cumulative product of elements along an axis, ignoring NaN values.
JAX implementation of
numpy.nancumprod.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array to be accumulated.axis (
int|None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.out (
None) – unused by JAX
- Return type:
- Returns:
An array containing the accumulated product along the given axis.
See also
jax.numpy.cumprod: cumulative product without ignoring NaN values.jax.numpy.multiply.accumulate: cumulative product via ufunc methods.jax.numpy.prod: product along axis
Examples
>>> x = jnp.array([[1., 2., jnp.nan], ... [4., jnp.nan, 6.]])
The standard cumulative product will propagate NaN values:
>>> jnp.cumprod(x) Array([ 1., 2., nan, nan, nan, nan], dtype=float32)
nancumprodwill ignore NaN values, effectively replacing them with ones:>>> jnp.nancumprod(x) Array([ 1., 2., 2., 8., 8., 48.], dtype=float32)
Cumulative product along axis 1:
>>> jnp.nancumprod(x, axis=1) Array([[ 1., 2., 2.], [ 4., 4., 24.]], dtype=float32)
- scico.numpy.nancumsum(a, axis=None, dtype=None, out=None)¶
Cumulative sum of elements along an axis, ignoring NaN values.
JAX implementation of
numpy.nancumsum.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array to be accumulated.axis (
int|None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.out (
None) – unused by JAX
- Return type:
- Returns:
An array containing the accumulated sum along the given axis.
See also
jax.numpy.cumsum: cumulative sum without ignoring NaN values.jax.numpy.cumulative_sum: cumulative sum via the array API standard.jax.numpy.add.accumulate: cumulative sum via ufunc methods.jax.numpy.sum: sum along axis
Examples
>>> x = jnp.array([[1., 2., jnp.nan], ... [4., jnp.nan, 6.]])
The standard cumulative sum will propagate NaN values:
>>> jnp.cumsum(x) Array([ 1., 3., nan, nan, nan, nan], dtype=float32)
nancumsumwill ignore NaN values, effectively replacing them with zeros:>>> jnp.nancumsum(x) Array([ 1., 3., 3., 7., 7., 13.], dtype=float32)
Cumulative sum along axis 1:
>>> jnp.nancumsum(x, axis=1) Array([[ 1., 3., 3.], [ 4., 4., 10.]], dtype=float32)
- scico.numpy.nanmax(a, axis=None, out=None, keepdims=False, initial=None, where=None)¶
Return the maximum of the array elements along a given axis, ignoring NaNs.
JAX implementation of
numpy.nanmax.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or sequence of ints, default=None. Axis along which the maximum is computed. If None, the maximum is computed along the flattened array.keepdims (
bool) – bool, default=False. If True, reduced axes are left in the result with size 1.initial (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, default=None. Initial value for the maximum.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – array of boolean dtype, default=None. The elements to be used in the maximum. Array should be broadcast compatible to the input.initialmust be specified whenwhereis used.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of maximum values along the given axis, ignoring NaNs. If all values are NaNs along the given axis, returns
nan.
See also
jax.numpy.nanmin: Compute the minimum of array elements along a given axis, ignoring NaNs.jax.numpy.nansum: Compute the sum of array elements along a given axis, ignoring NaNs.jax.numpy.nanprod: Compute the product of array elements along a given axis, ignoring NaNs.jax.numpy.nanmean: Compute the mean of array elements along a given axis, ignoring NaNs.
Examples
By default,
jnp.nanmaxcomputes the maximum of elements along the flattened array.>>> nan = jnp.nan >>> x = jnp.array([[8, nan, 4, 6], ... [nan, -2, nan, -4], ... [-2, 1, 7, nan]]) >>> jnp.nanmax(x) Array(8., dtype=float32)
If
axis=1, the maximum will be computed along axis 1.>>> jnp.nanmax(x, axis=1) Array([ 8., -2., 7.], dtype=float32)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.nanmax(x, axis=1, keepdims=True) Array([[ 8.], [-2.], [ 7.]], dtype=float32)
To include only specific elements in computing the maximum, you can use
where. It can either have same dimension as input>>> where=jnp.array([[0, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.nanmax(x, axis=1, keepdims=True, initial=0, where=where) Array([[4.], [0.], [7.]], dtype=float32)
or must be broadcast compatible with input.
>>> where = jnp.array([[True], ... [False], ... [False]]) >>> jnp.nanmax(x, axis=0, keepdims=True, initial=0, where=where) Array([[8., 0., 4., 6.]], dtype=float32)
- scico.numpy.nanmin(a, axis=None, out=None, keepdims=False, initial=None, where=None)¶
Return the minimum of the array elements along a given axis, ignoring NaNs.
JAX implementation of
numpy.nanmin.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or sequence of ints, default=None. Axis along which the minimum is computed. If None, the minimum is computed along the flattened array.keepdims (
bool) – bool, default=False. If True, reduced axes are left in the result with size 1.initial (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, default=None. Initial value for the minimum.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – array of boolean dtype, default=None. The elements to be used in the minimum. Array should be broadcast compatible to the input.initialmust be specified whenwhereis used.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of minimum values along the given axis, ignoring NaNs. If all values are NaNs along the given axis, returns
nan.
See also
jax.numpy.nanmax: Compute the maximum of array elements along a given axis, ignoring NaNs.jax.numpy.nansum: Compute the sum of array elements along a given axis, ignoring NaNs.jax.numpy.nanprod: Compute the product of array elements along a given axis, ignoring NaNs.jax.numpy.nanmean: Compute the mean of array elements along a given axis, ignoring NaNs.
Examples
By default,
jnp.nanmincomputes the minimum of elements along the flattened array.>>> nan = jnp.nan >>> x = jnp.array([[1, nan, 4, 5], ... [nan, -2, nan, -4], ... [2, 1, 3, nan]]) >>> jnp.nanmin(x) Array(-4., dtype=float32)
If
axis=1, the maximum will be computed along axis 1.>>> jnp.nanmin(x, axis=1) Array([ 1., -4., 1.], dtype=float32)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.nanmin(x, axis=1, keepdims=True) Array([[ 1.], [-4.], [ 1.]], dtype=float32)
To include only specific elements in computing the maximum, you can use
where. It can either have same dimension as input>>> where=jnp.array([[0, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.nanmin(x, axis=1, keepdims=True, initial=0, where=where) Array([[ 0.], [-4.], [ 0.]], dtype=float32)
or must be broadcast compatible with input.
>>> where = jnp.array([[False], ... [True], ... [False]]) >>> jnp.nanmin(x, axis=0, keepdims=True, initial=0, where=where) Array([[ 0., -2., 0., -4.]], dtype=float32)
- scico.numpy.nanprod(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None)¶
Return the product of the array elements along a given axis, ignoring NaNs.
JAX implementation of
numpy.nanprod.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or sequence of ints, default=None. Axis along which the product is computed. If None, the product is computed along the flattened array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – The type of the output array. Default=None.keepdims (
bool) – bool, default=False. If True, reduced axes are left in the result with size 1.initial (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, default=None. Initial value for the product.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – array of boolean dtype, default=None. The elements to be used in the product. Array should be broadcast compatible to the input.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array containing the product of array elements along the given axis, ignoring NaNs. If all elements along the given axis are NaNs, returns 1.
See also
jax.numpy.nanmin: Compute the minimum of array elements along a given axis, ignoring NaNs.jax.numpy.nanmax: Compute the maximum of array elements along a given axis, ignoring NaNs.jax.numpy.nansum: Compute the sum of array elements along a given axis, ignoring NaNs.jax.numpy.nanmean: Compute the mean of array elements along a given axis, ignoring NaNs.
Examples
By default,
jnp.nanprodcomputes the product of elements along the flattened array.>>> nan = jnp.nan >>> x = jnp.array([[nan, 3, 4, nan], ... [5, nan, 1, 3], ... [2, 1, nan, 1]]) >>> jnp.nanprod(x) Array(360., dtype=float32)
If
axis=1, the product will be computed along axis 1.>>> jnp.nanprod(x, axis=1) Array([12., 15., 2.], dtype=float32)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.nanprod(x, axis=1, keepdims=True) Array([[12.], [15.], [ 2.]], dtype=float32)
To include only specific elements in computing the maximum, you can use
where.>>> where=jnp.array([[1, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.nanprod(x, axis=1, keepdims=True, where=where) Array([[4.], [3.], [2.]], dtype=float32)
If
whereisFalseat all elements,jnp.nanprodreturns 1 along the given axis.>>> where = jnp.array([[False], ... [False], ... [False]]) >>> jnp.nanprod(x, axis=0, keepdims=True, where=where) Array([[1., 1., 1., 1.]], dtype=float32)
- scico.numpy.nansum(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None)¶
Return the sum of the array elements along a given axis, ignoring NaNs.
JAX implementation of
numpy.nansum.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or sequence of ints, default=None. Axis along which the sum is computed. If None, the sum is computed along the flattened array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – The type of the output array. Default=None.keepdims (
bool) – bool, default=False. If True, reduced axes are left in the result with size 1.initial (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, default=None. Initial value for the sum.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – array of boolean dtype, default=None. The elements to be used in the sum. Array should be broadcast compatible to the input.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array containing the sum of array elements along the given axis, ignoring NaNs. If all elements along the given axis are NaNs, returns 0.
See also
jax.numpy.nanmin: Compute the minimum of array elements along a given axis, ignoring NaNs.jax.numpy.nanmax: Compute the maximum of array elements along a given axis, ignoring NaNs.jax.numpy.nanprod: Compute the product of array elements along a given axis, ignoring NaNs.jax.numpy.nanmean: Compute the mean of array elements along a given axis, ignoring NaNs.
Examples
By default,
jnp.nansumcomputes the sum of elements along the flattened array.>>> nan = jnp.nan >>> x = jnp.array([[3, nan, 4, 5], ... [nan, -2, nan, 7], ... [2, 1, 6, nan]]) >>> jnp.nansum(x) Array(26., dtype=float32)
If
axis=1, the sum will be computed along axis 1.>>> jnp.nansum(x, axis=1) Array([12., 5., 9.], dtype=float32)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.nansum(x, axis=1, keepdims=True) Array([[12.], [ 5.], [ 9.]], dtype=float32)
To include only specific elements in computing the sum, you can use
where.>>> where=jnp.array([[1, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.nansum(x, axis=1, keepdims=True, where=where) Array([[7.], [7.], [9.]], dtype=float32)
If
whereisFalseat all elements,jnp.nansumreturns 0 along the given axis.>>> where = jnp.array([[False], ... [False], ... [False]]) >>> jnp.nansum(x, axis=0, keepdims=True, where=where) Array([[0., 0., 0., 0.]], dtype=float32)
- scico.numpy.ndim(a)¶
Return the number of dimensions of an array.
JAX implementation of
numpy.ndim. Unlikenp.ndim, this function raises aTypeErrorif the input is a collection such as a list or tuple.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex,SupportsNdim]) – array-like object, or any object with anndimattribute.- Return type:
- Returns:
An integer specifying the number of dimensions of
a.
Examples
Number of dimensions for arrays:
>>> x = jnp.arange(10) >>> jnp.ndim(x) 1 >>> y = jnp.ones((2, 3)) >>> jnp.ndim(y) 2
This also works for scalars:
>>> jnp.ndim(3.14) 0
For arrays, this can also be accessed via the
jax.Array.ndimproperty:>>> x.ndim 1
- scico.numpy.negative(*args: ArrayLike, out: None = None, where: None = None) Any¶
Return element-wise negative values of the input.
JAX implementation of
numpy.negative.- Parameters:
x – input array or scalar.
- Returns:
An array with same shape and dtype as
xcontaining-x.
See also
jax.numpy.positive: Returns element-wise positive values of the input.jax.numpy.sign: Returns element-wise indication of sign of the input.
Note
jnp.negative, when applied overunsigned integer, produces the result of their two’s complement negation, which typically results in unexpected large positive values due to integer underflow.Examples
For real-valued inputs:
>>> x = jnp.array([0., -3., 7]) >>> jnp.negative(x) Array([-0., 3., -7.], dtype=float32)
For complex inputs:
>>> x1 = jnp.array([1-2j, -3+4j, 5-6j]) >>> jnp.negative(x1) Array([-1.+2.j, 3.-4.j, -5.+6.j], dtype=complex64)
For unit32:
>>> x2 = jnp.array([5, 0, -7]).astype(jnp.uint32) >>> x2 Array([ 5, 0, 4294967289], dtype=uint32) >>> jnp.negative(x2) Array([4294967291, 0, 7], dtype=uint32)
- scico.numpy.nextafter(x, y, /)¶
Return element-wise next floating point value after
xtowardsy.JAX implementation of
numpy.nextafter.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the value after which the next number is found.y (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the direction towards which the next number is found.xandyshould either have same shape or be broadcast compatible.
- Return type:
- Returns:
An array containing the next representable number of
xin the direction ofy.
Examples
>>> jnp.nextafter(2, 1) Array(1.9999999, dtype=float32, weak_type=True) >>> x = jnp.array([3, -2, 1]) >>> y = jnp.array([2, -1, 2]) >>> jnp.nextafter(x, y) Array([ 2.9999998, -1.9999999, 1.0000001], dtype=float32)
- scico.numpy.nonzero(a, *, size=None, fill_value=None)¶
Return indices of nonzero elements of an array.
JAX implementation of
numpy.nonzero.Because the size of the output of
nonzerois data-dependent, the function is not compatible with JIT and other transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.nonzeroto be used within JAX’s transformations.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array.size (
int|None) – optional static integer specifying the number of nonzero entries to return. If there are more nonzero elements than the specifiedsize, then indices will be truncated at the end. If there are fewer nonzero elements than the specified size, then indices will be padded withfill_value, which defaults to zero.fill_value (
Union[None,Array,ndarray,bool,number,bool,int,float,complex,tuple[Union[Array,ndarray,bool,number,bool,int,float,complex],...]]) – optional padding value whensizeis specified. Defaults to 0.
- Return type:
- Returns:
Tuple of JAX Arrays of length
a.ndim, containing the indices of each nonzero value.
See also
Examples
One-dimensional array returns a length-1 tuple of indices:
>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jnp.nonzero(x) (Array([1, 3, 5], dtype=int32),)
Two-dimensional array returns a length-2 tuple of indices:
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 7]]) >>> jnp.nonzero(x) (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
In either case, the resulting tuple of indices can be used directly to extract the nonzero values:
>>> indices = jnp.nonzero(x) >>> x[indices] Array([5, 6, 7], dtype=int32)
The output of
nonzerohas a dynamic shape, because the number of returned indices depends on the contents of the input array. As such, it is incompatible with JIT and other JAX transformations:>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jax.jit(jnp.nonzero)(x) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This can be addressed by passing a static
sizeparameter to specify the desired output shape:>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') >>> nonzero_jit(x, size=3) (Array([1, 3, 5], dtype=int32),)
If
sizedoes not match the true size, the result will be either truncated or padded:>>> nonzero_jit(x, size=2) # size < 3: indices are truncated (Array([1, 3], dtype=int32),) >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. (Array([1, 3, 5, 0, 0], dtype=int32),)
You can specify a custom fill value for the padding using the
fill_valueargument:>>> nonzero_jit(x, size=5, fill_value=len(x)) (Array([1, 3, 5, 6, 6], dtype=int32),)
- scico.numpy.not_equal(x, y, /)¶
Returns element-wise truth value of
x != y.JAX implementation of
numpy.not_equal. This function provides the implementation of the!=operator for JAX arrays.- Parameters:
- Return type:
- Returns:
A boolean array containing
Truewhere the elements ofx != yandFalseotherwise.
See also
jax.numpy.equal: Returns element-wise truth value ofx == y.jax.numpy.greater_equal: Returns element-wise truth value ofx >= y.jax.numpy.less_equal: Returns element-wise truth value ofx <= y.jax.numpy.greater: Returns element-wise truth value ofx > y.jax.numpy.less: Returns element-wise truth value ofx < y.
Examples
>>> jnp.not_equal(0., -0.) Array(False, dtype=bool, weak_type=True) >>> jnp.not_equal(-2, 2) Array(True, dtype=bool, weak_type=True) >>> jnp.not_equal(1, 1.) Array(False, dtype=bool, weak_type=True) >>> jnp.not_equal(5, jnp.array(5)) Array(False, dtype=bool, weak_type=True) >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> y = jnp.array([1, 5, 9]) >>> jnp.not_equal(x, y) Array([[False, True, True], [ True, False, True], [ True, True, False]], dtype=bool) >>> x != y Array([[False, True, True], [ True, False, True], [ True, True, False]], dtype=bool)
- scico.numpy.ones(shape, dtype=None, *, device=None, out_sharding=None)¶
Create an array full of ones.
JAX implementation of
numpy.ones.- Parameters:
shape (
Any) – int or sequence of ints specifying the shape of the created array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype for the created array; defaults to float32 or float64 depending on the X64 configuration (see Default dtypes and the X64 flag).device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed. This argument exists for compatibility with the Python Array API standard.out_sharding (
NamedSharding|P|None) – (optional)PartitionSpecorNamedShardingrepresenting the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying bothout_shardinganddevicewill result in an error.
- Return type:
- Returns:
Array of the specified shape and dtype, with the given device/sharding if specified.
Examples
>>> jnp.ones(4) Array([1., 1., 1., 1.], dtype=float32) >>> jnp.ones((2, 3), dtype=bool) Array([[ True, True, True], [ True, True, True]], dtype=bool)
- scico.numpy.ones_like(a, dtype=None, shape=None, *, device=None, out_sharding=None)¶
Create an array of ones with the same shape and dtype as an array.
JAX implementation of
numpy.ones_like.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex,DuckTypedArray]) – Array-like object withshapeanddtypeattributes.shape (
Any) – optionally override the shape of the created array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally override the dtype of the created array.device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed.
- Return type:
- Returns:
Array of the specified shape and dtype, on the specified device if specified.
Examples
>>> x = jnp.arange(4) >>> jnp.ones_like(x) Array([1, 1, 1, 1], dtype=int32) >>> jnp.ones_like(x, dtype=bool) Array([ True, True, True, True], dtype=bool) >>> jnp.ones_like(x, shape=(2, 3)) Array([[1, 1, 1], [1, 1, 1]], dtype=int32)
- scico.numpy.outer(a, b, out=None)¶
Compute the outer product of two arrays.
JAX implementation of
numpy.outer.- Parameters:
- Return type:
- Returns:
The outer product of the inputs
aandb. Returned array will be of shape(a.size, b.size).
See also
jax.numpy.inner: compute the inner product of two arrays.jax.numpy.einsum: Einstein summation.
Examples
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.outer(a, b) Array([[ 4, 5, 6], [ 8, 10, 12], [12, 15, 18]], dtype=int32)
- scico.numpy.pad(array, pad_width, mode='constant', **kwargs)¶
Add padding to an array.
JAX implementation of
numpy.pad.- Parameters:
array (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array to pad.pad_width (
Union[int,Array,ndarray,Sequence[int|Array|ndarray],Sequence[Sequence[int|Array|ndarray]]]) –specify the pad width for each dimension of an array. Padding widths may be separately specified for before and after the array. Options are:
intor(int,): pad each array dimension with the same number of values both before and after.(before, after): pad each array withbeforeelements before, andafterelements after((before_1, after_1), (before_2, after_2), ... (before_N, after_N)): specify distinctbeforeandaftervalues for each array dimension.
mode (
str|Callable[...,Any]) –a string or callable. Supported pad modes are:
'constant'(default): pad with a constant value, which defaults to zero.'empty': pad with empty values (i.e. zero)'edge': pad with the edge values of the array.'wrap': pad by wrapping the array.'linear_ramp': pad with a linear ramp to specifiedend_values.'maximum': pad with the maximum value.'mean': pad with the mean value.'median': pad with the median value.'minimum': pad with the minimum value.'reflect': pad by reflection.'symmetric': pad by symmetric reflection.<callable>: a callable function. See Notes below.
constant_values – referenced for
mode = 'constant'. Specify the constant value to pad with.stat_length – referenced for
mode in ['maximum', 'mean', 'median', 'minimum']. An integer or tuple specifying the number of edge values to use when calculating the statistic.end_values – referenced for
mode = 'linear_ramp'. Specify the end values to ramp the padding values to.reflect_type – referenced for
mode in ['reflect', 'symmetric']. Specify whether to use even or odd reflection.
- Return type:
- Returns:
A padded copy of
array.
Notes
When
modeis callable, it should have the following signature:def pad_func(row: Array, pad_width: tuple[int, int], iaxis: int, kwargs: dict) -> Array: ...
Here
rowis a 1D slice of the padded array along axisiaxis, with the pad values filled with zeros.pad_widthis a tuple specifying the(before, after)padding sizes, andkwargsare any additional keyword arguments passed to thejax.numpy.padfunction.Note that while in NumPy, the function should modify
rowin-place, in JAX the function should return the modifiedrow. In JAX, the custom padding function will be mapped across the padded axis using thejax.vmaptransformation.See also
jax.numpy.resize: resize an arrayjax.numpy.tile: create a larger array by tiling a smaller array.jax.numpy.repeat: create a larger array by repeating values of a smaller array.
Examples
Pad a 1-dimensional array with zeros:
>>> x = jnp.array([10, 20, 30, 40]) >>> jnp.pad(x, 2) Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32) >>> jnp.pad(x, (2, 4)) Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32)
Pad a 1-dimensional array with specified values:
>>> jnp.pad(x, 2, constant_values=99) Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)
Pad a 1-dimensional array with the mean array value:
>>> jnp.pad(x, 2, mode='mean') Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)
Pad a 1-dimensional array with reflected values:
>>> jnp.pad(x, 2, mode='reflect') Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)
Pad a 2-dimensional array with different paddings in each dimension:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.pad(x, ((1, 2), (3, 0))) Array([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 2, 3], [0, 0, 0, 4, 5, 6], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=int32)
Pad a 1-dimensional array with a custom padding function:
>>> def custom_pad(row, pad_width, iaxis, kwargs): ... # row represents a 1D slice of the zero-padded array. ... before, after = pad_width ... before_value = kwargs.get('before_value', 0) ... after_value = kwargs.get('after_value', 0) ... row = row.at[:before].set(before_value) ... return row.at[len(row) - after:].set(after_value) >>> x = jnp.array([2, 3, 4]) >>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10) Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32)
- scico.numpy.partition(a, kth, axis=-1)¶
Returns a partially-sorted copy of an array.
JAX implementation of
numpy.partition. The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.- Parameters:
- Return type:
- Returns:
A copy of
apartitioned at thekthvalue alongaxis. The entries beforekthare values smaller thantake(a, kth, axis), and entries afterkthare indices of values larger thantake(a, kth, axis)
Note
The JAX version requires the
kthargument to be a static integer rather than a general array. This is implemented via two calls tojax.lax.top_k. If you’re only accessing the top or bottom k values of the output, it may be more efficient to calljax.lax.top_kdirectly.See also
jax.numpy.sort: full sortjax.numpy.argpartition: indirect partial sortjax.lax.top_k: directly find the top k entriesjax.lax.approx_max_k: compute the approximate top k entriesjax.lax.approx_min_k: compute the approximate bottom k entries
Examples
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> x_partitioned = jnp.partition(x, kth) >>> x_partitioned Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
The result is a partially-sorted copy of the input. All values before
kthare of smaller than the pivot value, and all values afterkthare larger than the pivot value:>>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [9 8 7 6 5]
Notice that among
smallest_valuesandlargest_values, the returned order is arbitrary and implementation-dependent.
- scico.numpy.permute_dims(a, /, axes)¶
Permute the axes/dimensions of an array.
JAX implementation of
array_api.permute_dims.- Parameters:
- Return type:
- Returns:
a copy of
awith axes permuted.
Examples
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.permute_dims(a, (1, 0)) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
- scico.numpy.piecewise(x, condlist, funclist, *args, **kw)¶
Evaluate a function defined piecewise across the domain.
JAX implementation of
numpy.piecewise, in terms ofjax.lax.switch.Note
Unlike
numpy.piecewise,jax.numpy.piecewiserequires functions infunclistto be traceable by JAX, as it is implemented viajax.lax.switch.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of input values.condlist (
Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – boolean array or sequence of boolean arrays corresponding to the functions infunclist. If a sequence of arrays, the length of each array must match the length ofxfunclist (
list[Union[Array,ndarray,bool,number,bool,int,float,complex,Callable[...,Array]]]) – list of arrays or functions; must either be the same length ascondlist, or have lengthlen(condlist) + 1, in which case the last entry is the default applied when none of the conditions are True. Alternatively, entries offunclistmay be numerical values, in which case they indicate a constant function.args – additional arguments are passed to each function in
funclist.kwargs – additional arguments are passed to each function in
funclist.
- Return type:
- Returns:
An array which is the result of evaluating the functions on
xat the specified conditions.
See also
jax.lax.switch: choose between N functions based on an index.jax.lax.cond: choose between two functions based on a boolean condition.jax.numpy.where: choose between two results based on a boolean mask.jax.lax.select: choose between two results based on a boolean mask.jax.lax.select_n: choose between N results based on a boolean mask.
Examples
Here’s an example of a function which is zero for negative values, and linear for positive values:
>>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
>>> condlist = [x < 0, x >= 0] >>> funclist = [lambda x: 0 * x, lambda x: x] >>> jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
funclistcan also contain a simple scalar value for constant functions:>>> condlist = [x < 0, x >= 0] >>> funclist = [0, lambda x: x] >>> jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
You can specify a default value by appending an extra condition to
funclist:>>> condlist = [x < -1, x > 1] >>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0] >>> jnp.piecewise(x, condlist, funclist) Array([-3, -2, -1, 0, 0, 0, 1, 2, 3], dtype=int32)
condlistmay also be a simple array of scalar conditions, in which case the associated function applies to the whole range>>> condlist = jnp.array([False, True, False]) >>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100] >>> jnp.piecewise(x, condlist, funclist) Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32)
- scico.numpy.place(arr, mask, vals, *, inplace=True)¶
Update array elements based on a mask.
JAX implementation of
numpy.place.The semantics of
numpy.placeare to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplaceparameter which must be set to False` by the user as a reminder of this API difference.- Parameters:
arr (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array into which values will be placed.mask (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – boolean mask with the same size asarr.vals (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – values to be inserted intoarrat the locations indicated by mask. If too many values are supplied, they will be truncated. If not enough values are supplied, they will be repeated.inplace (
bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.
- Return type:
- Returns:
A copy of
arrwith masked values set to entries from vals.
See also
jax.numpy.put: put elements into an array at numerical indices.jax.numpy.ndarray.at: array updates using NumPy-style indexing
Examples
>>> x = jnp.zeros((3, 5), dtype=int) >>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape) >>> mask Array([[ True, False, False, True, False], [False, True, False, False, True], [False, False, True, False, False]], dtype=bool)
Placing a scalar value:
>>> jnp.place(x, mask, 1, inplace=False) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
In this case,
jnp.placeis similar to the masked array update syntax:>>> x.at[mask].set(1) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
placediffers when placing values from an array. The array is repeated to fill the masked entries:>>> vals = jnp.array([1, 3, 5]) >>> jnp.place(x, mask, vals, inplace=False) Array([[1, 0, 0, 3, 0], [0, 5, 0, 0, 1], [0, 0, 3, 0, 0]], dtype=int32)
- scico.numpy.polydiv(u, v, *, trim_leading_zeros=False)¶
Returns the quotient and remainder of polynomial division.
JAX implementation of
numpy.polydiv.- Parameters:
u (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Array of dividend polynomial coefficients.v (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Array of divisor polynomial coefficients.trim_leading_zeros (
bool) – Default isFalse. IfTrueremoves the leading zeros in the return value to match the result of numpy. But prevents the function from being able to be used in compiled code. Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be considered zero may lead to inconsistent results between NumPy and JAX, and even between different JAX backends. The result may lead to inconsistent output shapes whentrim_leading_zeros=True.
- Return type:
- Returns:
A tuple of quotient and remainder arrays. The dtype of the output is always promoted to inexact.
Note
jax.numpy.polydivonly accepts arrays as input unlikenumpy.polydivwhich accepts scalar inputs as well.See also
jax.numpy.polyadd: Computes the sum of two polynomials.jax.numpy.polysub: Computes the difference of two polynomials.jax.numpy.polymul: Computes the product of two polynomials.
Examples
>>> x1 = jnp.array([5, 7, 9]) >>> x2 = jnp.array([4, 1]) >>> np.polydiv(x1, x2) (array([1.25 , 1.4375]), array([7.5625])) >>> jnp.polydiv(x1, x2) (Array([1.25 , 1.4375], dtype=float32), Array([0. , 0. , 7.5625], dtype=float32))
If
trim_leading_zeros=True, the result matches withnp.polydiv’s.>>> jnp.polydiv(x1, x2, trim_leading_zeros=True) (Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32))
- scico.numpy.polymul(a1, a2, *, trim_leading_zeros=False)¶
Returns the product of two polynomials.
JAX implementation of
numpy.polymul.- Parameters:
a1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – 1D array of polynomial coefficients.a2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – 1D array of polynomial coefficients.trim_leading_zeros (
bool) – Default isFalse. IfTrueremoves the leading zeros in the return value to match the result of numpy. But prevents the function from being able to be used in compiled code. Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be considered zero may lead to inconsistent results between NumPy and JAX, and even between different JAX backends. The result may lead to inconsistent output shapes whentrim_leading_zeros=True.
- Return type:
- Returns:
An array of the coefficients of the product of the two polynomials. The dtype of the output is always promoted to inexact.
Note
jax.numpy.polymulonly accepts arrays as input unlikenumpy.polymulwhich accepts scalar inputs as well.See also
jax.numpy.polyadd: Computes the sum of two polynomials.jax.numpy.polysub: Computes the difference of two polynomials.jax.numpy.polydiv: Computes the quotient and remainder of polynomial division.
Examples
>>> x1 = np.array([2, 1, 0]) >>> x2 = np.array([0, 5, 0, 3]) >>> np.polymul(x1, x2) array([10, 5, 6, 3, 0]) >>> jnp.polymul(x1, x2) Array([ 0., 10., 5., 6., 3., 0.], dtype=float32)
If
trim_leading_zeros=True, the result matches withnp.polymul’s.>>> jnp.polymul(x1, x2, trim_leading_zeros=True) Array([10., 5., 6., 3., 0.], dtype=float32)
For input arrays of dtype
complex:>>> x3 = np.array([2., 1+2j, 1-2j]) >>> x4 = np.array([0, 5, 0, 3]) >>> np.polymul(x3, x4) array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j]) >>> jnp.polymul(x3, x4) Array([ 0. +0.j, 10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64) >>> jnp.polymul(x3, x4, trim_leading_zeros=True) Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
- scico.numpy.positive(x, /)¶
Return element-wise positive values of the input.
JAX implementation of
numpy.positive.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar- Return type:
- Returns:
An array of same shape and dtype as
xcontaining+x.
Note
jnp.positiveis equivalent tox.copy()and is defined only for the types that support arithmetic operations.See also
jax.numpy.negative: Returns element-wise negative values of the input.jax.numpy.sign: Returns element-wise indication of sign of the input.
Examples
For real-valued inputs:
>>> x = jnp.array([-5, 4, 7., -9.5]) >>> jnp.positive(x) Array([-5. , 4. , 7. , -9.5], dtype=float32) >>> x.copy() Array([-5. , 4. , 7. , -9.5], dtype=float32)
For complex inputs:
>>> x1 = jnp.array([1-2j, -3+4j, 5-6j]) >>> jnp.positive(x1) Array([ 1.-2.j, -3.+4.j, 5.-6.j], dtype=complex64) >>> x1.copy() Array([ 1.-2.j, -3.+4.j, 5.-6.j], dtype=complex64)
For uint32:
>>> x2 = jnp.array([6, 0, -4]).astype(jnp.uint32) >>> x2 Array([ 6, 0, 4294967292], dtype=uint32) >>> jnp.positive(x2) Array([ 6, 0, 4294967292], dtype=uint32)
- scico.numpy.pow(x1, x2, /)¶
Alias of
jax.numpy.power- Return type:
- scico.numpy.power(x1, x2, /)¶
Calculate element-wise base
x1exponential ofx2.JAX implementation of
numpy.power.- Parameters:
- Return type:
- Returns:
An array containing the base
x1exponentials ofx2with same dtype as input.
Note
When
x2is a concrete integer scalar,jnp.powerlowers tojax.lax.integer_pow.When
x2is a traced scalar or an array,jnp.powerlowers tojax.lax.pow.jnp.powerraises aTypeErrorfor integer type raised to a concrete negative integer power. For a non-concrete power, the operation is invalid and the returned value is implementation-defined.jnp.powerreturnsnanfor negative value raised to the power of non-integer values.
See also
jax.lax.pow: Computes element-wise power, \(x^y\).jax.lax.integer_pow: Computes element-wise power \(x^y\), where \(y\) is a fixed integer.jax.numpy.float_power: Computes the first array raised to the power of second array, element-wise, by promoting to the inexact dtype.jax.numpy.pow: Computes the first array raised to the power of second array, element-wise.
Examples
Inputs with scalar integers:
>>> jnp.power(4, 3) Array(64, dtype=int32, weak_type=True)
Inputs with same shape:
>>> x1 = jnp.array([2, 4, 5]) >>> x2 = jnp.array([3, 0.5, 2]) >>> jnp.power(x1, x2) Array([ 8., 2., 25.], dtype=float32)
Inputs with broadcast compatibility:
>>> x3 = jnp.array([-2, 3, 1]) >>> x4 = jnp.array([[4, 1, 6], ... [1.3, 3, 5]]) >>> jnp.power(x3, x4) Array([[16., 3., 1.], [nan, 27., 1.]], dtype=float32)
- scico.numpy.printoptions(*args, **kwargs)¶
Alias of
numpy.printoptions.JAX arrays are printed via NumPy, so NumPy’s printoptions configurations will apply to printed JAX arrays.
See the
numpy.set_printoptionsdocumentation for details on the available options and their meanings.
- scico.numpy.prod(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)¶
Return product of the array elements over a given axis.
JAX implementation of
numpy.prod.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or array, default=None. Axis along which the product to be computed. If None, the product is computed along all the axes.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – The type of the output array. Default=None.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.initial (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, Default=None. Initial value for the product.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, default=None. The elements to be used in the product. Array should be broadcast compatible to the input.promote_integers (
bool) – bool, default=True. If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input.promote_integersis ignored ifdtypeis specified.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of the product along the given axis.
See also
jax.numpy.sum: Compute the sum of array elements over a given axis.jax.numpy.max: Compute the maximum of array elements over given axis.jax.numpy.min: Compute the minimum of array elements over given axis.
Examples
By default,
jnp.prodcomputes along all the axes.>>> x = jnp.array([[1, 3, 4, 2], ... [5, 2, 1, 3], ... [2, 1, 3, 1]]) >>> jnp.prod(x) Array(4320, dtype=int32)
If
axis=1, product is computed along axis 1.>>> jnp.prod(x, axis=1) Array([24, 30, 6], dtype=int32)
If
keepdims=True,ndimof the output is equal to that of the input.>>> jnp.prod(x, axis=1, keepdims=True) Array([[24], [30], [ 6]], dtype=int32)
To include only specific elements in the sum, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.prod(x, axis=1, keepdims=True, where=where) Array([[4], [3], [6]], dtype=int32) >>> where = jnp.array([[False], ... [False], ... [False]]) >>> jnp.prod(x, axis=1, keepdims=True, where=where) Array([[1], [1], [1]], dtype=int32)
- scico.numpy.promote_types(a, b)¶
Returns the type to which a binary operation should cast its arguments.
JAX implementation of
numpy.promote_types. For details of JAX’s type promotion semantics, see Type promotion semantics.- Parameters:
a (
Union[str,type[Any],dtype,SupportsDType]) – anumpy.dtypeor a dtype specifier.b (
Union[str,type[Any],dtype,SupportsDType]) – anumpy.dtypeor a dtype specifier.
- Return type:
- Returns:
A
numpy.dtypeobject.
Examples
Type specifiers may be strings, dtypes, or scalar types, and the return value is always a dtype:
>>> jnp.promote_types('int32', 'float32') # strings dtype('float32') >>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes dtype('float32') >>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types dtype('float32')
Built-in scalar types (:type:`int`, :type:`float`, or :type:`complex`) are treated as weakly-typed and will not change the bit width of a strongly-typed counterpart (see discussion in Type promotion semantics):
>>> jnp.promote_types('uint8', int) dtype('uint8') >>> jnp.promote_types('float16', float) dtype('float16')
This differs from the NumPy version of this function, which treats built-in scalar types as equivalent to 64-bit types:
>>> import numpy >>> numpy.promote_types('uint8', int) dtype('int64') >>> numpy.promote_types('float16', float) dtype('float64')
- scico.numpy.ptp(a, axis=None, out=None, keepdims=False)¶
Return the peak-to-peak range along a given axis.
JAX implementation of
numpy.ptp.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array.axis (
Union[int,Sequence[int],None]) – optional, int or sequence of ints, default=None. Axis along which the range is computed. If None, the range is computed on the flattened array.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array with the range of elements along specified axis of input.
Examples
By default,
jnp.ptpcomputes the range along all axes.>>> x = jnp.array([[1, 3, 5, 2], ... [4, 6, 8, 1], ... [7, 9, 3, 4]]) >>> jnp.ptp(x) Array(8, dtype=int32)
If
axis=1, computes the range along axis 1.>>> jnp.ptp(x, axis=1) Array([4, 7, 6], dtype=int32)
To preserve the dimensions of input, you can set
keepdims=True.>>> jnp.ptp(x, axis=1, keepdims=True) Array([[4], [7], [6]], dtype=int32)
- scico.numpy.put(a, ind, v, mode=None, *, inplace=True)¶
Put elements into an array at given indices.
JAX implementation of
numpy.put.The semantics of
numpy.putare to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplaceparameter which must be set to False` by the user as a reminder of this API difference.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array into which values will be placed.ind (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of indices over the flattened array at which to put values.v (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of values to put into the array.string specifying how to handle out-of-bound indices. Supported values:
"clip"(default): clip out-of-bound indices to the final index."wrap": wrap out-of-bound indices to the beginning of the array.
inplace (
bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.
- Return type:
- Returns:
A copy of
awith specified entries updated.
See also
jax.numpy.place: place elements into an array via boolean mask.jax.numpy.ndarray.at: array updates using NumPy-style indexing.jax.numpy.take: extract values from an array at given indices.
Examples
>>> x = jnp.zeros(5, dtype=int) >>> indices = jnp.array([0, 2, 4]) >>> values = jnp.array([10, 20, 30]) >>> jnp.put(x, indices, values, inplace=False) Array([10, 0, 20, 0, 30], dtype=int32)
This is equivalent to the following
jax.numpy.ndarray.atindexing syntax:>>> x.at[indices].set(values) Array([10, 0, 20, 0, 30], dtype=int32)
There are two modes for handling out-of-bound indices. By default they are clipped:
>>> indices = jnp.array([0, 2, 6]) >>> jnp.put(x, indices, values, inplace=False, mode='clip') Array([10, 0, 20, 0, 30], dtype=int32)
Alternatively, they can be wrapped to the beginning of the array:
>>> jnp.put(x, indices, values, inplace=False, mode='wrap') Array([10, 30, 20, 0, 0], dtype=int32)
For N-dimensional inputs, the indices refer to the flattened array:
>>> x = jnp.zeros((3, 5), dtype=int) >>> indices = jnp.array([0, 7, 14]) >>> jnp.put(x, indices, values, inplace=False) Array([[10, 0, 0, 0, 0], [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32)
- scico.numpy.rad2deg(x, /)¶
Convert angles from radians to degrees.
JAX implementation of
numpy.rad2deg.The angle in radians is converted to degrees by:
\[rad2deg(x) = x * \frac{180}{pi}\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Specifies the angle in radians.- Return type:
- Returns:
An array containing the angles in degrees.
See also
jax.numpy.deg2radandjax.numpy.radians: Converts the angles from degrees to radians.jax.numpy.degrees: Alias ofrad2deg.
Examples
>>> pi = jnp.pi >>> x = jnp.array([pi/4, pi/2, 2*pi/3]) >>> jnp.rad2deg(x) Array([ 45. , 90. , 120.00001], dtype=float32) >>> x * 180 / pi Array([ 45., 90., 120.], dtype=float32)
- scico.numpy.radians(x, /)¶
Alias of
jax.numpy.deg2rad- Return type:
- scico.numpy.ravel_multi_index(multi_index, dims, mode='raise', order='C', *, dtype=None)¶
Convert multi-dimensional indices into flat indices.
JAX implementation of
numpy.ravel_multi_index- Parameters:
multi_index (
Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – sequence of integer arrays containing indices in each dimension.dims (
Sequence[int]) – sequence of integer sizes; must havelen(dims) == len(multi_index)mode (
str) –how to handle out-of bound indices. Options are
"raise"(default): raise a ValueError. This mode is incompatible withjitor other JAX transformations."clip": clip out-of-bound indices to valid range."wrap": wrap out-of-bound indices to valid range."ignore": do not coerce or check input indices. Behavior is undefined if indices are out of bounds.
order (
str) –"C"(default) or"F", specify whether to assume C-style row-major order or Fortran-style column-major order.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – the desired output dtype. If not specified, the dtype is determined by standard type promotion rules of the input multi_index.
- Return type:
- Returns:
array of flattened indices
See also
jax.numpy.unravel_index: inverse of this function.Examples
Define a 2-dimensional array and a sequence of indices of even values:
>>> x = jnp.array([[2., 3., 4.], ... [5., 6., 7.]]) >>> indices = jnp.where(x % 2 == 0) >>> indices (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32)) >>> x[indices] Array([2., 4., 6.], dtype=float32)
Compute the flattened indices:
>>> indices_flat = jnp.ravel_multi_index(indices, x.shape) >>> indices_flat Array([0, 2, 4], dtype=int32)
These flattened indices can be used to extract the same values from the flattened
xarray:>>> x_flat = x.ravel() >>> x_flat Array([2., 3., 4., 5., 6., 7.], dtype=float32) >>> x_flat[indices_flat] Array([2., 4., 6.], dtype=float32)
The original indices can be recovered with
unravel_index:>>> jnp.unravel_index(indices_flat, x.shape) (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
- scico.numpy.real(val, /)¶
Return element-wise real part of the complex argument.
JAX implementation of
numpy.real.- Parameters:
val (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the real part of the elements of
val.
See also
jax.numpy.conjugateandjax.numpy.conj: Returns the element-wise complex-conjugate of the input.jax.numpy.imag: Returns the element-wise imaginary part of the complex argument.
Examples
>>> jnp.real(5) Array(5, dtype=int32, weak_type=True) >>> jnp.real(2j) Array(0., dtype=float32, weak_type=True) >>> x = jnp.array([3-2j, 4+7j, -2j]) >>> jnp.real(x) Array([ 3., 4., -0.], dtype=float32)
- scico.numpy.reciprocal(x, /)¶
Calculate element-wise reciprocal of the input.
JAX implementation of
numpy.reciprocal.The reciprocal is calculated by
1/x.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array of same shape as
xcontaining the reciprocal of each element ofx.
Note
For integer inputs,
np.reciprocalreturns rounded integer output, whilejnp.reciprocalpromotes integer inputs to floating point.Examples
>>> jnp.reciprocal(2) Array(0.5, dtype=float32, weak_type=True) >>> jnp.reciprocal(0.) Array(inf, dtype=float32, weak_type=True) >>> x = jnp.array([1, 5., 4.]) >>> jnp.reciprocal(x) Array([1. , 0.2 , 0.25], dtype=float32)
- scico.numpy.remainder(x1, x2, /)¶
Returns element-wise remainder of the division.
JAX implementation of
numpy.remainder.- Parameters:
- Return type:
- Returns:
An array containing the remainder of element-wise division of
x1byx2with same sign as the elements ofx2.
Note
The result of
jnp.remainderis equivalent tox1 - x2 * jnp.floor(x1 / x2).See also
jax.numpy.mod: Returns the element-wise remainder of the division.jax.numpy.fmod: Calculates the element-wise floating-point modulo operation.jax.numpy.divmod: Calculates the integer quotient and remainder ofx1byx2, element-wise.
Examples
>>> x1 = jnp.array([[3, -1, 4], ... [8, 5, -2]]) >>> x2 = jnp.array([2, 3, -5]) >>> jnp.remainder(x1, x2) Array([[ 1, 2, -1], [ 0, 2, -2]], dtype=int32) >>> x1 - x2 * jnp.floor(x1 / x2) Array([[ 1., 2., -1.], [ 0., 2., -2.]], dtype=float32)
- scico.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None, out_sharding=None)¶
Construct an array from repeated elements.
JAX implementation of
numpy.repeat.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional arrayrepeats (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – 1D integer array specifying the number of repeats. Must match the length of the repeated axis.axis (
int|None) – integer specifying the axis ofaalong which to construct the repeated array. If None (default) thenais first flattened.total_repeat_length (
int|None) – this must be specified statically forjnp.repeatto be compatible withjitand other JAX transformations. Ifsum(repeats)is larger than the specifiedtotal_repeat_length, the remaining values will be discarded. Ifsum(repeats)is smaller thantotal_repeat_length, the final value will be repeated.
- Return type:
- Returns:
an array constructed from repeated values of
a.
See also
jax.numpy.tile: repeat a full array rather than individual values.
Examples
Repeat each value twice along the last axis:
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.repeat(a, 2, axis=-1) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
If
axisis not specified, the input array will be flattened:>>> jnp.repeat(a, 2) Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
Pass an array to
repeatsto repeat each value a different number of times:>>> repeats = jnp.array([2, 3]) >>> jnp.repeat(a, repeats, axis=1) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
In order to use
repeatwithinjitand other JAX transformations, the size of the output must be specified statically usingtotal_repeat_length:>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length']) >>> jit_repeat(a, repeats, axis=1, total_repeat_length=5) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
If total_repeat_length is smaller than
sum(repeats), the result will be truncated:>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
If it is larger, then the additional entries will be filled with the final value:
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7) Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)
- scico.numpy.reshape(a, shape, order='C', *, copy=None, out_sharding=None)¶
Return a reshaped copy of an array.
JAX implementation of
numpy.reshape, implemented in terms ofjax.lax.reshape.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array to reshapeshape (
Union[int,Any,Sequence[Union[int,Any]]]) – integer or sequence of integers giving the new shape, which must match the size of the input array. If any single dimension is given size-1, it will be replaced with a value such that the output has the correct size.order (
str) –'F'or'C', specifies whether the reshape should apply column-major (fortran-style,"F") or row-major (C-style,"C") order; default is"C". JAX does not supportorder="A".copy (
bool|None) – unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away.
- Return type:
- Returns:
reshaped copy of input array with the specified shape.
Notes
Unlike
numpy.reshape,jax.numpy.reshapewill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.Array.reshape: equivalent functionality via an array method.jax.numpy.ravel: flatten an array into a 1D shape.jax.numpy.squeeze: remove one or more length-1 axes from an array’s shape.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
You can use
-1to automatically compute a shape that is consistent with the input size:>>> jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
The default ordering of axes in the reshape is C-style row-major ordering. To use Fortran-style column-major ordering, specify
order='F':>>> jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) >>> jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
For convenience, this functionality is also available via the
jax.Array.reshapemethod:>>> x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
- scico.numpy.resize(a, new_shape)¶
Return a new array with specified shape.
JAX implementation of
numpy.resize.- Parameters:
- Return type:
- Returns:
A resized array with specified shape. The elements of
aare repeated in the resized array, if the resized array is larger than the original array.
See also
jax.numpy.reshape: Returns a reshaped copy of an array.jax.numpy.repeat: Constructs an array from repeated elements.
Examples
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> jnp.resize(x, (3, 3)) Array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32) >>> jnp.resize(x, (3, 4)) Array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 1, 2, 3]], dtype=int32) >>> jnp.resize(4, (3, 2)) Array([[4, 4], [4, 4], [4, 4]], dtype=int32, weak_type=True)
- scico.numpy.result_type(*args)¶
Return the result of applying JAX promotion rules to the inputs.
JAX implementation of
numpy.result_type.JAX’s dtype promotion behavior is described in Type promotion semantics.
- Parameters:
args (
Any) – one or more arrays or dtype-like objects.- Return type:
- Returns:
A
numpy.dtypeinstance representing the result of type promotion for the inputs.
Examples
Inputs can be dtype specifiers:
>>> jnp.result_type('int32', 'float32') dtype('float32') >>> jnp.result_type(np.uint16, np.dtype('int32')) dtype('int32')
Inputs may also be scalars or arrays:
>>> jnp.result_type(1.0, jnp.bfloat16(2)) dtype(bfloat16) >>> jnp.result_type(jnp.arange(4), jnp.zeros(4)) dtype('float32')
Be aware that the result type will be canonicalized based on the state of the
jax_enable_x64configuration flag, meaning that 64-bit types may be downcast to 32-bit:>>> jnp.result_type('float64') dtype('float32')
For details on 64-bit values, refer to Sharp bits - double precision:
- scico.numpy.rint(x, /)¶
Rounds the elements of x to the nearest integer
JAX implementation of
numpy.rint.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array- Return type:
- Returns:
An array-like object containing the rounded elements of
x. Always promotes to inexact.
Note
If an element of x is exactly half way, e.g.
0.5or1.5, rint will round to the nearest even integer.Examples
>>> x1 = jnp.array([5, 4, 7]) >>> jnp.rint(x1) Array([5., 4., 7.], dtype=float32)
>>> x2 = jnp.array([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) >>> jnp.rint(x2) Array([-2., -2., -0., 0., 2., 2., 4., 4.], dtype=float32)
>>> x3 = jnp.array([-2.5+3.5j, 4.5-0.5j]) >>> jnp.rint(x3) Array([-2.+4.j, 4.-0.j], dtype=complex64)
- scico.numpy.roll(a, shift, axis=None)¶
Roll the elements of an array along a specified axis.
JAX implementation of
numpy.roll.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array.shift (
Union[Array,ndarray,bool,number,bool,int,float,complex,Sequence[int]]) – the number of positions to shift the specified axis. If an integer, all axes are shifted by the same amount. If a tuple, the shift for each axis is specified individually.axis (
int|Sequence[int] |None) – the axis or axes to roll. IfNone, the array is flattened, shifted, and then reshaped to its original shape.
- Return type:
- Returns:
A copy of
awith elements rolled along the specified axis or axes.
See also
jax.numpy.rollaxis: roll the specified axis to a given position.
Examples
>>> a = jnp.array([0, 1, 2, 3, 4, 5]) >>> jnp.roll(a, 2) Array([4, 5, 0, 1, 2, 3], dtype=int32)
Roll elements along a specific axis:
>>> a = jnp.array([[ 0, 1, 2, 3], ... [ 4, 5, 6, 7], ... [ 8, 9, 10, 11]]) >>> jnp.roll(a, 1, axis=0) Array([[ 8, 9, 10, 11], [ 0, 1, 2, 3], [ 4, 5, 6, 7]], dtype=int32) >>> jnp.roll(a, [2, 3], axis=[0, 1]) Array([[ 5, 6, 7, 4], [ 9, 10, 11, 8], [ 1, 2, 3, 0]], dtype=int32)
- scico.numpy.rollaxis(a, axis, start=0)¶
Roll the specified axis to a given position.
JAX implementation of
numpy.rollaxis.This function exists for compatibility with NumPy, but in most cases the newer
jax.numpy.moveaxisinstead, because the meaning of its arguments is more intuitive.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array.axis (
int) – index of the axis to roll forward.start (
int) – index toward which the axis will be rolled (default = 0). After normalizing negative axes, ifstart <= axis, the axis is rolled to thestartindex; ifstart > axis, the axis is rolled until the position beforestart.
- Return type:
- Returns:
Copy of
awith rolled axis.
Notes
Unlike
numpy.rollaxis,jax.numpy.rollaxiswill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.moveaxis: newer API with clearer semantics thanrollaxis; this should be preferred torollaxisin most cases.jax.numpy.swapaxes: swap two axes.jax.numpy.transpose: general permutation of axes.
Examples
>>> a = jnp.ones((2, 3, 4, 5))
Roll axis 2 to the start of the array:
>>> jnp.rollaxis(a, 2).shape (4, 2, 3, 5)
Roll axis 1 to the end of the array:
>>> jnp.rollaxis(a, 1, a.ndim).shape (2, 4, 5, 3)
Equivalent of these two with
moveaxis>>> jnp.moveaxis(a, 2, 0).shape (4, 2, 3, 5) >>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)
- scico.numpy.roots(p, *, strip_zeros=True)¶
Returns the roots of a polynomial given the coefficients
p.JAX implementations of
numpy.roots.- Parameters:
p (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Array of polynomial coefficients having rank-1.strip_zeros (
bool) – bool, default=True. If True, then leading zeros in the coefficients will be stripped, similar tonumpy.roots. If set to False, leading zeros will not be stripped, and undefined roots will be represented by NaN values in the function output.strip_zerosmust be set toFalsefor the function to be compatible withjax.jitand other JAX transformations.
- Return type:
- Returns:
An array containing the roots of the polynomial.
Note
Unlike
np.rootsof this function, thejnp.rootsreturns the roots in a complex array regardless of the values of the roots.See also
jax.numpy.poly: Finds the polynomial coefficients of the given sequence of roots.jax.numpy.polyfit: Least squares polynomial fit to data.jax.numpy.polyval: Evaluate a polynomial at specific values.
Examples
>>> coeffs = jnp.array([0, 1, 2])
The default behavior matches numpy and strips leading zeros:
>>> jnp.roots(coeffs) Array([-2.+0.j], dtype=complex64)
With
strip_zeros=False, extra roots are set to NaN:>>> jnp.roots(coeffs, strip_zeros=False) Array([-2. +0.j, nan+nanj], dtype=complex64)
- scico.numpy.rot90(m, k=1, axes=(0, 1))¶
Rotate an array by 90 degrees counterclockwise in the plane specified by axes.
JAX implementation of
numpy.rot90.- Parameters:
m (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array. Must havem.ndim >= 2.k (
int) – int, optional, default=1. Specifies the number of times the array is rotated. For negative values ofk, the array is rotated in clockwise direction.axes (
tuple[int,int]) – tuple of 2 integers, optional, default= (0, 1). The axes define the plane in which the array is rotated. Both the axes must be different.
- Return type:
- Returns:
An array containing the copy of the input,
mrotated by 90 degrees.
See also
jax.numpy.flip: reverse the order along the given axisjax.numpy.fliplr: reverse the order along axis 1 (left/right)jax.numpy.flipud: reverse the order along axis 0 (up/down)
Examples
>>> m = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.rot90(m) Array([[3, 6], [2, 5], [1, 4]], dtype=int32) >>> jnp.rot90(m, k=2) Array([[6, 5, 4], [3, 2, 1]], dtype=int32)
jnp.rot90(m, k=1, axes=(1, 0))is equivalent tojnp.rot90(m, k=-1, axes(0,1)).>>> jnp.rot90(m, axes=(1, 0)) Array([[4, 1], [5, 2], [6, 3]], dtype=int32) >>> jnp.rot90(m, k=-1, axes=(0, 1)) Array([[4, 1], [5, 2], [6, 3]], dtype=int32)
when input array has
ndim>2:>>> m1 = jnp.array([[[1, 2, 3], ... [4, 5, 6]], ... [[7, 8, 9], ... [10, 11, 12]]]) >>> jnp.rot90(m1, k=1, axes=(2, 1)) Array([[[ 4, 1], [ 5, 2], [ 6, 3]], [[10, 7], [11, 8], [12, 9]]], dtype=int32)
- scico.numpy.round(a, decimals=0, out=None)¶
Round input evenly to the given number of decimals.
JAX implementation of
numpy.round.- Parameters:
- Return type:
- Returns:
An array containing the rounded values to the specified
decimalswith same shape and dtype asa.
Note
jnp.roundrounds to the nearest even integer for the values exactly halfway between rounded decimal values.See also
jax.numpy.floor: Rounds the input to the nearest integer downwards.jax.numpy.ceil: Rounds the input to the nearest integer upwards.jax.numpy.fixand :func:numpy.trunc`: Rounds the input to the nearest integer towards zero.
Examples
>>> x = jnp.array([1.532, 3.267, 6.149]) >>> jnp.round(x) Array([2., 3., 6.], dtype=float32) >>> jnp.round(x, decimals=2) Array([1.53, 3.27, 6.15], dtype=float32)
For values exactly halfway between rounded values:
>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5]) >>> jnp.round(x1) Array([10., 22., 12., 32.], dtype=float32)
- scico.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')¶
Perform a binary search within a sorted array.
JAX implementation of
numpy.searchsorted.This will return the indices within a sorted array
awhere values invcan be inserted to maintain its sort order.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – one-dimensional array, assumed to be in sorted order unlesssorteris specified.v (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array of query valuesside (
str) –'left'(default) or'right'; specifies whether insertion indices will be to the left or the right in case of ties.sorter (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional array of indices specifying the sort order ofa. If specified, then the algorithm assumes thata[sorter]is in sorted order.method (
str) – one of'scan'(default),'scan_unrolled','sort'or'compare_all'. See Note below.
- Return type:
- Returns:
Array of insertion indices of shape
v.shape.
Note
The
methodargument controls the algorithm used to compute the insertion indices.'scan'(the default) tends to be more performant on CPU, particularly whenais very large.'scan_unrolled'is more performant on GPU at the expense of additional compile time.'sort'is often more performant on accelerator backends like GPU and TPU, particularly whenvis very large.'compare_all'tends to be the most performant whenais very small.
Examples
Searching for a single value:
>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5]) >>> jnp.searchsorted(a, 2) Array(1, dtype=int32) >>> jnp.searchsorted(a, 2, side='right') Array(3, dtype=int32)
Searching for a batch of values:
>>> vals = jnp.array([0, 3, 8, 1.5, 2]) >>> jnp.searchsorted(a, vals) Array([0, 3, 7, 1, 1], dtype=int32)
Optionally, the
sorterargument can be used to find insertion indices into an array sorted viajax.numpy.argsort:>>> a = jnp.array([4, 3, 5, 1, 2]) >>> sorter = jnp.argsort(a) >>> jnp.searchsorted(a, vals, sorter=sorter) Array([0, 2, 5, 1, 1], dtype=int32)
The result is equivalent to passing the sorted array:
>>> jnp.searchsorted(jnp.sort(a), vals) Array([0, 2, 5, 1, 1], dtype=int32)
- scico.numpy.select(condlist, choicelist, default=0)¶
Select values based on a series of conditions.
JAX implementation of
numpy.select, implemented in terms ofjax.lax.select_n- Parameters:
condlist (
Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – sequence of array-like conditions. All entries must be mutually broadcast-compatible.choicelist (
Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – sequence of array-like values to choose. Must have the same length ascondlist, and all entries must be broadcast-compatible with entries ofcondlist.default (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – value to return when every condition is False (default: 0).
- Return type:
- Returns:
Array of selected values from
choicelistcorresponding to the firstTrueentry incondlistat each location.
See also
jax.numpy.where: select between two values based on a single condition.jax.lax.select_n: select between N values based on an index.
Examples
>>> condlist = [ ... jnp.array([False, True, False, False]), ... jnp.array([True, False, False, False]), ... jnp.array([False, True, True, False]), ... ] >>> choicelist = [ ... jnp.array([1, 2, 3, 4]), ... jnp.array([10, 20, 30, 40]), ... jnp.array([100, 200, 300, 400]), ... ] >>> jnp.select(condlist, choicelist, default=0) Array([ 10, 2, 300, 0], dtype=int32)
This is logically equivalent to the following nested
wherestatement:>>> default = 0 >>> jnp.where(condlist[0], ... choicelist[0], ... jnp.where(condlist[1], ... choicelist[1], ... jnp.where(condlist[2], ... choicelist[2], ... default))) Array([ 10, 2, 300, 0], dtype=int32)
However, for efficiency it is implemented in terms of
jax.lax.select_n.
- scico.numpy.set_printoptions(*args, **kwargs)¶
Alias of
numpy.set_printoptions.JAX arrays are printed via NumPy, so NumPy’s printoptions configurations will apply to printed JAX arrays.
See the
numpy.set_printoptionsdocumentation for details on the available options and their meanings.
- scico.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)¶
Compute the set difference of two 1D arrays.
JAX implementation of
numpy.setdiff1d.Because the size of the output of
setdiff1dis data-dependent, the function is not typically compatible withjitand other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.setdiff1dto be used in such contexts.- Parameters:
ar1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first array of elements to be differenced.ar2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second array of elements to be differenced.assume_unique (
bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but ifassume_uniqueis True and the input arrays contain duplicates, the behavior is undefined. default: False.size (
int|None) – if specified, return only the firstsizesorted elements. If there are fewer elements thansizeindicates, the return value will be padded withfill_value.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum value.
- Return type:
- Returns:
an array containing the set difference of elements in the input array – i.e. the elements in
ar1that are not contained inar2.
See also
jax.numpy.intersect1d: the set intersection of two 1D arrays.jax.numpy.setxor1d: the set XOR of two 1D arrays.jax.numpy.union1d: the set union of two 1D arrays.
Examples
Computing the set difference of two arrays:
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.setdiff1d(ar1, ar2) Array([1, 2], dtype=int32)
Because the output shape is dynamic, this will fail under
jitand other transformations:>>> jax.jit(jnp.setdiff1d)(ar1, ar2) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static
sizeargument:>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size']) >>> jit_setdiff1d(ar1, ar2, size=2) Array([1, 2], dtype=int32)
If
sizeis too small, the difference is truncated:>>> jit_setdiff1d(ar1, ar2, size=1) Array([1], dtype=int32)
If
sizeis too large, then the output is padded withfill_value:>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0) Array([1, 2, 0, 0], dtype=int32)
- scico.numpy.setxor1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)¶
Compute the set-wise xor of elements in two arrays.
JAX implementation of
numpy.setxor1d.Because the size of the output of
setxor1dis data-dependent, the function is not compatible with JIT or other JAX transformations.- Parameters:
ar1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first array of values to intersect.ar2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second array of values to intersect.assume_unique (
bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but ifassume_uniqueis True and the input arrays contain duplicates, the behavior is undefined. default: False.size (
int|None) – if specified, return only the firstsizesorted elements. If there are fewer elements thansizeindicates, the return value will be padded withfill_value, and returned indices will be padded with an out-of-bound index.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the smallest value in the xor result.
- Return type:
- Returns:
An array of values that are found in exactly one of the input arrays.
See also
jax.numpy.intersect1d: the set intersection of two 1D arrays.jax.numpy.union1d: the set union of two 1D arrays.jax.numpy.setdiff1d: the set difference of two 1D arrays.
Examples
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.setxor1d(ar1, ar2) Array([1, 2, 5, 6], dtype=int32)
- scico.numpy.shape(a)¶
Return the shape an array.
JAX implementation of
numpy.shape. Unlikenp.shape, this function raises aTypeErrorif the input is a collection such as a list or tuple.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex,SupportsShape]) – array-like object, or any object with ashapeattribute.- Return type:
- Returns:
An tuple of integers representing the shape of
a.
Examples
Shape for arrays:
>>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3)
This also works for scalars:
>>> jnp.shape(3.14) ()
For arrays, this can also be accessed via the
jax.Array.shapeproperty:>>> x.shape (10,)
- scico.numpy.sign(x, /)¶
Return an element-wise indication of sign of the input.
JAX implementation of
numpy.sign.The sign of
xfor real-valued input is:\[\begin{split}\mathrm{sign}(x) = \begin{cases} 1, & x > 0\\ 0, & x = 0\\ -1, & x < 0 \end{cases}\end{split}\]For complex valued input,
jnp.signreturns a unit vector representing the phase. For generalized case, the sign ofxis given by:\[\begin{split}\mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\\ 0, & x = 0 \end{cases}\end{split}\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array with same shape and dtype as
xcontaining the sign indication.
See also
jax.numpy.positive: Returns element-wise positive values of the input.jax.numpy.negative: Returns element-wise negative values of the input.
Examples
For Real-valued inputs:
>>> x = jnp.array([0., -3., 7.]) >>> jnp.sign(x) Array([ 0., -1., 1.], dtype=float32)
For complex-inputs:
>>> x1 = jnp.array([1, 3+4j, 5j]) >>> jnp.sign(x1) Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)
- scico.numpy.signbit(x, /)¶
Return the sign bit of array elements.
JAX implementation of
numpy.signbit.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array. Complex values are not supported.- Return type:
- Returns:
A boolean array of the same shape as
x, containingTruewhere the sign ofxis negative, andFalseotherwise.
See also
jax.numpy.sign: return the mathematical sign of array elements, i.e.-1,0, or+1.
Examples
signbiton boolean values is alwaysFalse:>>> x = jnp.array([True, False]) >>> jnp.signbit(x) Array([False, False], dtype=bool)
signbiton integer values is equivalent tox < 0:>>> x = jnp.array([-2, -1, 0, 1, 2]) >>> jnp.signbit(x) Array([ True, True, False, False, False], dtype=bool)
signbiton floating point values returns the value of the actual sign bit from the float representation, including signed zero:>>> x = jnp.array([-1.5, -0.0, 0.0, 1.5]) >>> jnp.signbit(x) Array([ True, True, False, False], dtype=bool)
This also returns the sign bit for special values such as signed NaN and signed infinity:
>>> x = jnp.array([jnp.nan, -jnp.nan, jnp.inf, -jnp.inf]) >>> jnp.signbit(x) Array([False, True, False, True], dtype=bool)
- scico.numpy.sin(x, /)¶
Compute a trigonometric sine of each element of input.
JAX implementation of
numpy.sin.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array or scalar. Angle in radians.- Return type:
- Returns:
An array containing the sine of each element in
x, promotes to inexact dtype.
See also
jax.numpy.cos: Computes a trigonometric cosine of each element of input.jax.numpy.tan: Computes a trigonometric tangent of each element of input.jax.numpy.arcsinandjax.numpy.asin: Computes the inverse of trigonometric sine of each element of input.
Examples
>>> pi = jnp.pi >>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.sin(x)) [ 0.707 1. 0.707 -0. ]
- scico.numpy.sinc(x, /)¶
Calculate the normalized sinc function.
JAX implementation of
numpy.sinc.The normalized sinc function is given by
\[\mathrm{sinc}(x) = \frac{\sin({\pi x})}{\pi x}\]where
sinc(0)returns the limit value of1. The sinc function is smooth and infinitely differentiable.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array; will be promoted to an inexact type.- Return type:
- Returns:
An array of the same shape as
xcontaining the result.
Examples
>>> x = jnp.array([-1, -0.5, 0, 0.5, 1]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sinc(x) Array([-0. , 0.637, 1. , 0.637, -0. ], dtype=float32)
Compare this to the naive approach to computing the function, which is undefined at zero:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sin(jnp.pi * x) / (jnp.pi * x) Array([-0. , 0.637, nan, 0.637, -0. ], dtype=float32)
JAX defines a custom gradient rule for sinc to allow accurate evaluation of the gradient at zero even for higher-order derivatives:
>>> f = jnp.sinc >>> for i in range(1, 6): ... f = jax.grad(f) ... print(f"(d/dx)^{i} f(0.0) = {f(0.0):.2f}") ... (d/dx)^1 f(0.0) = 0.00 (d/dx)^2 f(0.0) = -3.29 (d/dx)^3 f(0.0) = 0.00 (d/dx)^4 f(0.0) = 19.48 (d/dx)^5 f(0.0) = 0.00
- scico.numpy.sinh(x, /)¶
Calculate element-wise hyperbolic sine of input.
JAX implementation of
numpy.sinh.The hyperbolic sine is defined by:
\[sinh(x) = \frac{e^x - e^{-x}}{2}\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the hyperbolic sine of each element of
x, promoting to inexact dtype.
Note
jnp.sinhis equivalent to computing-1j * jnp.sin(1j * x).See also
jax.numpy.cosh: Computes the element-wise hyperbolic cosine of the input.jax.numpy.tanh: Computes the element-wise hyperbolic tangent of the input.jax.numpy.arcsinh: Computes the element-wise inverse of hyperbolic sine of the input.
Examples
>>> x = jnp.array([[-2, 3, 5], ... [0, -1, 4]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sinh(x) Array([[-3.627, 10.018, 74.203], [ 0. , -1.175, 27.29 ]], dtype=float32) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.sin(1j * x) Array([[-3.627+0.j, 10.018-0.j, 74.203-0.j], [ 0. -0.j, -1.175+0.j, 27.29 -0.j]], dtype=complex64, weak_type=True)
For complex-valued input:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sinh(3-2j) Array(-4.169-9.154j, dtype=complex64, weak_type=True) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.sin(1j * (3-2j)) Array(-4.169-9.154j, dtype=complex64, weak_type=True)
- scico.numpy.size(a, axis=None)¶
Return number of elements along a given axis.
JAX implementation of
numpy.size. Unlikenp.size, this function raises aTypeErrorif the input is a collection such as a list or tuple.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex,SupportsSize,SupportsShape]) – array-like object, or any object with asizeattribute whenaxisis not specified, or with ashapeattribute whenaxisis specified.axis (
int|Sequence[int] |None) – optional integer or sequence of integers indicating which axis or axes to count elements along.None(the default) returns the total number of elements.
- Return type:
- Returns:
An integer specifying the number of elements in
a.
Examples
Size for arrays:
>>> x = jnp.arange(10) >>> jnp.size(x) 10 >>> y = jnp.ones((2, 3)) >>> jnp.size(y) 6 >>> jnp.size(y, axis=1) 3 >>> jnp.size(y, axis=(1,)) 3 >>> jnp.size(y, axis=(0, 1)) 6
This also works for scalars:
>>> jnp.size(3.14) 1
For arrays, this can also be accessed via the
jax.Array.sizeproperty:>>> y.size 6
- scico.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)¶
Return a sorted copy of an array.
JAX implementation of
numpy.sort.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array to sortaxis (
int|None) – integer axis along which to sort. Defaults to-1, i.e. the last axis. IfNone, thenais flattened before being sorted.stable (
bool) – boolean specifying whether a stable sort should be used. Default=True.descending (
bool) – boolean specifying whether to sort in descending order. Default=False.kind (
None) – deprecated; instead specify sort algorithm using stable=True or stable=False.order (
None) – not supported by JAX
- Return type:
- Returns:
Sorted array of shape
a.shape(ifaxisis an integer) or of shape(a.size,)(ifaxisis None).
Examples
Simple 1-dimensional sort
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> jnp.sort(x) Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
>>> x = jnp.array([[2, 1, 3], ... [4, 3, 6]]) >>> jnp.sort(x, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
See also
jax.numpy.argsort: return indices of sorted values.jax.numpy.lexsort: lexicographical sort of multiple arrays.jax.lax.sort: lower-level function wrapping XLA’s Sort operator.
- scico.numpy.sort_complex(a)¶
Return a sorted copy of complex array.
JAX implementation of
numpy.sort_complex.Complex numbers are sorted lexicographically, meaning by their real part first, and then by their imaginary part if real parts are equal.
- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array. If dtype is not complex, the array will be upcast to complex.- Return type:
- Returns:
A sorted array of the same shape and complex dtype as the input. If
ais multi-dimensional, it is sorted along the last axis.
See also
jax.numpy.sort: Return a sorted copy of an array.
Examples
>>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j]) >>> jnp.sort_complex(a) Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)
Multi-dimensional arrays are sorted along the last axis:
>>> a = jnp.array([[5, 3, 4], ... [6, 9, 2]]) >>> jnp.sort_complex(a) Array([[3.+0.j, 4.+0.j, 5.+0.j], [2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64)
- scico.numpy.split(ary, indices_or_sections, axis=0)¶
Split an array into sub-arrays.
JAX implementation of
numpy.split.- Parameters:
ary (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array-like object to splitindices_or_sections (
Union[int,Sequence[int],Array,ndarray,bool,number,bool,float,complex]) –either a single integer or a sequence of indices.
if
indices_or_sectionsis an integer N, then N must evenly divideary.shape[axis]andarywill be divided into N equally-sized chunks alongaxis.if
indices_or_sectionsis a sequence of integers, then these integers specify the boundary between unevenly-sized chunks alongaxis; see examples below.
axis (
int) – the axis along which to split; defaults to 0.
- Return type:
- Returns:
A list of arrays. If
indices_or_sectionsis an integer N, then the list is of length N. Ifindices_or_sectionsis a sequence seq, then the list is is of length len(seq) + 1.
Examples
Splitting a 1-dimensional array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
Split into three equal sections:
>>> chunks = jnp.split(x, 3) >>> print(*chunks) [1 2 3] [4 5 6] [7 8 9]
Split into sections by index:
>>> chunks = jnp.split(x, [2, 7]) # [x[0:2], x[2:7], x[7:]] >>> print(*chunks) [1 2] [3 4 5 6 7] [8 9]
Splitting a two-dimensional array along axis 1:
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8]]) >>> x1, x2 = jnp.split(x, 2, axis=1) >>> print(x1) [[1 2] [5 6]] >>> print(x2) [[3 4] [7 8]]
See also
jax.numpy.array_split: likesplit, but allowsindices_or_sectionsto be an integer that does not evenly divide the size of the array.jax.numpy.vsplit: split vertically, i.e. along axis=0jax.numpy.hsplit: split horizontally, i.e. along axis=1jax.numpy.dsplit: split depth-wise, i.e. along axis=2
- scico.numpy.sqrt(x, /)¶
Calculates element-wise non-negative square root of the input array.
JAX implementation of
numpy.sqrt.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the non-negative square root of the elements of
x.
Note
For real-valued negative inputs,
jnp.sqrtproduces ananoutput.For complex-valued negative inputs,
jnp.sqrtproduces acomplexoutput.
See also
jax.numpy.square: Calculates the element-wise square of the input.jax.numpy.power: Calculates the element-wise basex1exponential ofx2.
Examples
>>> x = jnp.array([-8-6j, 1j, 4]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sqrt(x) Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True)
- scico.numpy.square(x, /)¶
Calculate element-wise square of the input array.
JAX implementation of
numpy.square.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the square of the elements of
x.
Note
jnp.squareis equivalent to computingjnp.power(x, 2).See also
jax.numpy.sqrt: Calculates the element-wise non-negative square root of the input array.jax.numpy.power: Calculates the element-wise basex1exponential ofx2.jax.lax.integer_pow: Computes element-wise power \(x^y\), where \(y\) is a fixed integer.jax.numpy.float_power: Computes the first array raised to the power of second array, element-wise, by promoting to the inexact dtype.
Examples
>>> x = jnp.array([3, -2, 5.3, 1]) >>> jnp.square(x) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) >>> jnp.power(x, 2) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
For integer inputs:
>>> x1 = jnp.array([2, 4, 5, 6]) >>> jnp.square(x1) Array([ 4, 16, 25, 36], dtype=int32)
For complex-valued inputs:
>>> x2 = jnp.array([1-3j, -1j, 2]) >>> jnp.square(x2) Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64)
- scico.numpy.squeeze(a, axis=None)¶
Remove one or more length-1 axes from array
JAX implementation of
numpy.sqeeze, implemented viajax.lax.squeeze.- Parameters:
- Return type:
- Returns:
copy of
awith length-1 axes removed.
Notes
Unlike
numpy.squeeze,jax.numpy.squeezewill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.expand_dims: the inverse ofsqueeze: add dimensions of length 1.jax.Array.squeeze: equivalent functionality via an array method.jax.lax.squeeze: equivalent XLA API.jax.numpy.ravel: flatten an array into a 1D shape.jax.numpy.reshape: general array reshape.
Examples
>>> x = jnp.array([[[0]], [[1]], [[2]]]) >>> x.shape (3, 1, 1)
Squeeze all length-1 dimensions:
>>> jnp.squeeze(x) Array([0, 1, 2], dtype=int32) >>> _.shape (3,)
Equivalent while specifying the axes explicitly:
>>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)
Attempting to squeeze a non-unit axis results in an error:
>>> jnp.squeeze(x, axis=0) Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
For convenience, this functionality is also available via the
jax.Array.squeezemethod:>>> x.squeeze() Array([0, 1, 2], dtype=int32)
- scico.numpy.stack(arrays, axis=0, out=None, dtype=None)¶
Join arrays along a new axis.
JAX implementation of
numpy.stack.- Parameters:
arrays (
ndarray|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – a sequence of arrays to stack; each must have the same shape. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.axis (
int) – specify the axis along which to stack.out (
None) – unused by JAXdtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Return type:
- Returns:
the stacked result.
See also
jax.numpy.unstack: inverse ofstack.jax.numpy.concatenate: concatenation along existing axes.jax.numpy.vstack: stack vertically, i.e. along axis 0.jax.numpy.hstack: stack horizontally, i.e. along axis 1.jax.numpy.dstack: stack depth-wise, i.e. along axis 2.jax.numpy.column_stack: stack columns.
Examples
>>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
unstackperforms the inverse operation:>>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32)
- scico.numpy.std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)¶
Compute the standard deviation along a given axis.
JAX implementation of
numpy.std.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array.axis (
Union[int,Sequence[int],None]) – optional, int or sequence of ints, default=None. Axis along which the standard deviation is computed. If None, standard deviaiton is computed along all the axes.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – The type of the output array. Default=None.ddof (
int) – int, default=0. Degrees of freedom. The divisor in the standard deviation computation isN-ddof,Nis number of elements along given axis.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional, boolean array, default=None. The elements to be used in the standard deviation. Array should be broadcast compatible to the input.mean (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the standard deviation instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean withkeepdims=Trueandaxismatching this function’saxisargument.correction (
int|float|None) – int or float, default=None. Alternative name forddof. Both ddof and correction can’t be provided simultaneously.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of the standard deviation along the given axis.
See also
jax.numpy.var: Compute the variance of array elements over given axis.jax.numpy.mean: Compute the mean of array elements over a given axis.jax.numpy.nanvar: Compute the variance along a given axis, ignoring NaNs values.jax.numpy.nanstd: Computed the standard deviation of a given axis, ignoring NaN values.
Examples
By default,
jnp.stdcomputes the standard deviation along all axes.>>> x = jnp.array([[1, 3, 4, 2], ... [4, 2, 5, 3], ... [5, 4, 2, 3]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jnp.std(x) Array(1.21, dtype=float32)
If
axis=0, computes along axis 0.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0)) [1.7 0.82 1.25 0.47]
To preserve the dimensions of input, you can set
keepdims=True.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0, keepdims=True)) [[1.7 0.82 1.25 0.47]]
If
ddof=1:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0, keepdims=True, ddof=1)) [[2.08 1. 1.53 0.58]]
To include specific elements of the array to compute standard deviation, you can use
where.>>> where = jnp.array([[1, 0, 1, 0], ... [0, 1, 0, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.std(x, axis=0, keepdims=True, where=where) Array([[2., 1., 1., 0.]], dtype=float32)
- scico.numpy.subtract(*args: ArrayLike, out: None = None, where: None = None) Any¶
Subtract two arrays element-wise.
JAX implementation of
numpy.subtract. This is a universal function, and supports the additional APIs described atjax.numpy.ufunc. This function provides the implementation of the-operator for JAX arrays.- Parameters:
x – arrays to subtract. Must be broadcastable to a common shape.
y – arrays to subtract. Must be broadcastable to a common shape.
- Returns:
Array containing the result of the element-wise subtraction.
Examples
Calling
subtractexplicitly:>>> x = jnp.arange(4) >>> jnp.subtract(x, 10) Array([-10, -9, -8, -7], dtype=int32)
Calling
subtractvia the-operator:>>> x - 10 Array([-10, -9, -8, -7], dtype=int32)
- scico.numpy.sum(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)¶
Sum of the elements of the array over a given axis.
JAX implementation of
numpy.sum.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Input array.axis (
Union[int,Sequence[int],None]) – int or array, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – The type of the output array. Default=None.out (
None) – Unused by JAXkeepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.initial (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, Default=None. Initial value for the sum.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – int or array, default=None. The elements to be used in the sum. Array should be broadcast compatible to the input.promote_integers (
bool) – bool, default=True. If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input.promote_integersis ignored ifdtypeis specified.
- Return type:
- Returns:
An array of the sum along the given axis.
See also
jax.numpy.prod: Compute the product of array elements over a given axis.jax.numpy.max: Compute the maximum of array elements over given axis.jax.numpy.min: Compute the minimum of array elements over given axis.
Examples
By default, the sum is computed along all the axes.
>>> x = jnp.array([[1, 3, 4, 2], ... [5, 2, 6, 3], ... [8, 1, 3, 9]]) >>> jnp.sum(x) Array(47, dtype=int32)
If
axis=1, the sum is computed along axis 1.>>> jnp.sum(x, axis=1) Array([10, 16, 21], dtype=int32)
If
keepdims=True,ndimof the output is equal to that of the input.>>> jnp.sum(x, axis=1, keepdims=True) Array([[10], [16], [21]], dtype=int32)
To include only specific elements in the sum, you can use
where.>>> where=jnp.array([[0, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.sum(x, axis=1, keepdims=True, where=where) Array([[ 4], [ 9], [12]], dtype=int32) >>> where=jnp.array([[False], ... [False], ... [False]]) >>> jnp.sum(x, axis=0, keepdims=True, where=where) Array([[0, 0, 0, 0]], dtype=int32)
- scico.numpy.swapaxes(a, axis1, axis2)¶
Swap two axes of an array.
JAX implementation of
numpy.swapaxes, implemented in terms ofjax.lax.transpose.- Parameters:
- Return type:
- Returns:
Copy of
awith specified axes swapped.
Notes
Unlike
numpy.swapaxes,jax.numpy.swapaxeswill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.moveaxis: move a single axis of an array.jax.numpy.rollaxis: older API formoveaxis.jax.lax.transpose: more general axes permutations.jax.Array.swapaxes: same functionality via an array method.
Examples
>>> a = jnp.ones((2, 3, 4, 5)) >>> jnp.swapaxes(a, 1, 3).shape (2, 5, 4, 3)
Equivalent output via the
swapaxesarray method:>>> a.swapaxes(1, 3).shape (2, 5, 4, 3)
Equivalent output via
transpose:>>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3)
- scico.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)¶
Take elements from an array.
JAX implementation of
numpy.take, implemented in terms ofjax.lax.gather. JAX’s behavior differs from NumPy in the case of out-of-bound indices; see themodeparameter below.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array from which to take values.indices (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array of integer indices of values to take from the array.axis (
int|None) – the axis along which to take values. If not specified, the array will be flattened before indexing is applied.mode (
str|None) – Out-of-bounds indexing mode, either"fill"or"clip". The defaultmode="fill"returns invalid values (e.g. NaN) for out-of bounds indices; thefill_valueargument gives control over this value. For more discussion ofmodeoptions, seejax.numpy.ndarray.at.fill_value (
Union[bool,number,bool,int,float,complex,None]) – The fill value to return for out-of-bounds slices when mode is ‘fill’. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.unique_indices (
bool) – If True, the implementation will assume that the indices are unique after normalization of negative indices, which lets the compiler emit more efficient code during the backward pass. If set to True and normalized indices are not unique, the result is implementation-defined and may be non-deterministic.indices_are_sorted (
bool) – If True, the implementation will assume that the indices are sorted in ascending order after normalization of negative indices, which can lead to more efficient execution on some backends. If set to True and normalized indices are not sorted, the output is implementation-defined.
- Return type:
- Returns:
Array of values extracted from
a.
See also
jax.numpy.ndarray.at: take values via indexing syntax.jax.numpy.take_along_axis: take values along an axis
Examples
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([2, 0])
Passing no axis results in indexing into the flattened array:
>>> jnp.take(x, indices) Array([3., 1.], dtype=float32) >>> x.ravel()[indices] # equivalent indexing syntax Array([3., 1.], dtype=float32)
Passing an axis results ind applying the index to every subarray along the axis:
>>> jnp.take(x, indices, axis=1) Array([[3., 1.], [6., 4.]], dtype=float32) >>> x[:, indices] # equivalent indexing syntax Array([[3., 1.], [6., 4.]], dtype=float32)
Out-of-bound indices fill with invalid values. For float inputs, this is NaN:
>>> jnp.take(x, indices, axis=0) Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32)
This default out-of-bound behavior can be adjusted using the
modeparameter, for example, we can instead clip to the last valid value:>>> jnp.take(x, indices, axis=0, mode='clip') Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='clip') # equivalent indexing syntax Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32)
- scico.numpy.tan(x, /)¶
Compute a trigonometric tangent of each element of input.
JAX implementation of
numpy.tan.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – scalar or array. Angle in radians.- Return type:
- Returns:
An array containing the tangent of each element in
x, promotes to inexact dtype.
See also
jax.numpy.sin: Computes a trigonometric sine of each element of input.jax.numpy.cos: Computes a trigonometric cosine of each element of input.jax.numpy.arctanandjax.numpy.atan: Computes the inverse of trigonometric tangent of each element of input.
Examples
>>> pi = jnp.pi >>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6]) >>> with jnp.printoptions(precision=3, suppress=True): ... print(jnp.tan(x)) [ 0. 0.577 1. -1. -0.577]
- scico.numpy.tanh(x, /)¶
Calculate element-wise hyperbolic tangent of input.
JAX implementation of
numpy.tanh.The hyperbolic tangent is defined by:
\[tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array containing the hyperbolic tangent of each element of
x, promoting to inexact dtype.
Note
jnp.tanhis equivalent to computing-1j * jnp.tan(1j * x).See also
jax.numpy.sinh: Computes the element-wise hyperbolic sine of the input.jax.numpy.cosh: Computes the element-wise hyperbolic cosine of the input.jax.numpy.arctanh: Computes the element-wise inverse of hyperbolic tangent of the input.
Examples
>>> x = jnp.array([[-1, 0, 1], ... [3, -2, 5]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.tanh(x) Array([[-0.762, 0. , 0.762], [ 0.995, -0.964, 1. ]], dtype=float32) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.tan(1j * x) Array([[-0.762+0.j, 0. -0.j, 0.762-0.j], [ 0.995-0.j, -0.964+0.j, 1. -0.j]], dtype=complex64, weak_type=True)
For complex-valued input:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.tanh(2-5j) Array(1.031+0.021j, dtype=complex64, weak_type=True) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.tan(1j * (2-5j)) Array(1.031+0.021j, dtype=complex64, weak_type=True)
- scico.numpy.tensordot(a, b, axes=2, *, precision=None, preferred_element_type=None, out_sharding=None)¶
Compute the tensor dot product of two N-dimensional arrays.
JAX implementation of
numpy.linalg.tensordot.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional arrayb (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – M-dimensional arrayaxes (
int|Sequence[int] |Sequence[Sequence[int]]) – integer or tuple of sequences of integers. If an integer k, then sum over the last k axes ofaand the first k axes ofb, in order. If a tuple, thenaxes[0]specifies the axes ofaandaxes[1]specifies the axes ofb.precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofaandb.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
array containing the tensor dot product of the inputs
See also
jax.numpy.einsum: NumPy API for more general tensor contractions.jax.lax.dot_general: XLA API for more general tensor contractions.
Examples
>>> x1 = jnp.arange(24.).reshape(2, 3, 4) >>> x2 = jnp.ones((3, 4, 5)) >>> jnp.tensordot(x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result when specifying the axes as explicit sequences:
>>> jnp.tensordot(x1, x2, axes=([1, 2], [0, 1])) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result via
einsum:>>> jnp.einsum('ijk,jkm->im', x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Setting
axes=1for two-dimensional inputs is equivalent to a matrix multiplication:>>> x1 = jnp.array([[1, 2], ... [3, 4]]) >>> x2 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.tensordot(x1, x2, axes=1) Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32) >>> x1 @ x2 Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32)
Setting
axes=0for one-dimensional inputs is equivalent toouter:>>> x1 = jnp.array([1, 2]) >>> x2 = jnp.array([1, 2, 3]) >>> jnp.linalg.tensordot(x1, x2, axes=0) Array([[1, 2, 3], [2, 4, 6]], dtype=int32) >>> jnp.outer(x1, x2) Array([[1, 2, 3], [2, 4, 6]], dtype=int32)
- scico.numpy.tile(A, reps)¶
Construct an array by repeating
Aalong specified dimensions.JAX implementation of
numpy.tile.If
Ais an array of shape(d1, d2, ..., dn)andrepsis a sequence of integers, the resulting array will have a shape of(reps[0] * d1, reps[1] * d2, ..., reps[n] * dn), withAtiled along each dimension.- Parameters:
- Return type:
- Returns:
a new array where the input array has been repeated according to
reps.
See also
jax.numpy.repeat: Construct an array from repeated elements.jax.numpy.broadcast_to: Broadcast an array to a specified shape.
Examples
>>> arr = jnp.array([1, 2]) >>> jnp.tile(arr, 2) Array([1, 2, 1, 2], dtype=int32) >>> arr = jnp.array([[1, 2], ... [3, 4,]]) >>> jnp.tile(arr, (2, 1)) Array([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=int32)
- scico.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)¶
Calculate sum of the diagonal of input along the given axes.
JAX implementation of
numpy.trace.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array. Must havea.ndim >= 2.offset (
Union[int,Array,ndarray,bool,number,bool,float,complex]) – optional, int, default=0. Diagonal offset from the main diagonal. Can be positive or negative.axis1 (
int) – optional, default=0. The first axis along which to take the sum of diagonal. Must be a static integer value.axis2 (
int) – optional, default=1. The second axis along which to take the sum of diagonal. Must be a static integer value.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional. The dtype of the output array. Should be provided as static argument in JIT compilation.out (
None) – Not used by JAX.
- Return type:
- Returns:
An array of dimension x.ndim-2 containing the sum of the diagonal elements along axes (axis1, axis2)
See also
jax.numpy.diag: Returns the specified diagonal or constructs a diagonal arrayjax.numpy.diagonal: Returns the specified diagonal of an array.jax.numpy.diagflat: Returns a 2-D array with the flattened input array laid out on the diagonal.
Examples
>>> x = jnp.arange(1, 9).reshape(2, 2, 2) >>> x Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=int32) >>> jnp.trace(x) Array([ 8, 10], dtype=int32) >>> jnp.trace(x, offset=1) Array([3, 4], dtype=int32) >>> jnp.trace(x, axis1=1, axis2=2) Array([ 5, 13], dtype=int32) >>> jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32)
- scico.numpy.transpose(a, axes=None)¶
Return a transposed version of an N-dimensional array.
JAX implementation of
numpy.transpose, implemented in terms ofjax.lax.transpose.- Parameters:
- Return type:
- Returns:
transposed copy of the array.
See also
jax.Array.transpose: equivalent function via anArraymethod.jax.Array.T: equivalent function via anArrayproperty.jax.numpy.matrix_transpose: transpose the last two axes of an array. This is suitable for working with batched 2D matrices.jax.numpy.swapaxes: swap any two axes in an array.jax.numpy.moveaxis: move an axis to another position in the array.
Note
Unlike
numpy.transpose,jax.numpy.transposewill return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.Examples
For a 1D array, the transpose is the identity:
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
For a 2D array, the transpose is a matrix transpose:
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
For an N-dimensional array, the transpose reverses the order of the axes:
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
The
axesargument can be specified to change this default behavior:>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
Since swapping the last two axes is a common operation, it can be done via its own API,
jax.numpy.matrix_transpose:>>> jnp.matrix_transpose(x).shape (3, 5, 4)
For convenience, transposes may also be performed using the
jax.Array.transposemethod or thejax.Array.Tproperty:>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)
- scico.numpy.tri(N, M=None, k=0, dtype=None)¶
Return an array with ones on and below the diagonal and zeros elsewhere.
JAX implementation of
numpy.tri- Parameters:
N (
int) – int. Dimension of the rows of the returned array.M (
int|None) – optional, int. Dimension of the columns of the returned array. If not specified, thenM = N.k (
int) – optional, int, default=0. Specifies the sub-diagonal on and below which the array is filled with ones.k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refers to sub-diagonal above the main diagonal.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional, data type of the returned array. The default type is float.
- Return type:
- Returns:
An array of shape
(N, M)containing the lower triangle with elements below the sub-diagonal specified bykare set to one and zero elsewhere.
See also
jax.numpy.tril: Returns a lower triangle of an array.jax.numpy.triu: Returns an upper triangle of an array.
Examples
>>> jnp.tri(3) Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)
When
Mis not equal toN:>>> jnp.tri(3, 4) Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)
when
k>0:>>> jnp.tri(3, k=1) Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)
When
k<0:>>> jnp.tri(3, 4, k=-1) Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)
- scico.numpy.tril_indices(n, k=0, m=None)¶
Return the indices of lower triangle of an array of size
(n, m).JAX implementation of
numpy.tril_indices.- Parameters:
n (
Union[int,Any]) – int. Number of rows of the array for which the indices are returned.k (
Union[int,Any]) – optional, int, default=0. Specifies the sub-diagonal on and below which the indices of lower triangle are returned.k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refers to sub-diagonal above the main diagonal.m (
Union[int,Any,None]) – optional, int. Number of columns of the array for which the indices are returned. If not specified, thenm = n.
- Return type:
- Returns:
A tuple of two arrays containing the indices of the lower triangle, one along each axis.
See also
jax.numpy.triu_indices: Returns the indices of upper triangle of an array of size(n, m).jax.numpy.triu_indices_from: Returns the indices of upper triangle of a given array.jax.numpy.tril_indices_from: Returns the indices of lower triangle of a given array.
Examples
If only
nis provided in input, the indices of lower triangle of an array of size(n, n)array are returned.>>> jnp.tril_indices(3) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
If both
nandmare provided in input, the indices of lower triangle of an(n, m)array are returned.>>> jnp.tril_indices(3, m=2) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1], dtype=int32))
If
k = 1, the indices on and below the first sub-diagonal above the main diagonal are returned.>>> jnp.tril_indices(3, k=1) (Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))
If
k = -1, the indices on and below the first sub-diagonal below the main diagonal are returned.>>> jnp.tril_indices(3, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))
- scico.numpy.tril_indices_from(arr, k=0)¶
Return the indices of lower triangle of a given array.
JAX implementation of
numpy.tril_indices_from.- Parameters:
arr (
Union[Array,ndarray,bool,number,bool,int,float,complex,SupportsShape]) – input array. Must havearr.ndim == 2.k (
int) – optional, int, default=0. Specifies the sub-diagonal on and below which the indices of upper triangle are returned.k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refers to sub-diagonal above the main diagonal.
- Return type:
- Returns:
A tuple of two arrays containing the indices of the lower triangle, one along each axis.
See also
jax.numpy.triu_indices_from: Returns the indices of upper triangle of a given array.jax.numpy.tril_indices: Returns the indices of lower triangle of an array of size(n, m).jax.numpy.tril: Returns a lower triangle of an array
Examples
>>> arr = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.tril_indices_from(arr) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
Elements indexed by
jnp.tril_indices_fromcorrespond to those in the output ofjnp.tril.>>> ind = jnp.tril_indices_from(arr) >>> arr[ind] Array([1, 4, 5, 7, 8, 9], dtype=int32) >>> jnp.tril(arr) Array([[1, 0, 0], [4, 5, 0], [7, 8, 9]], dtype=int32)
When
k > 0:>>> jnp.tril_indices_from(arr, k=1) (Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))
When
k < 0:>>> jnp.tril_indices_from(arr, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))
- scico.numpy.trim_zeros(filt, trim='fb', axis=None)¶
Trim leading and/or trailing zeros of the input array.
JAX implementation of
numpy.trim_zeros.- Parameters:
filt (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional input array.trim (
str) –string, optional, default =
fb. Specifies from which end the input is trimmed.f- trims only the leading zeros.b- trims only the trailing zeros.fb- trims both leading and trailing zeros.
axis (
int|Sequence[int] |None) – optional axis or axes along which to trim. If not specified, trim along all axes of the array.
- Return type:
- Returns:
An array containing the trimmed input with same dtype as
filt.
Examples
One-dimensional input:
>>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0]) >>> jnp.trim_zeros(x) Array([2, 0, 1, 4, 3], dtype=int32) >>> jnp.trim_zeros(x, trim='f') Array([2, 0, 1, 4, 3, 0, 0, 0], dtype=int32) >>> jnp.trim_zeros(x, trim='b') Array([0, 0, 2, 0, 1, 4, 3], dtype=int32)
Two-dimensional input:
>>> x = jnp.zeros((4, 5)).at[1:3, 1:4].set(1) >>> x Array([[0., 0., 0., 0., 0.], [0., 1., 1., 1., 0.], [0., 1., 1., 1., 0.], [0., 0., 0., 0., 0.]], dtype=float32) >>> jnp.trim_zeros(x) Array([[1., 1., 1.], [1., 1., 1.]], dtype=float32) >>> jnp.trim_zeros(x, trim='f') Array([[1., 1., 1., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.]], dtype=float32) >>> jnp.trim_zeros(x, axis=0) Array([[0., 1., 1., 1., 0.], [0., 1., 1., 1., 0.]], dtype=float32) >>> jnp.trim_zeros(x, axis=1) Array([[0., 0., 0.], [1., 1., 1.], [1., 1., 1.], [0., 0., 0.]], dtype=float32)
- scico.numpy.triu_indices(n, k=0, m=None)¶
Return the indices of upper triangle of an array of size
(n, m).JAX implementation of
numpy.triu_indices.- Parameters:
n (
Union[int,Any]) – int. Number of rows of the array for which the indices are returned.k (
Union[int,Any]) – optional, int, default=0. Specifies the sub-diagonal on and above which the indices of upper triangle are returned.k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refers to sub-diagonal above the main diagonal.m (
Union[int,Any,None]) – optional, int. Number of columns of the array for which the indices are returned. If not specified, thenm = n.
- Return type:
- Returns:
A tuple of two arrays containing the indices of the upper triangle, one along each axis.
See also
jax.numpy.tril_indices: Returns the indices of lower triangle of an array of size(n, m).jax.numpy.triu_indices_from: Returns the indices of upper triangle of a given array.jax.numpy.tril_indices_from: Returns the indices of lower triangle of a given array.
Examples
If only
nis provided in input, the indices of upper triangle of an array of size(n, n)array are returned.>>> jnp.triu_indices(3) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
If both
nandmare provided in input, the indices of upper triangle of an(n, m)array are returned.>>> jnp.triu_indices(3, m=2) (Array([0, 0, 1], dtype=int32), Array([0, 1, 1], dtype=int32))
If
k = 1, the indices on and above the first sub-diagonal above the main diagonal are returned.>>> jnp.triu_indices(3, k=1) (Array([0, 0, 1], dtype=int32), Array([1, 2, 2], dtype=int32))
If
k = -1, the indices on and above the first sub-diagonal below the main diagonal are returned.>>> jnp.triu_indices(3, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))
- scico.numpy.triu_indices_from(arr, k=0)¶
Return the indices of upper triangle of a given array.
JAX implementation of
numpy.triu_indices_from.- Parameters:
arr (
Union[Array,ndarray,bool,number,bool,int,float,complex,SupportsShape]) – input array. Must havearr.ndim == 2.k (
int) – optional, int, default=0. Specifies the sub-diagonal on and above which the indices of upper triangle are returned.k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refers to sub-diagonal above the main diagonal.
- Return type:
- Returns:
A tuple of two arrays containing the indices of the upper triangle, one along each axis.
See also
jax.numpy.tril_indices_from: Returns the indices of lower triangle of a given array.jax.numpy.triu_indices: Returns the indices of upper triangle of an array of size(n, m).jax.numpy.triu: Return an upper triangle of an array.
Examples
>>> arr = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.triu_indices_from(arr) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
Elements indexed by
jnp.triu_indices_fromcorrespond to those in the output ofjnp.triu.>>> ind = jnp.triu_indices_from(arr) >>> arr[ind] Array([1, 2, 3, 5, 6, 9], dtype=int32) >>> jnp.triu(arr) Array([[1, 2, 3], [0, 5, 6], [0, 0, 9]], dtype=int32)
When
k > 0:>>> jnp.triu_indices_from(arr, k=1) (Array([0, 0, 1], dtype=int32), Array([1, 2, 2], dtype=int32))
When
k < 0:>>> jnp.triu_indices_from(arr, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))
- scico.numpy.true_divide(x1, x2, /)¶
Calculates the division of x1 by x2 element-wise
JAX implementation of
numpy.true_divide.- Parameters:
- Return type:
- Returns:
An array containing the elementwise quotients, will always use floating point division.
Examples
>>> x1 = jnp.array([3, 4, 5]) >>> x2 = 2 >>> jnp.true_divide(x1, x2) Array([1.5, 2. , 2.5], dtype=float32)
>>> x1 = 24 >>> x2 = jnp.array([3, 4, 6j]) >>> jnp.true_divide(x1, x2) Array([8.+0.j, 6.+0.j, 0.-4.j], dtype=complex64)
>>> x1 = jnp.array([1j, 9+5j, -4+2j]) >>> x2 = 3j >>> jnp.true_divide(x1, x2) Array([0.33333334+0.j , 1.6666666 -3.j , 0.6666667 +1.3333334j], dtype=complex64)
See also
jax.numpy.floor_dividefor integer division
- scico.numpy.trunc(x)¶
Round input to the nearest integer towards zero.
JAX implementation of
numpy.trunc.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array or scalar.- Return type:
- Returns:
An array with same shape and dtype as
xcontaining the rounded values.
See also
jax.numpy.fix: Rounds the input to the nearest integer towards zero.jax.numpy.ceil: Rounds the input up to the nearest integer.jax.numpy.floor: Rounds the input down to the nearest integer.
Examples
>>> key = jax.random.key(42) >>> x = jax.random.uniform(key, (3, 3), minval=-10, maxval=10) >>> with jnp.printoptions(precision=2, suppress=True): ... print(x) [[-0.23 3.6 2.33] [ 1.22 -0.99 1.72] [-8.5 5.5 3.98]] >>> jnp.trunc(x) Array([[-0., 3., 2.], [ 1., -0., 1.], [-8., 5., 3.]], dtype=float32)
- scico.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)¶
Compute the set union of two 1D arrays.
JAX implementation of
numpy.union1d.Because the size of the output of
union1dis data-dependent, the function is not typically compatible withjitand other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.union1dto be used in such contexts.- Parameters:
ar1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first array of elements to be unioned.ar2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second array of elements to be unionedsize (
int|None) – if specified, return only the firstsizesorted elements. If there are fewer elements thansizeindicates, the return value will be padded withfill_value.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum value.
- Return type:
- Returns:
an array containing the union of elements in the input array.
See also
jax.numpy.intersect1d: the set intersection of two 1D arrays.jax.numpy.setxor1d: the set XOR of two 1D arrays.jax.numpy.setdiff1d: the set difference of two 1D arrays.
Examples
Computing the union of two arrays:
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.union1d(ar1, ar2) Array([1, 2, 3, 4, 5, 6], dtype=int32)
Because the output shape is dynamic, this will fail under
jitand other transformations:>>> jax.jit(jnp.union1d)(ar1, ar2) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static
sizeargument:>>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size']) >>> jit_union1d(ar1, ar2, size=6) Array([1, 2, 3, 4, 5, 6], dtype=int32)
If
sizeis too small, the union is truncated:>>> jit_union1d(ar1, ar2, size=4) Array([1, 2, 3, 4], dtype=int32)
If
sizeis too large, then the output is padded withfill_value:>>> jit_union1d(ar1, ar2, size=8, fill_value=0) Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)
- scico.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None, sorted=True)¶
Return the unique values from an array.
JAX implementation of
numpy.unique.Because the size of the output of
uniqueis data-dependent, the function is not typically compatible withjitand other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.uniqueto be used in such contexts.- Parameters:
ar (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array from which unique values will be extracted.return_index (
bool) – if True, also return the indices inarwhere each value occursreturn_inverse (
bool) – if True, also return the indices that can be used to reconstructarfrom the unique values.return_counts (
bool) – if True, also return the number of occurrences of each unique value.axis (
int|None) – if specified, compute unique values along the specified axis. If None (default), then flattenarbefore computing the unique values.equal_nan (
bool) – if True, consider NaN values equivalent when determining uniqueness.size (
int|None) – if specified, return only the firstsizesorted unique elements. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum unique value.sorted (
bool) – unused by JAX.
- Returns:
An array or tuple of arrays, depending on the values of
return_index,return_inverse, andreturn_counts. Returned values areunique_values:if
axisis None, a 1D array of lengthn_unique, Ifaxisis specified, shape is(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:]).
unique_index:(returned only if return_index is True) An array of shape
(n_unique,). Contains the indices of the first occurrence of each unique value inar. For 1D inputs,ar[unique_index]is equivalent tounique_values.
unique_inverse:(returned only if return_inverse is True) An array of shape
(ar.size,)ifaxisis None, or of shape(ar.shape[axis],)ifaxisis specified. Contains the indices withinunique_valuesof each value inar. For 1D inputs,unique_values[unique_inverse]is equivalent toar.
unique_counts:(returned only if return_counts is True) An array of shape
(n_unique,). Contains the number of occurrences of each unique value inar.
See also
jax.numpy.unique_counts: shortcut tounique(arr, return_counts=True).jax.numpy.unique_inverse: shortcut tounique(arr, return_inverse=True).jax.numpy.unique_all: shortcut touniquewith all return values.jax.numpy.unique_values: likeunique, but no optional return values.
Examples
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> jnp.unique(x) Array([1, 3, 4], dtype=int32)
JIT compilation & the size argument
If you try this under
jitor another transformation, you will get an error because the output shape is dynamic:>>> jax.jit(jnp.unique)(x) Traceback (most recent call last): ... jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5]. The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.
The issue is that the output of transformed functions must have static shapes. In order to make this work, you can pass a static
sizeparameter:>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size']) >>> jit_unique(x, size=3) Array([1, 3, 4], dtype=int32)
If your static size is smaller than the true number of unique values, they will be truncated.
>>> jit_unique(x, size=2) Array([1, 3], dtype=int32)
If the static size is larger than the true number of unique values, they will be padded with
fill_value, which defaults to the minimum unique value:>>> jit_unique(x, size=5) Array([1, 3, 4, 1, 1], dtype=int32) >>> jit_unique(x, size=5, fill_value=0) Array([1, 3, 4, 0, 0], dtype=int32)
Multi-dimensional unique values
If you pass a multi-dimensional array to
unique, it will be flattened by default:>>> M = jnp.array([[1, 2], ... [2, 3], ... [1, 2]]) >>> jnp.unique(M) Array([1, 2, 3], dtype=int32)
If you pass an
axiskeyword, you can find unique slices of the array along that axis:>>> jnp.unique(M, axis=0) Array([[1, 2], [2, 3]], dtype=int32)
Returning indices
If you set
return_index=True, thenuniquereturns the indices of the first occurrence of each unique value:>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, indices = jnp.unique(x, return_index=True) >>> print(values) [1 3 4] >>> print(indices) [2 0 1] >>> jnp.all(values == x[indices]) Array(True, dtype=bool)
In multiple dimensions, the unique values can be extracted with
jax.numpy.takeevaluated along the specified axis:>>> values, indices = jnp.unique(M, axis=0, return_index=True) >>> jnp.all(values == jnp.take(M, indices, axis=0)) Array(True, dtype=bool)
Returning inverse
If you set
return_inverse=True, thenuniquereturns the indices within the unique values for every entry in the input array:>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, inverse = jnp.unique(x, return_inverse=True) >>> print(values) [1 3 4] >>> print(inverse) [1 2 0 1 0] >>> jnp.all(values[inverse] == x) Array(True, dtype=bool)
In multiple dimensions, the input can be reconstructed using
jax.numpy.take:>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True) >>> jnp.all(jnp.take(values, inverse, axis=0) == M) Array(True, dtype=bool)
Returning counts
If you set
return_counts=True, thenuniquereturns the number of occurrences within the input for every unique value:>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, counts = jnp.unique(x, return_counts=True) >>> print(values) [1 3 4] >>> print(counts) [2 2 1]
For multi-dimensional arrays, this also returns a 1D array of counts indicating number of occurrences along the specified axis:
>>> values, counts = jnp.unique(M, axis=0, return_counts=True) >>> print(values) [[1 2] [2 3]] >>> print(counts) [2 1]
- scico.numpy.unique_all(x, /, *, size=None, fill_value=None)¶
Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of
numpy.unique_all; this is equivalent to callingjax.numpy.uniquewith return_index, return_inverse, return_counts, and equal_nan set to True.Because the size of the output of
unique_allis data-dependent, the function is not typically compatible withjitand other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.uniqueto be used in such contexts.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array from which unique values will be extracted.size (
int|None) – if specified, return only the firstsizesorted unique elements. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum unique value.
- Return type:
_UniqueAllResult- Returns:
A tuple
(values, indices, inverse_indices, counts), with the following properties –values:an array of shape
(n_unique,)containing the unique values fromx.
indices:An array of shape
(n_unique,). Contains the indices of the first occurrence of each unique value inx. For 1D inputs,x[indices]is equivalent tovalues.
inverse_indices:An array of shape
x.shape. Contains the indices withinvaluesof each value inx. For 1D inputs,values[inverse_indices]is equivalent tox.
counts:An array of shape
(n_unique,). Contains the number of occurrences of each unique value inx.
See also
jax.numpy.unique: general function for computing unique values.jax.numpy.unique_values: compute onlyvalues.jax.numpy.unique_counts: compute onlyvaluesandcounts.jax.numpy.unique_inverse: compute onlyvaluesandinverse.
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_all(x)
The result is a
NamedTuplewith four named attributes. Thevaluesattribute contains the unique values from the array:>>> result.values Array([1, 3, 4], dtype=int32)
The
indicesattribute contains the indices of the uniquevalueswithin the input array:>>> result.indices Array([2, 0, 1], dtype=int32) >>> jnp.all(result.values == x[result.indices]) Array(True, dtype=bool)
The
inverse_indicesattribute contains the indices of the input withinvalues:>>> result.inverse_indices Array([1, 2, 0, 1, 0], dtype=int32) >>> jnp.all(x == result.values[result.inverse_indices]) Array(True, dtype=bool)
The
countsattribute contains the counts of each unique value in the input:>>> result.counts Array([2, 2, 1], dtype=int32)
For examples of the
sizeandfill_valuearguments, seejax.numpy.unique.
- scico.numpy.unique_counts(x, /, *, size=None, fill_value=None)¶
Return unique values from x, along with counts.
JAX implementation of
numpy.unique_counts; this is equivalent to callingjax.numpy.uniquewith return_counts and equal_nan set to True.Because the size of the output of
unique_countsis data-dependent, the function is not typically compatible withjitand other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.uniqueto be used in such contexts.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array from which unique values will be extracted.size (
int|None) – if specified, return only the firstsizesorted unique elements. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum unique value.
- Return type:
_UniqueCountsResult- Returns:
A tuple
(values, counts), with the following properties –values:an array of shape
(n_unique,)containing the unique values fromx.
counts:An array of shape
(n_unique,). Contains the number of occurrences of each unique value inx.
See also
jax.numpy.unique: general function for computing unique values.jax.numpy.unique_values: compute onlyvalues.jax.numpy.unique_inverse: compute onlyvaluesandinverse.jax.numpy.unique_all: computevalues,indices,inverse_indices, andcounts.
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_counts(x)
The result is a
NamedTuplewith two named attributes. Thevaluesattribute contains the unique values from the array:>>> result.values Array([1, 3, 4], dtype=int32)
The
countsattribute contains the counts of each unique value in the input:>>> result.counts Array([2, 2, 1], dtype=int32)
For examples of the
sizeandfill_valuearguments, seejax.numpy.unique.
- scico.numpy.unique_inverse(x, /, *, size=None, fill_value=None)¶
Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of
numpy.unique_inverse; this is equivalent to callingjax.numpy.uniquewith return_inverse and equal_nan set to True.Because the size of the output of
unique_inverseis data-dependent, the function is not typically compatible withjitand other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.uniqueto be used in such contexts.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array from which unique values will be extracted.size (
int|None) – if specified, return only the firstsizesorted unique elements. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum unique value.
- Return type:
_UniqueInverseResult- Returns:
A tuple
(values, indices, inverse_indices, counts), with the following properties –values:an array of shape
(n_unique,)containing the unique values fromx.
inverse_indices:An array of shape
x.shape. Contains the indices withinvaluesof each value inx. For 1D inputs,values[inverse_indices]is equivalent tox.
See also
jax.numpy.unique: general function for computing unique values.jax.numpy.unique_values: compute onlyvalues.jax.numpy.unique_counts: compute onlyvaluesandcounts.jax.numpy.unique_all: computevalues,indices,inverse_indices, andcounts.
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_inverse(x)
The result is a
NamedTuplewith two named attributes. Thevaluesattribute contains the unique values from the array:>>> result.values Array([1, 3, 4], dtype=int32)
The
indicesattribute contains the indices of the uniquevalueswithin the input array:The
inverse_indicesattribute contains the indices of the input withinvalues:>>> result.inverse_indices Array([1, 2, 0, 1, 0], dtype=int32) >>> jnp.all(x == result.values[result.inverse_indices]) Array(True, dtype=bool)
For examples of the
sizeandfill_valuearguments, seejax.numpy.unique.
- scico.numpy.unique_values(x, /, *, size=None, fill_value=None)¶
Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of
numpy.unique_values; this is equivalent to callingjax.numpy.uniquewith equal_nan set to True.Because the size of the output of
unique_valuesis data-dependent, the function is not typically compatible withjitand other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.uniqueto be used in such contexts.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array from which unique values will be extracted.size (
int|None) – if specified, return only the firstsizesorted unique elements. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value.fill_value (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – whensizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum unique value.
- Return type:
- Returns:
An array
valuesof shape(n_unique,)containing the unique values fromx.
See also
jax.numpy.unique: general function for computing unique values.jax.numpy.unique_values: compute onlyvalues.jax.numpy.unique_counts: compute onlyvaluesandcounts.jax.numpy.unique_inverse: compute onlyvaluesandinverse.
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> jnp.unique_values(x) Array([1, 3, 4], dtype=int32)
For examples of the
sizeandfill_valuearguments, seejax.numpy.unique.
- scico.numpy.unravel_index(indices, shape)¶
Convert flat indices into multi-dimensional indices.
JAX implementation of
numpy.unravel_index. The JAX version differs in its treatment of out-of-bound indices: unlike NumPy, negative indices are supported, and out-of-bound indices are clipped to the nearest valid value.- Parameters:
- Return type:
- Returns:
Tuple of unraveled indices
See also
jax.numpy.ravel_multi_index: Inverse of this function.Examples
Start with a 1D array values and indices:
>>> x = jnp.array([2., 3., 4., 5., 6., 7.]) >>> indices = jnp.array([1, 3, 5]) >>> print(x[indices]) [3. 5. 7.]
Now if
xis reshaped,unravel_indicescan be used to convert the flat indices into a tuple of indices that access the same entries:>>> shape = (2, 3) >>> x_2D = x.reshape(shape) >>> indices_2D = jnp.unravel_index(indices, shape) >>> indices_2D (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) >>> print(x_2D[indices_2D]) [3. 5. 7.]
The inverse function,
ravel_multi_index, can be used to obtain the original indices:>>> jnp.ravel_multi_index(indices_2D, shape) Array([1, 3, 5], dtype=int32)
- scico.numpy.unwrap(p, discont=None, axis=-1, period=6.283185307179586)¶
Unwrap a periodic signal.
JAX implementation of
numpy.unwrap.- Parameters:
p (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input arraydiscont (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – the maximum allowable discontinuity in the sequence. The default isperiod / 2axis (
int) – the axis along which to unwrap; defaults to -1period (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – the period of the signal, which defaults to \(2\pi\)
- Return type:
- Returns:
An unwrapped copy of
p.
Notes
This implementation follows that of
numpy.unwrap, and is not well-suited for integer-period unwrapping of narrow-width integers (e.g. int8, int16) or unsigned integers.Examples
Consider a situation in which you are making measurements of the position of a rotating disk via the
xandylocations of some point on that disk. The underlying variable is an always-increasing angle which we’ll generate this way, using degrees for ease of representation:>>> rng = np.random.default_rng(0) >>> theta = rng.integers(0, 90, size=(20,)).cumsum() >>> theta array([ 76, 133, 179, 203, 230, 233, 239, 240, 255, 328, 386, 468, 513, 567, 654, 719, 775, 823, 873, 957])
Our observations of this angle are the
xandycoordinates, given by the sine and cosine of this underlying angle:>>> x, y = jnp.sin(jnp.deg2rad(theta)), jnp.cos(jnp.deg2rad(theta))
Now, say that given these
xandycoordinates, we wish to recover the original angletheta. We might do this via theatan2function:>>> theta_out = jnp.rad2deg(jnp.atan2(x, y)).round() >>> theta_out Array([ 76., 133., 179., -157., -130., -127., -121., -120., -105., -32., 26., 108., 153., -153., -66., -1., 55., 103., 153., -123.], dtype=float32)
The first few values match the input angle
thetaabove, but after this the values are wrapped because thesinandcosobservations obscure the phase information. The purpose of theunwrapfunction is to recover the original signal from this wrapped view of it:>>> jnp.unwrap(theta_out, period=360) Array([ 76., 133., 179., 203., 230., 233., 239., 240., 255., 328., 386., 468., 513., 567., 654., 719., 775., 823., 873., 957.], dtype=float32)
It does this by assuming that the true underlying sequence does not differ by more than
discont(which defaults toperiod / 2) within a single step, and when it encounters a larger discontinuity it adds factors of the period to the data. For periodic signals that satisfy this assumption,unwrapcan recover the original phased signal.
- scico.numpy.var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)¶
Compute the variance along a given axis.
JAX implementation of
numpy.var.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array.axis (
Union[int,Sequence[int],None]) – optional, int or sequence of ints, default=None. Axis along which the variance is computed. If None, variance is computed along all the axes.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – The type of the output array. Default=None.ddof (
int) – int, default=0. Degrees of freedom. The divisor in the variance computation isN-ddof,Nis number of elements along given axis.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional, boolean array, default=None. The elements to be used in the variance. Array should be broadcast compatible to the input.mean (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the variance instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean withkeepdims=Trueandaxismatching this function’saxisargument.correction (
int|float|None) – int or float, default=None. Alternative name forddof. Both ddof and correction can’t be provided simultaneously.out (
None) – Unused by JAX.
- Return type:
- Returns:
An array of the variance along the given axis.
See also
jax.numpy.mean: Compute the mean of array elements over a given axis.jax.numpy.std: Compute the standard deviation of array elements over given axis.jax.numpy.nanvar: Compute the variance along a given axis, ignoring NaNs values.jax.numpy.nanstd: Computed the standard deviation of a given axis, ignoring NaN values.
Examples
By default,
jnp.varcomputes the variance along all axes.>>> x = jnp.array([[1, 3, 4, 2], ... [5, 2, 6, 3], ... [8, 4, 2, 9]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jnp.var(x) Array(5.74, dtype=float32)
If
axis=1, variance is computed along axis 1.>>> jnp.var(x, axis=1) Array([1.25 , 2.5 , 8.1875], dtype=float32)
To preserve the dimensions of input, you can set
keepdims=True.>>> jnp.var(x, axis=1, keepdims=True) Array([[1.25 ], [2.5 ], [8.1875]], dtype=float32)
If
ddof=1:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.var(x, axis=1, keepdims=True, ddof=1)) [[ 1.67] [ 3.33] [10.92]]
To include specific elements of the array to compute variance, you can use
where.>>> where = jnp.array([[1, 0, 1, 0], ... [0, 1, 1, 0], ... [1, 1, 1, 0]], dtype=bool) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.var(x, axis=1, keepdims=True, where=where)) [[2.25] [4. ] [6.22]]
- scico.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)¶
Perform a conjugate multiplication of two 1D vectors.
JAX implementation of
numpy.vdot.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first input array, if not 1D it will be flattened.b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second input array, if not 1D it will be flattened. Must havea.size == b.size.precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofaandb.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
Scalar array (shape
()) containing the conjugate vector product of the inputs.
See also
jax.numpy.vecdot: batched vector product.jax.numpy.matmul: general matrix multiplication.jax.lax.dot_general: general N-dimensional batched dot product.
Examples
>>> x = jnp.array([1j, 2j, 3j]) >>> y = jnp.array([1., 2., 3.]) >>> jnp.vdot(x, y) Array(0.-14.j, dtype=complex64)
Note the difference between this and
dot, which does not conjugate the first input when complex:>>> jnp.dot(x, y) Array(0.+14.j, dtype=complex64)
- scico.numpy.vecdot(x1, x2, /, *, axis=-1, precision=None, preferred_element_type=None)¶
Perform a conjugate multiplication of two batched vectors.
JAX implementation of
numpy.vecdot.- Parameters:
a – left-hand side array.
b – right-hand side array. Size of
b[axis]must match size ofa[axis], and remaining dimensions must be broadcast-compatible.axis (
int) – axis along which to compute the dot product (default: -1)precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofaandb.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
array containing the conjugate dot product of
aandbalongaxis. The non-contracted dimensions are broadcast together.
See also
jax.numpy.vdot: flattened vector product.jax.numpy.vecmat: vector-matrix product.jax.numpy.matmul: general matrix multiplication.jax.lax.dot_general: general N-dimensional batched dot product.
Examples
Vector conjugate-dot product of two 1D arrays:
>>> a = jnp.array([1j, 2j, 3j]) >>> b = jnp.array([4., 5., 6.]) >>> jnp.linalg.vecdot(a, b) Array(0.-32.j, dtype=complex64)
Batched vector dot product of two 2D arrays:
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> b = jnp.array([[2, 3, 4]]) >>> jnp.linalg.vecdot(a, b, axis=-1) Array([20, 47], dtype=int32)
- scico.numpy.vectorize(pyfunc, *, excluded=frozenset({}), signature=None)¶
Define a vectorized function with broadcasting.
vectorizeis a convenience wrapper for defining vectorized functions with broadcasting, in the style of NumPy’s generalized universal functions. It allows for defining functions that are automatically repeated across any leading dimensions, without the implementation of the function needing to be concerned about how to handle higher dimensional inputs.jax.numpy.vectorizehas the same interface asnumpy.vectorize, but it is syntactic sugar for an auto-batching transformation (vmap) rather than a Python loop. This should be considerably more efficient, but the implementation must be written in terms of functions that act on JAX arrays.- Parameters:
pyfunc – function to vectorize.
excluded – optional set of integers representing positional arguments for which the function will not be vectorized. These will be passed directly to
pyfuncunmodified.signature – optional generalized universal function signature, e.g.,
(m,n),(n)->(m)for vectorized matrix-vector multiplication. If provided,pyfuncwill be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By default, pyfunc is assumed to take scalar arrays as input, and ifsignatureisNone,pyfunccan produce outputs of any shape.
- Returns:
Vectorized version of the given function.
Examples
Here are a few examples of how one could write vectorized linear algebra routines using
vectorize:>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)') ... def cross_product(a, b): ... assert a.shape == b.shape and a.ndim == b.ndim == 1 ... return jnp.array([a[1] * b[2] - a[2] * b[1], ... a[2] * b[0] - a[0] * b[2], ... a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') ... def matrix_vector_product(matrix, vector): ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape ... return matrix @ vector
These functions are only written to handle 1D or 2D arrays (the
assertstatements will never be violated), but with vectorize they support arbitrary dimensional inputs with NumPy style broadcasting, e.g.,>>> cross_product(jnp.ones(3), jnp.ones(3)).shape (3,) >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2, 3) >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape (2, 2, 3) >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) Traceback (most recent call last): ValueError: input with shape (3,) does not have enough dimensions for all core dimensions ('n', 'k') on vectorized function with excluded=frozenset() and signature='(n,k),(k)->(k)' >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2,) >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape (4, 2)
Note that this has different semantics than jnp.matmul:
>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) Traceback (most recent call last): TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].
- scico.numpy.vsplit(ary, indices_or_sections)¶
Split an array into sub-arrays vertically.
JAX implementation of
numpy.vsplit.Refer to the documentation of
jax.numpy.splitfor details;vsplitis equivalent tosplitwithaxis=0.Examples
1D array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> x1, x2 = jnp.vsplit(x, 2) >>> print(x1, x2) [1 2 3] [4 5 6]
2D array:
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8]]) >>> x1, x2 = jnp.vsplit(x, 2) >>> print(x1, x2) [[1 2 3 4]] [[5 6 7 8]]
See also
jax.numpy.split: split an array along any axis.jax.numpy.hsplit: split horizontally, i.e. along axis=1jax.numpy.dsplit: split depth-wise, i.e. along axis=2jax.numpy.array_split: likesplit, but allowsindices_or_sectionsto be an integer that does not evenly divide the size of the array.
- scico.numpy.vstack(tup, dtype=None)¶
Vertically stack arrays.
JAX implementation of
numpy.vstack.For arrays of two or more dimensions, this is equivalent to
jax.numpy.concatenatewithaxis=0.- Parameters:
tup (
ndarray|Array|Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – a sequence of arrays to stack; each must have the same shape along all but the first axis. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Return type:
- Returns:
the stacked result.
See also
jax.numpy.stack: stack along arbitrary axesjax.numpy.concatenate: concatenation along existing axes.jax.numpy.hstack: stack horizontally, i.e. along axis 1.jax.numpy.dstack: stack depth-wise, i.e. along axis 2.
Examples
Scalar values:
>>> jnp.vstack([1, 2, 3]) Array([[1], [2], [3]], dtype=int32, weak_type=True)
1D arrays:
>>> x = jnp.arange(4) >>> y = jnp.ones(4) >>> jnp.vstack([x, y]) Array([[0., 1., 2., 3.], [1., 1., 1., 1.]], dtype=float32)
2D arrays:
>>> x = x.reshape(1, 4) >>> y = y.reshape(1, 4) >>> jnp.vstack([x, y]) Array([[0., 1., 2., 3.], [1., 1., 1., 1.]], dtype=float32)
- scico.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)¶
Select elements from two arrays based on a condition.
JAX implementation of
numpy.where.Note
when only
conditionis provided,jnp.where(condition)is equivalent tojnp.nonzero(condition). For that case, refer to the documentation ofjax.numpy.nonzero. The docstring below focuses on the case wherexandyare specified.The three-term version of
jnp.wherelowers tojax.lax.select.- Parameters:
condition – boolean array. Must be broadcast-compatible with
xandywhen they are specified.x – arraylike. Should be broadcast-compatible with
conditionandy, and typecast-compatible withy.y – arraylike. Should be broadcast-compatible with
conditionandx, and typecast-compatible withx.size – integer, only referenced when
xandyareNone. For details, seejax.numpy.nonzero.fill_value – only referenced when
xandyareNone. For details, seejax.numpy.nonzero.
- Returns:
An array of dtype
jnp.result_type(x, y)with values drawn fromxwhereconditionis True, and fromywhere condition isFalse. IfxandyareNone, the function behaves differently; seejax.numpy.nonzerofor a description of the return type.
Notes
Special care is needed when the
xoryinput tojax.numpy.wherecould have a value of NaN. Specifically, when a gradient is taken withjax.grad(reverse-mode differentiation), a NaN in eitherxorywill propagate into the gradient, regardless of the value ofcondition. More information on this behavior and workarounds is available in the JAX FAQ.Examples
When
xandyare not provided,wherebehaves equivalently tojax.numpy.nonzero:>>> x = jnp.arange(10) >>> jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) >>> jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
When
xandyare provided,whereselects between them based on the specified condition:>>> jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
- scico.numpy.zeros(shape, dtype=None, *, device=None, out_sharding=None)¶
Create an array full of zeros.
JAX implementation of
numpy.zeros.- Parameters:
shape (
Any) – int or sequence of ints specifying the shape of the created array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optional dtype for the created array; defaults to float32 or float64 depending on the X64 configuration (see Default dtypes and the X64 flag).device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed. This argument exists for compatibility with the Python Array API standard.out_sharding (
NamedSharding|P|None) – (optional)PartitionSpecorNamedShardingrepresenting the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying bothout_shardinganddevicewill result in an error.
- Return type:
- Returns:
Array of the specified shape and dtype, with the given device/sharding if specified.
Examples
>>> jnp.zeros(4) Array([0., 0., 0., 0.], dtype=float32) >>> jnp.zeros((2, 3), dtype=bool) Array([[False, False, False], [False, False, False]], dtype=bool)
- scico.numpy.zeros_like(a, dtype=None, shape=None, *, device=None, out_sharding=None)¶
Create an array full of zeros with the same shape and dtype as an array.
JAX implementation of
numpy.zeros_like.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex,DuckTypedArray]) – Array-like object withshapeanddtypeattributes.shape (
Any) – optionally override the shape of the created array.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – optionally override the dtype of the created array.device (
Device|Sharding|None) – (optional)DeviceorShardingto which the created array will be committed.
- Return type:
- Returns:
Array of the specified shape and dtype, on the specified device if specified.
Examples
>>> x = jnp.arange(4) >>> jnp.zeros_like(x) Array([0, 0, 0, 0], dtype=int32) >>> jnp.zeros_like(x, dtype=bool) Array([False, False, False, False], dtype=bool) >>> jnp.zeros_like(x, shape=(2, 3)) Array([[0, 0, 0], [0, 0, 0]], dtype=int32)