scico.numpy.linalg¶
Linear algebra functions.
Functions
|
Compute the Cholesky decomposition of a matrix. |
|
Compute the condition number of a matrix. |
|
Compute the cross-product of two 3D vectors |
|
Compute the determinant of an array. |
|
Extract the diagonal of an matrix or stack of matrices. |
|
Compute the eigenvalues and eigenvectors of a square array. |
|
Compute the eigenvalues and eigenvectors of a Hermitian matrix. |
|
Compute the eigenvalues of a general matrix. |
|
Compute the eigenvalues of a Hermitian matrix. |
|
Return the inverse of a square matrix |
|
Return the least-squares solution to a linear equation. |
|
Perform a matrix multiplication. |
|
Compute the norm of a matrix or stack of matrices. |
|
Raise a square matrix to an integer power. |
|
Compute the rank of a matrix. |
|
Transpose a matrix or stack of matrices. |
|
Efficiently compute matrix products between a sequence of arrays. |
|
Compute the norm of a matrix or vector. |
|
Compute the outer product of two 1-dimensional arrays. |
|
Compute the (Moore-Penrose) pseudo-inverse of a matrix. |
|
Compute the QR decomposition of an array |
|
Compute the sign and (natural) logarithm of the determinant of an array. |
|
Solve a linear system of equations. |
|
Compute the singular value decomposition. |
|
Compute the singular values of a matrix. |
|
Compute the tensor dot product of two N-dimensional arrays. |
|
Compute the tensor inverse of an array. |
|
Solve the tensor equation a x = b for x. |
|
Compute the trace of a matrix. |
|
Compute the (batched) vector conjugate dot product of two arrays. |
|
Compute the vector norm of a vector or batch of vectors. |
- scico.numpy.linalg.cholesky(a, *, upper=False, symmetrize_input=True)¶
Compute the Cholesky decomposition of a matrix.
JAX implementation of
numpy.linalg.cholesky.The Cholesky decomposition of a matrix A is:
\[A = U^HU\]or
\[A = LL^H\]where U is an upper-triangular matrix and L is a lower-triangular matrix, and \(X^H\) is the Hermitian transpose of X.
- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array, representing a (batched) positive-definite hermitian matrix. Must have shape(..., N, N).upper (
bool) – if True, compute the upper Cholesky decomposition U. if False (default), compute the lower Cholesky decomposition L.symmetrize_input (
bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.
- Return type:
- Returns:
array of shape
(..., N, N)representing the Cholesky decomposition of the input. If the input is not Hermitian positive-definite, the result will contain NaN entries.
See also
jax.scipy.linalg.cholesky: SciPy-style Cholesky APIjax.lax.linalg.cholesky: XLA-style Cholesky API
Examples
A small real Hermitian positive-definite matrix:
>>> x = jnp.array([[2., 1.], ... [1., 2.]])
Lower Cholesky factorization:
>>> jnp.linalg.cholesky(x) Array([[1.4142135 , 0. ], [0.70710677, 1.2247449 ]], dtype=float32)
Upper Cholesky factorization:
>>> jnp.linalg.cholesky(x, upper=True) Array([[1.4142135 , 0.70710677], [0. , 1.2247449 ]], dtype=float32)
Reconstructing
xfrom its factorization:>>> L = jnp.linalg.cholesky(x) >>> jnp.allclose(x, L @ L.T) Array(True, dtype=bool)
- scico.numpy.linalg.cond(x, p=None)¶
Compute the condition number of a matrix.
JAX implementation of
numpy.linalg.cond.The condition number is defined as
norm(x, p) * norm(inv(x), p). Forp = 2(the default), the condition number is the ratio of the largest to the smallest singular value.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, N)for which to compute the condition number.p – the order of the norm to use. One of
{None, 1, -1, 2, -2, inf, -inf, 'fro'}; seejax.numpy.linalg.normfor the meaning of these. The default isp = None, which is equivalent top = 2. If not in{None, 2, -2}thenxmust be square, i.e.M = N.
- Returns:
array of shape
x.shape[:-2]containing the condition number.
See also
Examples
Well-conditioned matrix:
>>> x = jnp.array([[1, 2], ... [2, 1]]) >>> jnp.linalg.cond(x) Array(3., dtype=float32)
Ill-conditioned matrix:
>>> x = jnp.array([[1, 2], ... [0, 0]]) >>> jnp.linalg.cond(x) Array(inf, dtype=float32)
- scico.numpy.linalg.cross(x1, x2, /, *, axis=-1)¶
Compute the cross-product of two 3D vectors
JAX implementation of
numpy.linalg.cross- Parameters:
x1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array, withx1.shape[axis] == 3x2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array, withx2.shape[axis] == 3, and other axes broadcast-compatible withx1.axis – axis along which to take the cross product (default: -1).
- Returns:
array containing the result of the cross-product
See also
jax.numpy.cross: more flexible cross-product API.Examples
Showing that \(\hat{x} \times \hat{y} = \hat{z}\):
>>> x = jnp.array([1., 0., 0.]) >>> y = jnp.array([0., 1., 0.]) >>> jnp.linalg.cross(x, y) Array([0., 0., 1.], dtype=float32)
Cross product of \(\hat{x}\) with all three standard unit vectors, via broadcasting:
>>> xyz = jnp.eye(3) >>> jnp.linalg.cross(x, xyz, axis=-1) Array([[ 0., 0., 0.], [ 0., 0., 1.], [ 0., -1., 0.]], dtype=float32)
- scico.numpy.linalg.det(a)¶
Compute the determinant of an array.
JAX implementation of
numpy.linalg.det.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, M)for which to compute the determinant.- Return type:
- Returns:
An array of determinants of shape
a.shape[:-2].
See also
jax.scipy.linalg.det: Scipy-style API for determinant.Examples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.linalg.det(a) Array(-2., dtype=float32)
- scico.numpy.linalg.diagonal(x, /, *, offset=0)¶
Extract the diagonal of an matrix or stack of matrices.
JAX implementation of
numpy.linalg.diagonal.- Parameters:
- Return type:
- Returns:
Array of shape
(..., K)whereKis the length of the specified diagonal.
See also
jax.numpy.diagonal: more general functionality for extracting diagonals.jax.numpy.diag: create a diagonal matrix from values.
Examples
Diagonals of a single matrix:
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.diagonal(x) Array([ 1, 6, 11], dtype=int32) >>> jnp.linalg.diagonal(x, offset=1) Array([ 2, 7, 12], dtype=int32) >>> jnp.linalg.diagonal(x, offset=-1) Array([ 5, 10], dtype=int32)
Batched diagonals:
>>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.diagonal(x) Array([[ 0, 5, 10], [12, 17, 22]], dtype=int32)
- scico.numpy.linalg.eig(a)¶
Compute the eigenvalues and eigenvectors of a square array.
JAX implementation of
numpy.linalg.eig.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, M)for which to compute the eigenvalues and vectors.- Return type:
EigResult- Returns:
A namedtuple
(eigenvalues, eigenvectors). The namedtuple has fields –eigenvalues: an array of shape(..., M)containing the eigenvalues.eigenvectors: an array of shape(..., M, M), where columnv[:, i]is the eigenvector corresponding to the eigenvaluew[i].
Notes
This differs from
numpy.linalg.eigin that the return type ofjax.numpy.linalg.eigis always complex64 for 32-bit input, and complex128 for 64-bit input.At present, non-symmetric eigendecomposition is only implemented on the CPU and GPU backends. For more details about the GPU implementation, see the documentation for
jax.lax.linalg.eig.Currently autodiff is not supported for computation of non-symmetric eigenvectors; see https://github.com/jax-ml/jax/issues/2748.
See also
jax.lax.linalg.eig: similar function with different eigenvector options and device-specific implementations.jax.numpy.linalg.eigh: eigenvectors and eigenvalues of a Hermitian matrix.jax.numpy.linalg.eigvals: compute eigenvalues only.
Examples
>>> a = jnp.array([[1., 2.], ... [2., 1.]]) >>> w, v = jnp.linalg.eig(a) >>> with jax.numpy.printoptions(precision=4): ... w Array([ 3.+0.j, -1.+0.j], dtype=complex64) >>> v Array([[ 0.70710677+0.j, -0.70710677+0.j], [ 0.70710677+0.j, 0.70710677+0.j]], dtype=complex64)
- scico.numpy.linalg.eigh(a, UPLO=None, symmetrize_input=True)¶
Compute the eigenvalues and eigenvectors of a Hermitian matrix.
JAX implementation of
numpy.linalg.eigh.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, M), containing the Hermitian (if complex) or symmetric (if real) matrix.UPLO (
str|None) – specifies whether the calculation is done with the lower triangular part ofa('L', default) or the upper triangular part ('U').symmetrize_input (
bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.
- Return type:
EighResult- Returns:
A namedtuple
(eigenvalues, eigenvectors)whereeigenvalues: an array of shape(..., M)containing the eigenvalues, sorted in ascending order.eigenvectors: an array of shape(..., M, M), where columnv[:, i]is the normalized eigenvector corresponding to the eigenvaluew[i].
See also
jax.numpy.linalg.eig: general eigenvalue decomposition.jax.numpy.linalg.eigvalsh: compute eigenvalues only.jax.scipy.linalg.eigh: SciPy API for Hermitian eigendecomposition.jax.lax.linalg.eigh: XLA API for Hermitian eigendecomposition.
Examples
>>> a = jnp.array([[1, -2j], ... [2j, 1]]) >>> w, v = jnp.linalg.eigh(a) >>> w Array([-1., 3.], dtype=float32) >>> with jnp.printoptions(precision=3): ... v Array([[-0.707+0.j , -0.707+0.j ], [ 0. +0.707j, 0. -0.707j]], dtype=complex64)
- scico.numpy.linalg.eigvals(a)¶
Compute the eigenvalues of a general matrix.
JAX implementation of
numpy.linalg.eigvals.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, M)for which to compute the eigenvalues.- Return type:
- Returns:
An array of shape
(..., M)containing the eigenvalues.
See also
jax.numpy.linalg.eig: computes eigenvalues eigenvectors of a general matrix.jax.numpy.linalg.eigh: computes eigenvalues eigenvectors of a Hermitian matrix.
Notes
This differs from
numpy.linalg.eigvalsin that the return type ofjax.numpy.linalg.eigvalsis always complex64 for 32-bit input, and complex128 for 64-bit input.At present, non-symmetric eigendecomposition is only implemented on the CPU backend.
Examples
>>> a = jnp.array([[1., 2.], ... [2., 1.]]) >>> w = jnp.linalg.eigvals(a) >>> with jnp.printoptions(precision=2): ... w Array([ 3.+0.j, -1.+0.j], dtype=complex64)
- scico.numpy.linalg.eigvalsh(a, UPLO='L', *, symmetrize_input=True)¶
Compute the eigenvalues of a Hermitian matrix.
JAX implementation of
numpy.linalg.eigvalsh.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, M), containing the Hermitian (if complex) or symmetric (if real) matrix.UPLO (
str|None) – specifies whether the calculation is done with the lower triangular part ofa('L', default) or the upper triangular part ('U').symmetrize_input (
bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.
- Return type:
- Returns:
An array of shape
(..., M)containing the eigenvalues, sorted in ascending order.
See also
jax.numpy.linalg.eig: general eigenvalue decomposition.jax.numpy.linalg.eigh: computes eigenvalues and eigenvectors of a Hermitian matrix.
Examples
>>> a = jnp.array([[1, -2j], ... [2j, 1]]) >>> w = jnp.linalg.eigvalsh(a) >>> w Array([-1., 3.], dtype=float32)
- scico.numpy.linalg.inv(a)¶
Return the inverse of a square matrix
JAX implementation of
numpy.linalg.inv.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., N, N)specifying square array(s) to be inverted.- Return type:
- Returns:
Array of shape
(..., N, N)containing the inverse of the input.
Notes
In most cases, explicitly computing the inverse of a matrix is ill-advised. For example, to compute
x = inv(A) @ b, it is more performant and numerically precise to use a direct solve, such asjax.scipy.linalg.solve.See also
jax.scipy.linalg.inv: SciPy-style API for matrix inversejax.numpy.linalg.solve: direct linear solver
Examples
Compute the inverse of a 3x3 matrix
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> a_inv = jnp.linalg.inv(a) >>> a_inv Array([[ 0. , -0.25 , 0.5 ], [-0.25 , 0.5 , -0.25000003], [ 0.5 , -0.25 , 0. ]], dtype=float32)
Check that multiplying with the inverse gives the identity:
>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
Multiply the inverse by a vector
b, to find a solution toa @ x = b>>> b = jnp.array([1., 4., 2.]) >>> a_inv @ b Array([ 0. , 1.25, -0.5 ], dtype=float32)
Note, however, that explicitly computing the inverse in such a case can lead to poor performance and loss of precision as the size of the problem grows. Instead, you should use a direct solver like
jax.numpy.linalg.solve:>>> jnp.linalg.solve(a, b) Array([ 0. , 1.25, -0.5 ], dtype=float32)
- scico.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)¶
Return the least-squares solution to a linear equation.
JAX implementation of
numpy.linalg.lstsq.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(M, N)representing the coefficient matrix.b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(M,)or(M, K)representing the right-hand side.rcond (
float|None) – Cut-off ratio for small singular values. Singular values smaller thanrcond * largest_singular_valueare treated as zero. If None (default), the optimal value will be used to reduce floating point errors.numpy_resid (
bool) – If True, compute and return residuals in the same way as NumPy’s linalg.lstsq. This is necessary if you want to precisely replicate NumPy’s behavior. If False (default), a more efficient method is used to compute residuals.
- Return type:
- Returns:
Tuple of arrays
(x, resid, rank, s)wherexis a shape(N,)or(N, K)array containing the least-squares solution.residis the sum of squared residual of shape()or(K,).rankis the rank of the matrixa.sis the singular values of the matrixa.
Examples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([5, 6]) >>> x, _, _, _ = jnp.linalg.lstsq(a, b) >>> with jnp.printoptions(precision=3): ... print(x) [-4. 4.5]
- scico.numpy.linalg.matmul(x1, x2, /, *, precision=None, preferred_element_type=None)¶
Perform a matrix multiplication.
JAX implementation of
numpy.linalg.matmul.- Parameters:
x1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – first input array, of shape(..., N).x2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – second input array. Must have shape(N,)or(..., N, M). In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions ofx1.precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofx1andx2.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
array containing the matrix product of the inputs. Shape is
x1.shape[:-1]ifx2.ndim == 1, otherwise the shape is(..., M).
See also
jax.numpy.matmul: NumPy API for this function.jax.numpy.linalg.vecdot: batched vector product.jax.numpy.linalg.tensordot: batched tensor product.Examples
Vector dot products:
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> jnp.linalg.matmul(x1, x2) Array(32, dtype=int32)
Matrix dot product:
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> jnp.linalg.matmul(x1, x2) Array([[22, 28], [49, 64]], dtype=int32)
For convenience, in all cases you can do the same computation using the
@operator:>>> x1 @ x2 Array([[22, 28], [49, 64]], dtype=int32)
- scico.numpy.linalg.matrix_norm(x, /, *, keepdims=False, ord='fro')¶
Compute the norm of a matrix or stack of matrices.
JAX implementation of
numpy.linalg.matrix_norm- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, N)for which to take the norm.keepdims (
bool) – if True, keep the reduced dimensions in the output.ord (
str|int) – A string or int specifying the type of norm; default is the Frobenius norm. Seenumpy.linalg.normfor details on available options.
- Return type:
- Returns:
array containing the norm of
x. Has shapex.shape[:-2]ifkeepdimsis False, or shape(..., 1, 1)ifkeepdimsis True.
See also
jax.numpy.linalg.vector_norm: Norm of a vector or stack of vectors.jax.numpy.linalg.norm: More general matrix or vector norm.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.linalg.matrix_norm(x) Array(16.881943, dtype=float32)
- scico.numpy.linalg.matrix_power(a, n)¶
Raise a square matrix to an integer power.
JAX implementation of
numpy.linalg.matrix_power, implemented via repeated squarings.- Parameters:
- Return type:
- Returns:
Array of shape
(..., M, M)containing the matrix power of a to the n.
Examples
>>> a = jnp.array([[1., 2.], ... [3., 4.]]) >>> jnp.linalg.matrix_power(a, 3) Array([[ 37., 54.], [ 81., 118.]], dtype=float32) >>> a @ a @ a # equivalent evaluated directly Array([[ 37., 54.], [ 81., 118.]], dtype=float32)
This also supports zero powers:
>>> jnp.linalg.matrix_power(a, 0) Array([[1., 0.], [0., 1.]], dtype=float32)
and also supports negative powers:
>>> with jnp.printoptions(precision=3): ... jnp.linalg.matrix_power(a, -2) Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32)
Negative powers are equivalent to matmul of the inverse:
>>> inv_a = jnp.linalg.inv(a) >>> with jnp.printoptions(precision=3): ... inv_a @ inv_a Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32)
- scico.numpy.linalg.matrix_rank(M, rtol=None, *, hermitian=False, tol=None)¶
Compute the rank of a matrix.
JAX implementation of
numpy.linalg.matrix_rank.The rank is calculated via the Singular Value Decomposition (SVD), and determined by the number of singular values greater than the specified tolerance.
- Parameters:
M (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., N, K)whose rank is to be computed.rtol (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – optional array of shape(...)specifying the tolerance. Singular values smaller than rtol * largest_singular_value are considered to be zero. Ifrtolis None (the default), a reasonable default is chosen based the floating point precision of the input.hermitian (
bool) – if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False)tol (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – alias of thertolargument present for backward compatibility. Only one of rtol or tol may be specified.
- Return type:
- Returns:
array of shape
a.shape[-2]giving the matrix rank.
Notes
The rank calculation may be inaccurate for matrices with very small singular values or those that are numerically ill-conditioned. Consider adjusting the
rtolparameter or using a more specialized rank computation method in such cases.Examples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.linalg.matrix_rank(a) Array(2, dtype=int32)
>>> b = jnp.array([[1, 0], # Rank-deficient matrix ... [0, 0]]) >>> jnp.linalg.matrix_rank(b) Array(1, dtype=int32)
- scico.numpy.linalg.matrix_transpose(x, /)¶
Transpose a matrix or stack of matrices.
JAX implementation of
numpy.linalg.matrix_transpose.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, N)- Return type:
- Returns:
array of shape
(..., N, M)containing the matrix transpose ofx.
See also
jax.numpy.transpose: more general transpose operation.Examples
Transpose of a single matrix:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.matrix_transpose(x) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
Transpose of a stack of matrices:
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.linalg.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
For convenience, the same computation can be done via the
mTproperty of JAX array objects:>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
- scico.numpy.linalg.multi_dot(arrays, *, precision=None)¶
Efficiently compute matrix products between a sequence of arrays.
JAX implementation of
numpy.linalg.multi_dot.JAX internally uses the opt_einsum library to compute the most efficient operation order.
- Parameters:
arrays (
Sequence[Union[Array,ndarray,bool,number,bool,int,float,complex]]) – sequence of arrays. All must be two-dimensional, except the first and last which may be one-dimensional.precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST).
- Return type:
- Returns:
an array representing the equivalent of
reduce(jnp.matmul, arrays), but evaluated in the optimal order.
This function exists because the cost of computing sequences of matmul operations can differ vastly depending on the order in which the operations are evaluated. For a single matmul, the number of floating point operations (flops) required to compute a matrix product can be approximated this way:
>>> def approx_flops(x, y): ... # for 2D x and y, with x.shape[1] == y.shape[0] ... return 2 * x.shape[0] * x.shape[1] * y.shape[1]
Suppose we have three matrices that we’d like to multiply in sequence:
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> x = jax.random.normal(key1, shape=(200, 5)) >>> y = jax.random.normal(key2, shape=(5, 100)) >>> z = jax.random.normal(key3, shape=(100, 10))
Because of associativity of matrix products, there are two orders in which we might evaluate the product
x @ y @ z, and both produce equivalent outputs up to floating point precision:>>> result1 = (x @ y) @ z >>> result2 = x @ (y @ z) >>> jnp.allclose(result1, result2, atol=1E-4) Array(True, dtype=bool)
But the computational cost of these differ greatly:
>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z)) (x @ y) @ z flops: 600000 >>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z)) x @ (y @ z) flops: 30000
The second approach is about 20x more efficient in terms of estimated flops!
multi_dotis a function that will automatically choose the fastest computational path for such problems:>>> result3 = jnp.linalg.multi_dot([x, y, z]) >>> jnp.allclose(result1, result3, atol=1E-4) Array(True, dtype=bool)
We can use JAX’s Ahead-of-time lowering and compilation tools to estimate the total flops of each approach, and confirm that
multi_dotis choosing the more efficient option:>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops'] 600000.0 >>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops'] 30000.0 >>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops'] 30000.0
- scico.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)¶
Compute the norm of a matrix or vector.
JAX implementation of
numpy.linalg.norm.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array for which the norm will be computed.ord (
int|str|None) – specify the kind of norm to take. Default is Frobenius norm for matrices, and the 2-norm for vectors. For other options, see Notes below.axis (
None|tuple[int,...] |int) – integer or sequence of integers specifying the axes over which the norm will be computed. For a single axis, compute a vector norm. For two axes, compute a matrix norm. Defaults to all axes ofx.keepdims (
bool) – if True, the output array will have the same number of dimensions as the input, with the size of reduced axes replaced by1(default: False).
- Return type:
- Returns:
array containing the specified norm of x.
Notes
The flavor of norm computed depends on the value of
ordand the number of axes being reduced.For vector norms (i.e. a single axis reduction):
ord=None(default) computes the 2-normord=infcomputesmax(abs(x))ord=-infcomputes min(abs(x))``ord=0computessum(x!=0)for other numerical values, computes
sum(abs(x) ** ord)**(1/ord)
For matrix norms (i.e. two axes reductions):
ord='fro'orord=None(default) computes the Frobenius normord='nuc'computes the nuclear norm, or the sum of the singular valuesord=1computesmax(abs(x).sum(0))ord=-1computesmin(abs(x).sum(0))ord=2computes the 2-norm, i.e. the largest singular valueord=-2computes the smallest singular value
In the special case of
ord=Noneandaxis=None, this function accepts an array of any dimension and computes the vector 2-norm of the flattened array.Examples
Vector norms:
>>> x = jnp.array([3., 4., 12.]) >>> jnp.linalg.norm(x) Array(13., dtype=float32) >>> jnp.linalg.norm(x, ord=1) Array(19., dtype=float32) >>> jnp.linalg.norm(x, ord=0) Array(3., dtype=float32)
Matrix norms:
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.norm(x) # Frobenius norm Array(10.198039, dtype=float32) >>> jnp.linalg.norm(x, ord='nuc') # nuclear norm Array(10.762535, dtype=float32) >>> jnp.linalg.norm(x, ord=1) # 1-norm Array(10., dtype=float32)
Batched vector norm:
>>> jnp.linalg.norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)
- scico.numpy.linalg.outer(x1, x2, /)¶
Compute the outer product of two 1-dimensional arrays.
JAX implementation of
numpy.linalg.outer.- Parameters:
- Return type:
- Returns:
array containing the outer product of
x1andx2
See also
jax.numpy.outer: similar function in the mainjax.numpymodule.Examples
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> jnp.linalg.outer(x1, x2) Array([[ 4, 5, 6], [ 8, 10, 12], [12, 15, 18]], dtype=int32)
- scico.numpy.linalg.pinv(a, rtol=None, hermitian=False, *, rcond=None)¶
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
JAX implementation of
numpy.linalg.pinv.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, N)containing matrices to pseudo-invert.rtol (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – float or array_like of shapea.shape[:-2]. Specifies the cutoff for small singular values.of shape(...,). Cutoff for small singular values; singular values smallerrtol * largest_singular_valueare treated as zero. The default is determined based on the floating point precision of the dtype.hermitian (
bool) – if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False)rcond (
Union[Array,ndarray,bool,number,bool,int,float,complex,None]) – alias of the rtol argument, present for backward compatibility. Only one of rtol and rcond may be specified.
- Return type:
- Returns:
An array of shape
(..., N, M)containing the pseudo-inverse ofa.
See also
jax.numpy.linalg.inv: multiplicative inverse of a square matrix.
Notes
jax.numpy.linalg.pinvdiffers fromnumpy.linalg.pinvin the default value of rcond`: in NumPy, the default is 1e-15. In JAX, the default is10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps.Examples
>>> a = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> a_pinv = jnp.linalg.pinv(a) >>> a_pinv Array([[-1.333332 , -0.33333257, 0.6666657 ], [ 1.0833322 , 0.33333272, -0.41666582]], dtype=float32)
The pseudo-inverse operates as a multiplicative inverse so long as the output is not rank-deficient:
>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4) Array(True, dtype=bool)
- scico.numpy.linalg.qr(a, mode='reduced')¶
Compute the QR decomposition of an array
JAX implementation of
numpy.linalg.qr.The QR decomposition of a matrix A is given by
\[A = QR\]Where Q is a unitary matrix (i.e. \(Q^HQ=I\)) and R is an upper-triangular matrix.
- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape (…, M, N)mode (
str) –Computational mode. Supported values are:
"reduced"(default): return Q of shape(..., M, K)and R of shape(..., K, N), whereK = min(M, N)."complete": return Q of shape(..., M, M)and R of shape(..., M, N)."raw": return lapack-internal representations of shape(..., M, N)and(..., K)."r": return R only.
- Return type:
Array|QRResult- Returns:
A tuple
(Q, R)(ifmodeis not"r") otherwise an arrayR, where:Qis an orthogonal matrix of shape(..., M, K)(ifmodeis"reduced") or(..., M, M)(ifmodeis"complete").Ris an upper-triangular matrix of shape(..., M, N)(ifmodeis"r"or"complete") or(..., K, N)(ifmodeis"reduced")
with
K = min(M, N).
See also
jax.scipy.linalg.qr: SciPy-style QR decomposition APIjax.lax.linalg.qr: XLA-style QR decomposition API
Examples
Compute the QR decomposition of a matrix:
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jnp.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
Check that
Qis orthonormal:>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
Reconstruct the input:
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)
- scico.numpy.linalg.slogdet(a, *, method=None)¶
Compute the sign and (natural) logarithm of the determinant of an array.
JAX implementation of
numpy.linalg.slogdet.- Parameters:
- Return type:
SlogdetResult- Returns:
A tuple of arrays
(sign, logabsdet), each of shapea.shape[:-2]signis the sign of the determinant.logabsdetis the natural log of the determinant’s absolute value.
See also
jax.numpy.linalg.det: direct computation of determinantExamples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> sign, logabsdet = jnp.linalg.slogdet(a) >>> sign # -1 indicates negative determinant Array(-1., dtype=float32) >>> jnp.exp(logabsdet) # Absolute value of determinant Array(2., dtype=float32)
- scico.numpy.linalg.solve(a, b)¶
Solve a linear system of equations.
JAX implementation of
numpy.linalg.solve.This solves a (batched) linear system of equations
a @ x = bforxgivenaandb.If
ais singular, this will returnnanorinfvalues.- Parameters:
- Return type:
- Returns:
An array containing the result of the linear solve if
ais non-singular. The result has shape(..., N)ifbis of shape(N,), and has shape(..., N, M)otherwise. Ifais singular, the result containsnanorinfvalues.
See also
jax.scipy.linalg.solve: SciPy-style API for solving linear systems.jax.lax.custom_linear_solve: matrix-free linear solver.
Examples
A simple 3x3 linear system:
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jnp.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
Confirming that the result solves the system:
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)
- scico.numpy.linalg.svd(a, full_matrices=True, compute_uv=True, hermitian=False, subset_by_index=None)¶
Compute the singular value decomposition.
JAX implementation of
numpy.linalg.svd, implemented in terms ofjax.lax.linalg.svd.The SVD of a matrix A is given by
\[A = U\Sigma V^H\]\(U\) contains the left singular vectors and satisfies \(U^HU=I\)
\(V\) contains the right singular vectors and satisfies \(V^HV=I\)
\(\Sigma\) is a diagonal matrix of singular values.
- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array, of shape(..., N, M)full_matrices (
bool) – if True (default) compute the full matrices; i.e.uandvhhave shape(..., N, N)and(..., M, M). If False, then the shapes are(..., N, K)and(..., K, M)withK = min(N, M).compute_uv (
bool) – if True (default), return the full SVD(u, s, vh). If False then return only the singular valuess.hermitian (
bool) – if True, assume the matrix is hermitian, which allows for a more efficient implementation (default=False)subset_by_index (
tuple[int,int] |None) – (TPU-only) Optional 2-tuple [start, end] indicating the range of indices of singular values to compute. For example, if[n-2, n]thensvdcomputes the two largest singular values and their singular vectors. Only compatible withfull_matrices=False.
- Return type:
Array|SVDResult- Returns:
A tuple of arrays
(u, s, vh)ifcompute_uvis True, otherwise the arrays.u: left singular vectors of shape(..., N, N)iffull_matricesis True or(..., N, K)otherwise.s: singular values of shape(..., K)vh: conjugate-transposed right singular vectors of shape(..., M, M)iffull_matricesis True or(..., K, M)otherwise.
where
K = min(N, M).
See also
jax.scipy.linalg.svd: SciPy-style SVD APIjax.lax.linalg.svd: XLA-style SVD API
Examples
Consider the SVD of a small real-valued array:
>>> x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) >>> u, s, vt = jnp.linalg.svd(x, full_matrices=False) >>> s Array([9.361919 , 1.8315067], dtype=float32)
The singular vectors are in the columns of
uandv = vt.T. These vectors are orthonormal, which can be demonstrated by comparing the matrix product with the identity matrix:>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) >>> v = vt.T >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)
Given the SVD,
xcan be reconstructed via matrix multiplication:>>> x_reconstructed = u @ jnp.diag(s) @ vt >>> jnp.allclose(x_reconstructed, x) Array(True, dtype=bool)
- scico.numpy.linalg.svdvals(x, /)¶
Compute the singular values of a matrix.
JAX implementation of
numpy.linalg.svdvals.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, N)for which singular values will be computed.- Return type:
- Returns:
array of singular values of shape
(..., K)withK = min(M, N).
See also
jax.numpy.linalg.svd: compute singular values and singular vectorsExamples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.svdvals(x) Array([9.508031 , 0.7728694], dtype=float32)
- scico.numpy.linalg.tensordot(x1, x2, /, *, axes=2, precision=None, preferred_element_type=None, out_sharding=None)¶
Compute the tensor dot product of two N-dimensional arrays.
JAX implementation of
numpy.linalg.tensordot.- Parameters:
x1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional arrayx2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – M-dimensional arrayaxes (
int|tuple[Sequence[int],Sequence[int]]) – integer or tuple of sequences of integers. If an integer k, then sum over the last k axes ofx1and the first k axes ofx2, in order. If a tuple, thenaxes[0]specifies the axes ofx1andaxes[1]specifies the axes ofx2.precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofx1andx2.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
array containing the tensor dot product of the inputs
See also
jax.numpy.tensordot: equivalent API in thejax.numpynamespace.jax.numpy.einsum: NumPy API for more general tensor contractions.jax.lax.dot_general: XLA API for more general tensor contractions.
Examples
>>> x1 = jnp.arange(24.).reshape(2, 3, 4) >>> x2 = jnp.ones((3, 4, 5)) >>> jnp.linalg.tensordot(x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result when specifying the axes as explicit sequences:
>>> jnp.linalg.tensordot(x1, x2, axes=([1, 2], [0, 1])) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result via
einsum:>>> jnp.einsum('ijk,jkm->im', x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Setting
axes=1for two-dimensional inputs is equivalent to a matrix multiplication:>>> x1 = jnp.array([[1, 2], ... [3, 4]]) >>> x2 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.tensordot(x1, x2, axes=1) Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32) >>> x1 @ x2 Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32)
Setting
axes=0for one-dimensional inputs is equivalent tojax.numpy.linalg.outer:>>> x1 = jnp.array([1, 2]) >>> x2 = jnp.array([1, 2, 3]) >>> jnp.linalg.tensordot(x1, x2, axes=0) Array([[1, 2, 3], [2, 4, 6]], dtype=int32) >>> jnp.linalg.outer(x1, x2) Array([[1, 2, 3], [2, 4, 6]], dtype=int32)
- scico.numpy.linalg.tensorinv(a, ind=2)¶
Compute the tensor inverse of an array.
JAX implementation of
numpy.linalg.tensorinv.This computes the inverse of the
tensordotoperation with the sameindvalue.- Parameters:
- Return type:
- Returns:
array of shape
(*a.shape[ind:], *a.shape[:ind])containing the tensor inverse ofa.
Examples
>>> key = jax.random.key(1337) >>> x = jax.random.normal(key, shape=(2, 2, 4)) >>> xinv = jnp.linalg.tensorinv(x, 2) >>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2) >>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4) Array(True, dtype=bool)
- scico.numpy.linalg.tensorsolve(a, b, axes=None)¶
Solve the tensor equation a x = b for x.
JAX implementation of
numpy.linalg.tensorsolve.- Parameters:
a (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – input array. After reordering viaaxes(see below), shape must be(*b.shape, *x.shape).b (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – right-hand-side array.axes (
tuple[int,...] |None) – optional tuple specifying axes ofathat should be moved to the end
- Return type:
- Returns:
array x such that after reordering of axes of
a,tensordot(a, x, x.ndim)is equivalent tob.
Examples
>>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,)
Now show that
xcan be used to reconstructbusingtensordot:>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)
- scico.numpy.linalg.trace(x, /, *, offset=0, dtype=None)¶
Compute the trace of a matrix.
JAX implementation of
numpy.linalg.trace.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – array of shape(..., M, N)and whose innermost two dimensions form MxN matrices for which to take the trace.offset (
int) – positive or negative offset from the main diagonal (default: 0).dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – data type of the returned array (default:None). IfNone, then output dtype will match the dtype ofx, promoted to default precision in the case of integer types.
- Return type:
- Returns:
array of batched traces with shape
x.shape[:-2]
See also
jax.numpy.trace: similar API in thejax.numpynamespace.
Examples
Trace of a single matrix:
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.trace(x) Array(18, dtype=int32) >>> jnp.linalg.trace(x, offset=1) Array(21, dtype=int32) >>> jnp.linalg.trace(x, offset=-1, dtype="float32") Array(15., dtype=float32)
Batched traces:
>>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.trace(x) Array([15, 51], dtype=int32)
- scico.numpy.linalg.vecdot(x1, x2, /, *, axis=-1, precision=None, preferred_element_type=None)¶
Compute the (batched) vector conjugate dot product of two arrays.
JAX implementation of
numpy.linalg.vecdot.- Parameters:
x1 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – left-hand side array.x2 (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – right-hand side array. Size ofx2[axis]must match size ofx1[axis], and remaining dimensions must be broadcast-compatible.axis (
int) – axis along which to compute the dot product (default: -1)precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – eitherNone(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofx1andx2.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – eitherNone(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
array containing the conjugate dot product of
x1andx2alongaxis. The non-contracted dimensions are broadcast together.
See also
jax.numpy.vecdot: similar API in thejax.numpynamespace.jax.numpy.linalg.matmul: matrix multiplication.jax.numpy.linalg.tensordot: general tensor dot product.
Examples
Vector dot product of two 1D arrays:
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> jnp.linalg.vecdot(x1, x2) Array(32, dtype=int32)
Batched vector dot product of two 2D arrays:
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([[2, 3, 4]]) >>> jnp.linalg.vecdot(x1, x2, axis=-1) Array([20, 47], dtype=int32)
- scico.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)¶
Compute the vector norm of a vector or batch of vectors.
JAX implementation of
numpy.linalg.vector_norm.- Parameters:
x (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – N-dimensional array for which to take the norm.axis (
int|tuple[int,...] |None) – optional axis along which to compute the vector norm. If None (default) thenxis flattened and the norm is taken over all values.keepdims (
bool) – if True, keep the reduced dimensions in the output.ord (
int|str) – A string or int specifying the type of norm; default is the 2-norm. Seenumpy.linalg.normfor details on available options.
- Return type:
- Returns:
array containing the norm of
x.
See also
jax.numpy.linalg.matrix_norm: Norm of a matrix or stack of matrices.jax.numpy.linalg.norm: More general matrix or vector norm.
Examples
Norm of a single vector:
>>> x = jnp.array([1., 2., 3.]) >>> jnp.linalg.vector_norm(x) Array(3.7416575, dtype=float32)
Norm of a batch of vectors:
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.vector_norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)