scico.numpy.linalg

Linear algebra functions.

Functions

cholesky(a, *[, upper, symmetrize_input])

Compute the Cholesky decomposition of a matrix.

cond(x[, p])

Compute the condition number of a matrix.

cross(x1, x2, /, *[, axis])

Compute the cross-product of two 3D vectors

det(a)

Compute the determinant of an array.

diagonal(x, /, *[, offset])

Extract the diagonal of an matrix or stack of matrices.

eig(a)

Compute the eigenvalues and eigenvectors of a square array.

eigh(a[, UPLO, symmetrize_input])

Compute the eigenvalues and eigenvectors of a Hermitian matrix.

eigvals(a)

Compute the eigenvalues of a general matrix.

eigvalsh(a[, UPLO, symmetrize_input])

Compute the eigenvalues of a Hermitian matrix.

inv(a)

Return the inverse of a square matrix

lstsq(a, b[, rcond, numpy_resid])

Return the least-squares solution to a linear equation.

matmul(x1, x2, /, *[, precision, ...])

Perform a matrix multiplication.

matrix_norm(x, /, *[, keepdims, ord])

Compute the norm of a matrix or stack of matrices.

matrix_power(a, n)

Raise a square matrix to an integer power.

matrix_rank(M[, rtol, hermitian, tol])

Compute the rank of a matrix.

matrix_transpose(x, /)

Transpose a matrix or stack of matrices.

multi_dot(arrays, *[, precision])

Efficiently compute matrix products between a sequence of arrays.

norm(x[, ord, axis, keepdims])

Compute the norm of a matrix or vector.

outer(x1, x2, /)

Compute the outer product of two 1-dimensional arrays.

pinv(a[, rtol, hermitian, rcond])

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

qr(a[, mode])

Compute the QR decomposition of an array

slogdet(a, *[, method])

Compute the sign and (natural) logarithm of the determinant of an array.

solve(a, b)

Solve a linear system of equations.

svd(a[, full_matrices, compute_uv, ...])

Compute the singular value decomposition.

svdvals(x, /)

Compute the singular values of a matrix.

tensordot(x1, x2, /, *[, axes, precision, ...])

Compute the tensor dot product of two N-dimensional arrays.

tensorinv(a[, ind])

Compute the tensor inverse of an array.

tensorsolve(a, b[, axes])

Solve the tensor equation a x = b for x.

trace(x, /, *[, offset, dtype])

Compute the trace of a matrix.

vecdot(x1, x2, /, *[, axis, precision, ...])

Compute the (batched) vector conjugate dot product of two arrays.

vector_norm(x, /, *[, axis, keepdims, ord])

Compute the vector norm of a vector or batch of vectors.

scico.numpy.linalg.cholesky(a, *, upper=False, symmetrize_input=True)

Compute the Cholesky decomposition of a matrix.

JAX implementation of numpy.linalg.cholesky.

The Cholesky decomposition of a matrix A is:

\[A = U^HU\]

or

\[A = LL^H\]

where U is an upper-triangular matrix and L is a lower-triangular matrix, and \(X^H\) is the Hermitian transpose of X.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array, representing a (batched) positive-definite hermitian matrix. Must have shape (..., N, N).

  • upper (bool) – if True, compute the upper Cholesky decomposition U. if False (default), compute the lower Cholesky decomposition L.

  • symmetrize_input (bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.

Return type:

Array

Returns:

array of shape (..., N, N) representing the Cholesky decomposition of the input. If the input is not Hermitian positive-definite, the result will contain NaN entries.

See also

Examples

A small real Hermitian positive-definite matrix:

>>> x = jnp.array([[2., 1.],
...                [1., 2.]])

Lower Cholesky factorization:

>>> jnp.linalg.cholesky(x)
Array([[1.4142135 , 0.        ],
       [0.70710677, 1.2247449 ]], dtype=float32)

Upper Cholesky factorization:

>>> jnp.linalg.cholesky(x, upper=True)
Array([[1.4142135 , 0.70710677],
       [0.        , 1.2247449 ]], dtype=float32)

Reconstructing x from its factorization:

>>> L = jnp.linalg.cholesky(x)
>>> jnp.allclose(x, L @ L.T)
Array(True, dtype=bool)
scico.numpy.linalg.cond(x, p=None)

Compute the condition number of a matrix.

JAX implementation of numpy.linalg.cond.

The condition number is defined as norm(x, p) * norm(inv(x), p). For p = 2 (the default), the condition number is the ratio of the largest to the smallest singular value.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, N) for which to compute the condition number.

  • p – the order of the norm to use. One of {None, 1, -1, 2, -2, inf, -inf, 'fro'}; see jax.numpy.linalg.norm for the meaning of these. The default is p = None, which is equivalent to p = 2. If not in {None, 2, -2} then x must be square, i.e. M = N.

Returns:

array of shape x.shape[:-2] containing the condition number.

Examples

Well-conditioned matrix:

>>> x = jnp.array([[1, 2],
...                [2, 1]])
>>> jnp.linalg.cond(x)
Array(3., dtype=float32)

Ill-conditioned matrix:

>>> x = jnp.array([[1, 2],
...                [0, 0]])
>>> jnp.linalg.cond(x)
Array(inf, dtype=float32)
scico.numpy.linalg.cross(x1, x2, /, *, axis=-1)

Compute the cross-product of two 3D vectors

JAX implementation of numpy.linalg.cross

Parameters:
Returns:

array containing the result of the cross-product

See also

jax.numpy.cross: more flexible cross-product API.

Examples

Showing that \(\hat{x} \times \hat{y} = \hat{z}\):

>>> x = jnp.array([1., 0., 0.])
>>> y = jnp.array([0., 1., 0.])
>>> jnp.linalg.cross(x, y)
Array([0., 0., 1.], dtype=float32)

Cross product of \(\hat{x}\) with all three standard unit vectors, via broadcasting:

>>> xyz = jnp.eye(3)
>>> jnp.linalg.cross(x, xyz, axis=-1)
Array([[ 0.,  0.,  0.],
       [ 0.,  0.,  1.],
       [ 0., -1.,  0.]], dtype=float32)
scico.numpy.linalg.det(a)

Compute the determinant of an array.

JAX implementation of numpy.linalg.det.

Parameters:

a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, M) for which to compute the determinant.

Return type:

Array

Returns:

An array of determinants of shape a.shape[:-2].

See also

jax.scipy.linalg.det: Scipy-style API for determinant.

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.linalg.det(a)
Array(-2., dtype=float32)
scico.numpy.linalg.diagonal(x, /, *, offset=0)

Extract the diagonal of an matrix or stack of matrices.

JAX implementation of numpy.linalg.diagonal.

Parameters:
Return type:

Array

Returns:

Array of shape (..., K) where K is the length of the specified diagonal.

See also

Examples

Diagonals of a single matrix:

>>> x = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9, 10, 11, 12]])
>>> jnp.linalg.diagonal(x)
Array([ 1,  6, 11], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=1)
Array([ 2,  7, 12], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=-1)
Array([ 5, 10], dtype=int32)

Batched diagonals:

>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.diagonal(x)
Array([[ 0,  5, 10],
       [12, 17, 22]], dtype=int32)
scico.numpy.linalg.eig(a)

Compute the eigenvalues and eigenvectors of a square array.

JAX implementation of numpy.linalg.eig.

Parameters:

a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, M) for which to compute the eigenvalues and vectors.

Return type:

EigResult

Returns:

A namedtuple (eigenvalues, eigenvectors). The namedtuple has fields –

  • eigenvalues: an array of shape (..., M) containing the eigenvalues.

  • eigenvectors: an array of shape (..., M, M), where column v[:, i] is the eigenvector corresponding to the eigenvalue w[i].

Notes

See also

Examples

>>> a = jnp.array([[1., 2.],
...                [2., 1.]])
>>> w, v = jnp.linalg.eig(a)
>>> with jax.numpy.printoptions(precision=4):
...   w
Array([ 3.+0.j, -1.+0.j], dtype=complex64)
>>> v
Array([[ 0.70710677+0.j, -0.70710677+0.j],
       [ 0.70710677+0.j,  0.70710677+0.j]], dtype=complex64)
scico.numpy.linalg.eigh(a, UPLO=None, symmetrize_input=True)

Compute the eigenvalues and eigenvectors of a Hermitian matrix.

JAX implementation of numpy.linalg.eigh.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, M), containing the Hermitian (if complex) or symmetric (if real) matrix.

  • UPLO (str | None) – specifies whether the calculation is done with the lower triangular part of a ('L', default) or the upper triangular part ('U').

  • symmetrize_input (bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.

Return type:

EighResult

Returns:

A namedtuple (eigenvalues, eigenvectors) where

  • eigenvalues: an array of shape (..., M) containing the eigenvalues, sorted in ascending order.

  • eigenvectors: an array of shape (..., M, M), where column v[:, i] is the normalized eigenvector corresponding to the eigenvalue w[i].

See also

Examples

>>> a = jnp.array([[1, -2j],
...                [2j, 1]])
>>> w, v = jnp.linalg.eigh(a)
>>> w
Array([-1.,  3.], dtype=float32)
>>> with jnp.printoptions(precision=3):
...   v
Array([[-0.707+0.j   , -0.707+0.j   ],
       [ 0.   +0.707j,  0.   -0.707j]], dtype=complex64)
scico.numpy.linalg.eigvals(a)

Compute the eigenvalues of a general matrix.

JAX implementation of numpy.linalg.eigvals.

Parameters:

a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, M) for which to compute the eigenvalues.

Return type:

Array

Returns:

An array of shape (..., M) containing the eigenvalues.

See also

Notes

  • This differs from numpy.linalg.eigvals in that the return type of jax.numpy.linalg.eigvals is always complex64 for 32-bit input, and complex128 for 64-bit input.

  • At present, non-symmetric eigendecomposition is only implemented on the CPU backend.

Examples

>>> a = jnp.array([[1., 2.],
...                [2., 1.]])
>>> w = jnp.linalg.eigvals(a)
>>> with jnp.printoptions(precision=2):
...  w
Array([ 3.+0.j, -1.+0.j], dtype=complex64)
scico.numpy.linalg.eigvalsh(a, UPLO='L', *, symmetrize_input=True)

Compute the eigenvalues of a Hermitian matrix.

JAX implementation of numpy.linalg.eigvalsh.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, M), containing the Hermitian (if complex) or symmetric (if real) matrix.

  • UPLO (str | None) – specifies whether the calculation is done with the lower triangular part of a ('L', default) or the upper triangular part ('U').

  • symmetrize_input (bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition.

Return type:

Array

Returns:

An array of shape (..., M) containing the eigenvalues, sorted in ascending order.

See also

Examples

>>> a = jnp.array([[1, -2j],
...                [2j, 1]])
>>> w = jnp.linalg.eigvalsh(a)
>>> w
Array([-1.,  3.], dtype=float32)
scico.numpy.linalg.inv(a)

Return the inverse of a square matrix

JAX implementation of numpy.linalg.inv.

Parameters:

a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., N, N) specifying square array(s) to be inverted.

Return type:

Array

Returns:

Array of shape (..., N, N) containing the inverse of the input.

Notes

In most cases, explicitly computing the inverse of a matrix is ill-advised. For example, to compute x = inv(A) @ b, it is more performant and numerically precise to use a direct solve, such as jax.scipy.linalg.solve.

See also

Examples

Compute the inverse of a 3x3 matrix

>>> a = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> a_inv = jnp.linalg.inv(a)
>>> a_inv  
Array([[ 0.        , -0.25      ,  0.5       ],
       [-0.25      ,  0.5       , -0.25000003],
       [ 0.5       , -0.25      ,  0.        ]], dtype=float32)

Check that multiplying with the inverse gives the identity:

>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)

Multiply the inverse by a vector b, to find a solution to a @ x = b

>>> b = jnp.array([1., 4., 2.])
>>> a_inv @ b
Array([ 0.  ,  1.25, -0.5 ], dtype=float32)

Note, however, that explicitly computing the inverse in such a case can lead to poor performance and loss of precision as the size of the problem grows. Instead, you should use a direct solver like jax.numpy.linalg.solve:

>>> jnp.linalg.solve(a, b)
 Array([ 0.  ,  1.25, -0.5 ], dtype=float32)
scico.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)

Return the least-squares solution to a linear equation.

JAX implementation of numpy.linalg.lstsq.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (M, N) representing the coefficient matrix.

  • b (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (M,) or (M, K) representing the right-hand side.

  • rcond (float | None) – Cut-off ratio for small singular values. Singular values smaller than rcond * largest_singular_value are treated as zero. If None (default), the optimal value will be used to reduce floating point errors.

  • numpy_resid (bool) – If True, compute and return residuals in the same way as NumPy’s linalg.lstsq. This is necessary if you want to precisely replicate NumPy’s behavior. If False (default), a more efficient method is used to compute residuals.

Return type:

tuple[Array, Array, Array, Array]

Returns:

Tuple of arrays (x, resid, rank, s) where

  • x is a shape (N,) or (N, K) array containing the least-squares solution.

  • resid is the sum of squared residual of shape () or (K,).

  • rank is the rank of the matrix a.

  • s is the singular values of the matrix a.

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([5, 6])
>>> x, _, _, _ = jnp.linalg.lstsq(a, b)
>>> with jnp.printoptions(precision=3):
...   print(x)
[-4.   4.5]
scico.numpy.linalg.matmul(x1, x2, /, *, precision=None, preferred_element_type=None)

Perform a matrix multiplication.

JAX implementation of numpy.linalg.matmul.

Parameters:
  • x1 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – first input array, of shape (..., N).

  • x2 (Union[Array, ndarray, bool, number, bool, int, float, complex]) – second input array. Must have shape (N,) or (..., N, M). In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of x1.

  • precision (Union[None, str, Precision, tuple[str, str], tuple[Precision, Precision], DotAlgorithm, DotAlgorithmPreset]) – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two such values indicating precision of x1 and x2.

  • preferred_element_type (Union[str, type[Any], dtype, SupportsDType, None]) – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Return type:

Array

Returns:

array containing the matrix product of the inputs. Shape is x1.shape[:-1] if x2.ndim == 1, otherwise the shape is (..., M).

See also

jax.numpy.matmul: NumPy API for this function. jax.numpy.linalg.vecdot: batched vector product. jax.numpy.linalg.tensordot: batched tensor product.

Examples

Vector dot products:

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> jnp.linalg.matmul(x1, x2)
Array(32, dtype=int32)

Matrix dot product:

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> x2 = jnp.array([[1, 2],
...                 [3, 4],
...                 [5, 6]])
>>> jnp.linalg.matmul(x1, x2)
Array([[22, 28],
       [49, 64]], dtype=int32)

For convenience, in all cases you can do the same computation using the @ operator:

>>> x1 @ x2
Array([[22, 28],
       [49, 64]], dtype=int32)
scico.numpy.linalg.matrix_norm(x, /, *, keepdims=False, ord='fro')

Compute the norm of a matrix or stack of matrices.

JAX implementation of numpy.linalg.matrix_norm

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, N) for which to take the norm.

  • keepdims (bool) – if True, keep the reduced dimensions in the output.

  • ord (str | int) – A string or int specifying the type of norm; default is the Frobenius norm. See numpy.linalg.norm for details on available options.

Return type:

Array

Returns:

array containing the norm of x. Has shape x.shape[:-2] if keepdims is False, or shape (..., 1, 1) if keepdims is True.

See also

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.linalg.matrix_norm(x)
Array(16.881943, dtype=float32)
scico.numpy.linalg.matrix_power(a, n)

Raise a square matrix to an integer power.

JAX implementation of numpy.linalg.matrix_power, implemented via repeated squarings.

Parameters:
Return type:

Array

Returns:

Array of shape (..., M, M) containing the matrix power of a to the n.

Examples

>>> a = jnp.array([[1., 2.],
...                [3., 4.]])
>>> jnp.linalg.matrix_power(a, 3)
Array([[ 37.,  54.],
       [ 81., 118.]], dtype=float32)
>>> a @ a @ a  # equivalent evaluated directly
Array([[ 37.,  54.],
       [ 81., 118.]], dtype=float32)

This also supports zero powers:

>>> jnp.linalg.matrix_power(a, 0)
Array([[1., 0.],
       [0., 1.]], dtype=float32)

and also supports negative powers:

>>> with jnp.printoptions(precision=3):
...   jnp.linalg.matrix_power(a, -2)
Array([[ 5.5 , -2.5 ],
       [-3.75,  1.75]], dtype=float32)

Negative powers are equivalent to matmul of the inverse:

>>> inv_a = jnp.linalg.inv(a)
>>> with jnp.printoptions(precision=3):
...   inv_a @ inv_a
Array([[ 5.5 , -2.5 ],
       [-3.75,  1.75]], dtype=float32)
scico.numpy.linalg.matrix_rank(M, rtol=None, *, hermitian=False, tol=None)

Compute the rank of a matrix.

JAX implementation of numpy.linalg.matrix_rank.

The rank is calculated via the Singular Value Decomposition (SVD), and determined by the number of singular values greater than the specified tolerance.

Parameters:
  • M (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., N, K) whose rank is to be computed.

  • rtol (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – optional array of shape (...) specifying the tolerance. Singular values smaller than rtol * largest_singular_value are considered to be zero. If rtol is None (the default), a reasonable default is chosen based the floating point precision of the input.

  • hermitian (bool) – if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False)

  • tol (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – alias of the rtol argument present for backward compatibility. Only one of rtol or tol may be specified.

Return type:

Array

Returns:

array of shape a.shape[-2] giving the matrix rank.

Notes

The rank calculation may be inaccurate for matrices with very small singular values or those that are numerically ill-conditioned. Consider adjusting the rtol parameter or using a more specialized rank computation method in such cases.

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.linalg.matrix_rank(a)
Array(2, dtype=int32)
>>> b = jnp.array([[1, 0],  # Rank-deficient matrix
...                [0, 0]])
>>> jnp.linalg.matrix_rank(b)
Array(1, dtype=int32)
scico.numpy.linalg.matrix_transpose(x, /)

Transpose a matrix or stack of matrices.

JAX implementation of numpy.linalg.matrix_transpose.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, N)

Return type:

Array

Returns:

array of shape (..., N, M) containing the matrix transpose of x.

See also

jax.numpy.transpose: more general transpose operation.

Examples

Transpose of a single matrix:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.linalg.matrix_transpose(x)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

Transpose of a stack of matrices:

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.linalg.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

For convenience, the same computation can be done via the mT property of JAX array objects:

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)
scico.numpy.linalg.multi_dot(arrays, *, precision=None)

Efficiently compute matrix products between a sequence of arrays.

JAX implementation of numpy.linalg.multi_dot.

JAX internally uses the opt_einsum library to compute the most efficient operation order.

Parameters:
Return type:

Array

Returns:

an array representing the equivalent of reduce(jnp.matmul, arrays), but evaluated in the optimal order.

This function exists because the cost of computing sequences of matmul operations can differ vastly depending on the order in which the operations are evaluated. For a single matmul, the number of floating point operations (flops) required to compute a matrix product can be approximated this way:

>>> def approx_flops(x, y):
...   # for 2D x and y, with x.shape[1] == y.shape[0]
...   return 2 * x.shape[0] * x.shape[1] * y.shape[1]

Suppose we have three matrices that we’d like to multiply in sequence:

>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.normal(key1, shape=(200, 5))
>>> y = jax.random.normal(key2, shape=(5, 100))
>>> z = jax.random.normal(key3, shape=(100, 10))

Because of associativity of matrix products, there are two orders in which we might evaluate the product x @ y @ z, and both produce equivalent outputs up to floating point precision:

>>> result1 = (x @ y) @ z
>>> result2 = x @ (y @ z)
>>> jnp.allclose(result1, result2, atol=1E-4)
Array(True, dtype=bool)

But the computational cost of these differ greatly:

>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z))
(x @ y) @ z flops: 600000
>>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z))
x @ (y @ z) flops: 30000

The second approach is about 20x more efficient in terms of estimated flops!

multi_dot is a function that will automatically choose the fastest computational path for such problems:

>>> result3 = jnp.linalg.multi_dot([x, y, z])
>>> jnp.allclose(result1, result3, atol=1E-4)
Array(True, dtype=bool)

We can use JAX’s Ahead-of-time lowering and compilation tools to estimate the total flops of each approach, and confirm that multi_dot is choosing the more efficient option:

>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops']
600000.0
>>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops']
30000.0
>>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops']
30000.0
scico.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)

Compute the norm of a matrix or vector.

JAX implementation of numpy.linalg.norm.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array for which the norm will be computed.

  • ord (int | str | None) – specify the kind of norm to take. Default is Frobenius norm for matrices, and the 2-norm for vectors. For other options, see Notes below.

  • axis (None | tuple[int, ...] | int) – integer or sequence of integers specifying the axes over which the norm will be computed. For a single axis, compute a vector norm. For two axes, compute a matrix norm. Defaults to all axes of x.

  • keepdims (bool) – if True, the output array will have the same number of dimensions as the input, with the size of reduced axes replaced by 1 (default: False).

Return type:

Array

Returns:

array containing the specified norm of x.

Notes

The flavor of norm computed depends on the value of ord and the number of axes being reduced.

For vector norms (i.e. a single axis reduction):

  • ord=None (default) computes the 2-norm

  • ord=inf computes max(abs(x))

  • ord=-inf computes min(abs(x))``

  • ord=0 computes sum(x!=0)

  • for other numerical values, computes sum(abs(x) ** ord)**(1/ord)

For matrix norms (i.e. two axes reductions):

  • ord='fro' or ord=None (default) computes the Frobenius norm

  • ord='nuc' computes the nuclear norm, or the sum of the singular values

  • ord=1 computes max(abs(x).sum(0))

  • ord=-1 computes min(abs(x).sum(0))

  • ord=2 computes the 2-norm, i.e. the largest singular value

  • ord=-2 computes the smallest singular value

In the special case of ord=None and axis=None, this function accepts an array of any dimension and computes the vector 2-norm of the flattened array.

Examples

Vector norms:

>>> x = jnp.array([3., 4., 12.])
>>> jnp.linalg.norm(x)
Array(13., dtype=float32)
>>> jnp.linalg.norm(x, ord=1)
Array(19., dtype=float32)
>>> jnp.linalg.norm(x, ord=0)
Array(3., dtype=float32)

Matrix norms:

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.norm(x)  # Frobenius norm
Array(10.198039, dtype=float32)
>>> jnp.linalg.norm(x, ord='nuc')  # nuclear norm
Array(10.762535, dtype=float32)
>>> jnp.linalg.norm(x, ord=1)  # 1-norm
Array(10., dtype=float32)

Batched vector norm:

>>> jnp.linalg.norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)
scico.numpy.linalg.outer(x1, x2, /)

Compute the outer product of two 1-dimensional arrays.

JAX implementation of numpy.linalg.outer.

Parameters:
Return type:

Array

Returns:

array containing the outer product of x1 and x2

See also

jax.numpy.outer: similar function in the main jax.numpy module.

Examples

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> jnp.linalg.outer(x1, x2)
Array([[ 4,  5,  6],
       [ 8, 10, 12],
       [12, 15, 18]], dtype=int32)
scico.numpy.linalg.pinv(a, rtol=None, hermitian=False, *, rcond=None)

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

JAX implementation of numpy.linalg.pinv.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, N) containing matrices to pseudo-invert.

  • rtol (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – float or array_like of shape a.shape[:-2]. Specifies the cutoff for small singular values.of shape (...,). Cutoff for small singular values; singular values smaller rtol * largest_singular_value are treated as zero. The default is determined based on the floating point precision of the dtype.

  • hermitian (bool) – if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False)

  • rcond (Union[Array, ndarray, bool, number, bool, int, float, complex, None]) – alias of the rtol argument, present for backward compatibility. Only one of rtol and rcond may be specified.

Return type:

Array

Returns:

An array of shape (..., N, M) containing the pseudo-inverse of a.

See also

Notes

jax.numpy.linalg.pinv differs from numpy.linalg.pinv in the default value of rcond`: in NumPy, the default is 1e-15. In JAX, the default is 10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps.

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4],
...                [5, 6]])
>>> a_pinv = jnp.linalg.pinv(a)
>>> a_pinv  
Array([[-1.333332  , -0.33333257,  0.6666657 ],
       [ 1.0833322 ,  0.33333272, -0.41666582]], dtype=float32)

The pseudo-inverse operates as a multiplicative inverse so long as the output is not rank-deficient:

>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4)
Array(True, dtype=bool)
scico.numpy.linalg.qr(a, mode='reduced')

Compute the QR decomposition of an array

JAX implementation of numpy.linalg.qr.

The QR decomposition of a matrix A is given by

\[A = QR\]

Where Q is a unitary matrix (i.e. \(Q^HQ=I\)) and R is an upper-triangular matrix.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (…, M, N)

  • mode (str) –

    Computational mode. Supported values are:

    • "reduced" (default): return Q of shape (..., M, K) and R of shape (..., K, N), where K = min(M, N).

    • "complete": return Q of shape (..., M, M) and R of shape (..., M, N).

    • "raw": return lapack-internal representations of shape (..., M, N) and (..., K).

    • "r": return R only.

Return type:

Array | QRResult

Returns:

A tuple (Q, R) (if mode is not "r") otherwise an array R, where:

  • Q is an orthogonal matrix of shape (..., M, K) (if mode is "reduced") or (..., M, M) (if mode is "complete").

  • R is an upper-triangular matrix of shape (..., M, N) (if mode is "r" or "complete") or (..., K, N) (if mode is "reduced")

with K = min(M, N).

See also

Examples

Compute the QR decomposition of a matrix:

>>> a = jnp.array([[1., 2., 3., 4.],
...                [5., 4., 2., 1.],
...                [6., 3., 1., 5.]])
>>> Q, R = jnp.linalg.qr(a)
>>> Q  
Array([[-0.12700021, -0.7581426 , -0.6396022 ],
       [-0.63500065, -0.43322435,  0.63960224],
       [-0.7620008 ,  0.48737738, -0.42640156]], dtype=float32)
>>> R  
Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ],
       [ 0.       , -1.7870499, -2.6534991, -1.028908 ],
       [ 0.       ,  0.       , -1.0660033, -4.050814 ]], dtype=float32)

Check that Q is orthonormal:

>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)

Reconstruct the input:

>>> jnp.allclose(Q @ R, a)
Array(True, dtype=bool)
scico.numpy.linalg.slogdet(a, *, method=None)

Compute the sign and (natural) logarithm of the determinant of an array.

JAX implementation of numpy.linalg.slogdet.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, M) for which to compute the sign and log determinant.

  • method (str | None) –

    the method to use for determinant computation. Options are

    • 'lu' (default): use the LU decomposition.

    • 'qr': use the QR decomposition.

Return type:

SlogdetResult

Returns:

A tuple of arrays (sign, logabsdet), each of shape a.shape[:-2]

  • sign is the sign of the determinant.

  • logabsdet is the natural log of the determinant’s absolute value.

See also

jax.numpy.linalg.det: direct computation of determinant

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> sign, logabsdet = jnp.linalg.slogdet(a)
>>> sign  # -1 indicates negative determinant
Array(-1., dtype=float32)
>>> jnp.exp(logabsdet)  # Absolute value of determinant
Array(2., dtype=float32)
scico.numpy.linalg.solve(a, b)

Solve a linear system of equations.

JAX implementation of numpy.linalg.solve.

This solves a (batched) linear system of equations a @ x = b for x given a and b.

If a is singular, this will return nan or inf values.

Parameters:
Return type:

Array

Returns:

An array containing the result of the linear solve if a is non-singular. The result has shape (..., N) if b is of shape (N,), and has shape (..., N, M) otherwise. If a is singular, the result contains nan or inf values.

See also

Examples

A simple 3x3 linear system:

>>> A = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jnp.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)

Confirming that the result solves the system:

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)
scico.numpy.linalg.svd(a, full_matrices=True, compute_uv=True, hermitian=False, subset_by_index=None)

Compute the singular value decomposition.

JAX implementation of numpy.linalg.svd, implemented in terms of jax.lax.linalg.svd.

The SVD of a matrix A is given by

\[A = U\Sigma V^H\]
  • \(U\) contains the left singular vectors and satisfies \(U^HU=I\)

  • \(V\) contains the right singular vectors and satisfies \(V^HV=I\)

  • \(\Sigma\) is a diagonal matrix of singular values.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – input array, of shape (..., N, M)

  • full_matrices (bool) – if True (default) compute the full matrices; i.e. u and vh have shape (..., N, N) and (..., M, M). If False, then the shapes are (..., N, K) and (..., K, M) with K = min(N, M).

  • compute_uv (bool) – if True (default), return the full SVD (u, s, vh). If False then return only the singular values s.

  • hermitian (bool) – if True, assume the matrix is hermitian, which allows for a more efficient implementation (default=False)

  • subset_by_index (tuple[int, int] | None) – (TPU-only) Optional 2-tuple [start, end] indicating the range of indices of singular values to compute. For example, if [n-2, n] then svd computes the two largest singular values and their singular vectors. Only compatible with full_matrices=False.

Return type:

Array | SVDResult

Returns:

A tuple of arrays (u, s, vh) if compute_uv is True, otherwise the array s.

  • u: left singular vectors of shape (..., N, N) if full_matrices is True or (..., N, K) otherwise.

  • s: singular values of shape (..., K)

  • vh: conjugate-transposed right singular vectors of shape (..., M, M) if full_matrices is True or (..., K, M) otherwise.

where K = min(N, M).

See also

Examples

Consider the SVD of a small real-valued array:

>>> x = jnp.array([[1., 2., 3.],
...                [6., 5., 4.]])
>>> u, s, vt = jnp.linalg.svd(x, full_matrices=False)
>>> s  
Array([9.361919 , 1.8315067], dtype=float32)

The singular vectors are in the columns of u and v = vt.T. These vectors are orthonormal, which can be demonstrated by comparing the matrix product with the identity matrix:

>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
>>> v = vt.T
>>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)

Given the SVD, x can be reconstructed via matrix multiplication:

>>> x_reconstructed = u @ jnp.diag(s) @ vt
>>> jnp.allclose(x_reconstructed, x)
Array(True, dtype=bool)
scico.numpy.linalg.svdvals(x, /)

Compute the singular values of a matrix.

JAX implementation of numpy.linalg.svdvals.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, N) for which singular values will be computed.

Return type:

Array

Returns:

array of singular values of shape (..., K) with K = min(M, N).

See also

jax.numpy.linalg.svd: compute singular values and singular vectors

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.linalg.svdvals(x)
Array([9.508031 , 0.7728694], dtype=float32)
scico.numpy.linalg.tensordot(x1, x2, /, *, axes=2, precision=None, preferred_element_type=None, out_sharding=None)

Compute the tensor dot product of two N-dimensional arrays.

JAX implementation of numpy.linalg.tensordot.

Parameters:
Return type:

Array

Returns:

array containing the tensor dot product of the inputs

See also

Examples

>>> x1 = jnp.arange(24.).reshape(2, 3, 4)
>>> x2 = jnp.ones((3, 4, 5))
>>> jnp.linalg.tensordot(x1, x2)
Array([[ 66.,  66.,  66.,  66.,  66.],
       [210., 210., 210., 210., 210.]], dtype=float32)

Equivalent result when specifying the axes as explicit sequences:

>>> jnp.linalg.tensordot(x1, x2, axes=([1, 2], [0, 1]))
Array([[ 66.,  66.,  66.,  66.,  66.],
       [210., 210., 210., 210., 210.]], dtype=float32)

Equivalent result via einsum:

>>> jnp.einsum('ijk,jkm->im', x1, x2)
Array([[ 66.,  66.,  66.,  66.,  66.],
       [210., 210., 210., 210., 210.]], dtype=float32)

Setting axes=1 for two-dimensional inputs is equivalent to a matrix multiplication:

>>> x1 = jnp.array([[1, 2],
...                 [3, 4]])
>>> x2 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> jnp.linalg.tensordot(x1, x2, axes=1)
Array([[ 9, 12, 15],
       [19, 26, 33]], dtype=int32)
>>> x1 @ x2
Array([[ 9, 12, 15],
       [19, 26, 33]], dtype=int32)

Setting axes=0 for one-dimensional inputs is equivalent to jax.numpy.linalg.outer:

>>> x1 = jnp.array([1, 2])
>>> x2 = jnp.array([1, 2, 3])
>>> jnp.linalg.tensordot(x1, x2, axes=0)
Array([[1, 2, 3],
       [2, 4, 6]], dtype=int32)
>>> jnp.linalg.outer(x1, x2)
Array([[1, 2, 3],
       [2, 4, 6]], dtype=int32)
scico.numpy.linalg.tensorinv(a, ind=2)

Compute the tensor inverse of an array.

JAX implementation of numpy.linalg.tensorinv.

This computes the inverse of the tensordot operation with the same ind value.

Parameters:
  • a (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array to be inverted. Must have prod(a.shape[:ind]) == prod(a.shape[ind:])

  • ind (int) – positive integer specifying the number of indices in the tensor product.

Return type:

Array

Returns:

array of shape (*a.shape[ind:], *a.shape[:ind]) containing the tensor inverse of a.

Examples

>>> key = jax.random.key(1337)
>>> x = jax.random.normal(key, shape=(2, 2, 4))
>>> xinv = jnp.linalg.tensorinv(x, 2)
>>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2)
>>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4)
Array(True, dtype=bool)
scico.numpy.linalg.tensorsolve(a, b, axes=None)

Solve the tensor equation a x = b for x.

JAX implementation of numpy.linalg.tensorsolve.

Parameters:
Return type:

Array

Returns:

array x such that after reordering of axes of a, tensordot(a, x, x.ndim) is equivalent to b.

Examples

>>> key1, key2 = jax.random.split(jax.random.key(8675309))
>>> a = jax.random.normal(key1, shape=(2, 2, 4))
>>> b = jax.random.normal(key2, shape=(2, 2))
>>> x = jnp.linalg.tensorsolve(a, b)
>>> x.shape
(4,)

Now show that x can be used to reconstruct b using tensordot:

>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim)
>>> jnp.allclose(b, b_reconstructed)
Array(True, dtype=bool)
scico.numpy.linalg.trace(x, /, *, offset=0, dtype=None)

Compute the trace of a matrix.

JAX implementation of numpy.linalg.trace.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – array of shape (..., M, N) and whose innermost two dimensions form MxN matrices for which to take the trace.

  • offset (int) – positive or negative offset from the main diagonal (default: 0).

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – data type of the returned array (default: None). If None, then output dtype will match the dtype of x, promoted to default precision in the case of integer types.

Return type:

Array

Returns:

array of batched traces with shape x.shape[:-2]

See also

Examples

Trace of a single matrix:

>>> x = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9, 10, 11, 12]])
>>> jnp.linalg.trace(x)
Array(18, dtype=int32)
>>> jnp.linalg.trace(x, offset=1)
Array(21, dtype=int32)
>>> jnp.linalg.trace(x, offset=-1, dtype="float32")
Array(15., dtype=float32)

Batched traces:

>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.trace(x)
Array([15, 51], dtype=int32)
scico.numpy.linalg.vecdot(x1, x2, /, *, axis=-1, precision=None, preferred_element_type=None)

Compute the (batched) vector conjugate dot product of two arrays.

JAX implementation of numpy.linalg.vecdot.

Parameters:
Return type:

Array

Returns:

array containing the conjugate dot product of x1 and x2 along axis. The non-contracted dimensions are broadcast together.

See also

Examples

Vector dot product of two 1D arrays:

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> jnp.linalg.vecdot(x1, x2)
Array(32, dtype=int32)

Batched vector dot product of two 2D arrays:

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> x2 = jnp.array([[2, 3, 4]])
>>> jnp.linalg.vecdot(x1, x2, axis=-1)
Array([20, 47], dtype=int32)
scico.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)

Compute the vector norm of a vector or batch of vectors.

JAX implementation of numpy.linalg.vector_norm.

Parameters:
  • x (Union[Array, ndarray, bool, number, bool, int, float, complex]) – N-dimensional array for which to take the norm.

  • axis (int | tuple[int, ...] | None) – optional axis along which to compute the vector norm. If None (default) then x is flattened and the norm is taken over all values.

  • keepdims (bool) – if True, keep the reduced dimensions in the output.

  • ord (int | str) – A string or int specifying the type of norm; default is the 2-norm. See numpy.linalg.norm for details on available options.

Return type:

Array

Returns:

array containing the norm of x.

See also

Examples

Norm of a single vector:

>>> x = jnp.array([1., 2., 3.])
>>> jnp.linalg.vector_norm(x)
Array(3.7416575, dtype=float32)

Norm of a batch of vectors:

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.vector_norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)