API Reference¶
Scientific Computational Imaging COde (SCICO) is a Python package for solving the inverse problems that arise in scientific imaging applications.
Data files for usage examples. |
|
Interfaces to standard denoisers. |
|
Diagnostic information for iterative solvers. |
|
Utility functions used by example scripts. |
|
Neural network models implemented in Flax and utility functions. |
|
Function class. |
|
Functionals and functionals classes. |
|
Linear operator functions and classes. |
|
Loss function classes. |
|
Image quality metrics and related functions. |
|
|
|
Operator functions and classes. |
|
Optimization algorithms. |
|
Plotting/visualization functions. |
|
Random number generation. |
|
Simplified interfaces to Ray. |
|
Wrapped versions of jax.scipy functions. |
|
Solver and optimization algorithms. |
|
Call tracing of scico functions and methods. |
|
Type definitions. |
|
General utility functions. |
- class scico.custom_jvp(fun, nondiff_argnums=(), nondiff_argnames=())¶
Bases:
Generic[ReturnValue]Set up a JAX-transformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like
jax.jvporjax.grad) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.There are two instance methods available for defining the custom JVP rule:
defjvpfor defining a single custom JVP rule for all the function’s inputs, and for conveniencedefjvps, which wrapsdefjvp, and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.For example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.
- defjvp(jvp, symbolic_zeros=False)[source]¶
Define a custom JVP rule for the function represented by this instance.
- Parameters:
jvp (
Callable[...,tuple[TypeVar(ReturnValue),TypeVar(ReturnValue)]]) – a Python callable representing the custom JVP rule. When there are nonondiff_argnums, thejvpfunction should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of parameters of thecustom_jvpfunction. Thejvpfunction should produce as output a pair where the first element is the primal output and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof.symbolic_zeros (
bool) – boolean, indicating whether the rule should be passed objects representing static symbolic zeros in its tangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed. Setting this option toTrueallows a JVP rule to detect whether certain inputs are not involved in differentiation, but at the cost of needing special handling for these objects (which e.g. can’t be passed into jax.numpy functions). DefaultFalse.
- Return type:
Callable[...,tuple[TypeVar(ReturnValue),TypeVar(ReturnValue)]]- Returns:
Returns
jvpso thatdefjvpcan be used as a decorator.
Examples
>>> @jax.custom_jvp ... def f(x, y): ... return jnp.sin(x) * y ... >>> @f.defjvp ... def f_jvp(primals, tangents): ... x, y = primals ... x_dot, y_dot = tangents ... primal_out = f(x, y) ... tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot ... return primal_out, tangent_out
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))
- defjvps(*jvps)[source]¶
Convenience wrapper for defining JVPs for each argument separately.
This convenience wrapper cannot be used together with
nondiff_argnums.- Parameters:
*jvps (
Callable[...,TypeVar(ReturnValue)] |None) – a sequence of functions, one for each positional argument of thecustom_jvpfunction. Each function takes as arguments the tangent value for the corresponding primal input, the primal output, and the primal inputs. See the example below.- Return type:
- Returns:
None.
Examples
>>> @jax.custom_jvp ... def f(x, y): ... return jnp.sin(x) * y ... >>> f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, ... lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))
- class scico.custom_vjp(fun, nondiff_argnums=(), nondiff_argnames=())¶
Bases:
Generic[ReturnValue]Set up a JAX-transformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like
jax.grad) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method,defvjp, which may be used to define the custom VJP rule.This decorator precludes the use of forward-mode automatic differentiation.
For example:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
For a more detailed introduction, see the tutorial.
- defvjp(fwd, bwd, symbolic_zeros=False, optimize_remat=False)[source]¶
Define a custom VJP rule for the function represented by this instance.
- Parameters:
fwd (
Callable[...,tuple[TypeVar(ReturnValue),Any]]) – a Python callable representing the forward pass of the custom VJP rule. When there are nonondiff_argnums, thefwdfunction has the same input signature as the underlying primal function. It should return as output a pair, where the first element represents the primal output and the second element represents any “residual” values to store from the forward pass for use on the backward pass by the functionbwd. Input arguments and elements of the output pair may be arrays or nested tuples/lists/dicts thereof.bwd (
Callable[...,tuple[Any,...]]) – a Python callable representing the backward pass of the custom VJP rule. When there are nonondiff_argnums, thebwdfunction takes two arguments, where the first is the “residual” values produced on the forward pass byfwd, and the second is the output cotangent with the same structure as the primal function output. The output ofbwdmust be a tuple of length equal to the number of arguments of the primal function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments.symbolic_zeros (
bool) –boolean, determining whether to indicate symbolic zeros to the
fwdandbwdrules. Enabling this option allows custom derivative rules to detect when certain inputs, and when certain output cotangents, are not involved in differentiation. IfTrue:fwdmust accept, in place of each leaf valuexin the pytree comprising an argument to the original function, an object (of typejax.custom_derivatives.CustomVJPPrimal) with two attributes instead:valueandperturbed. Thevaluefield is the original primal argument, andperturbedis a boolean. Theperturbedbit indicates whether the argument is involved in differentiation (i.e., if it isFalse, then the corresponding Jacobian “column” is zero).bwdwill be passed objects representing static symbolic zeros in its cotangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed.
Setting this option to
Trueallows these rules to detect whether certain inputs and outputs are not involved in differentiation, but at the cost of special handling. For instance:The signature of
fwdchanges, and the objects it is passed cannot be output from the rule directly.The
bwdrule is passed objects that are not entirely array-like, and that cannot be passed to mostjax.numpyfunctions.Any custom pytree nodes involved in the primal function’s arguments must accept, in their unflattening functions, the two-field record objects that are given as input leaves to the
fwdrule.
Default
False.optimize_remat (
bool) – boolean, an experimental flag to enable an automatic optimization when this function is used underjax.remat. This will be most useful when thefwdrule is an opaque call such as a Pallas kernel or a custom call. DefaultFalse.
- Return type:
- Returns:
None.
Examples
>>> @jax.custom_vjp ... def f(x, y): ... return jnp.sin(x) * y ... >>> def f_fwd(x, y): ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) ... >>> def f_bwd(res, g): ... cos_x, sin_x, y = res ... return (cos_x * g * y, sin_x * g) ... >>> f.defvjp(f_fwd, f_bwd)
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))
- scico.hessian(fun, argnums=0, has_aux=False, holomorphic=False)¶
Hessian of
funas a dense array.- Parameters:
fun (
Callable) – Function whose Hessian is to be computed. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.argnums (
int|Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0).has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. Default False.
- Return type:
- Returns:
A function with the same arguments as
fun, that evaluates the Hessian offun.
>>> import jax >>> >>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. -2.] [ -2. -480.]]
hessianis a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure ofjax.hessian(fun)(x)is given by forming a tree product of the structure offun(x)with a tree product of two copies of the structure ofx. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:>>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': Array([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of
jax.hessian(fun)(x)corresponds to a leaf offun(x)and a pair of leaves ofx. For each leaf injax.hessian(fun)(x), if the corresponding array leaf offun(x)has shape(out_1, out_2, ...)and the corresponding array leaves ofxhave shape(in_1_1, in_1_2, ...)and(in_2_1, in_2_2, ...)respectively, then the Hessian leaf has shape(out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...). In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.In particular, an array is produced (with no pytrees involved) when the function input
xand outputfun(x)are each a single array, as in thegexample above. Iffun(x)has shape(out1, out2, ...)andxhas shape(in1, in2, ...)thenjax.hessian(fun)(x)has shape(out1, out2, ..., in1, in2, ..., in1, in2, ...). To flatten pytrees into 1D vectors, consider usingjax.flatten_util.flatten_pytree.
- scico.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)¶
Jacobian of
funevaluated column-by-column using forward-mode AD.- Parameters:
fun (
Callable) – Function whose Jacobian is to be computed.argnums (
int|Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0).has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. Default False.
- Return type:
- Returns:
A function with the same arguments as
fun, that evaluates the Jacobian offunusing forward-mode automatic differentiation. Ifhas_auxis True then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]
- scico.jvp(fun, primals, tangents, has_aux=False)¶
Computes a (forward-mode) Jacobian-vector product of
fun.- Parameters:
fun (
Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – The primal values at which the Jacobian of
funshould be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun.tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals.has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
- Return type:
- Returns:
If
has_auxisFalse, returns a(primals_out, tangents_out)pair, whereprimals_outisfun(*primals), andtangents_outis the Jacobian-vector product offunctionevaluated atprimalswithtangents. Thetangents_outvalue has the same Python tree structure and shapes asprimals_out. Ifhas_auxisTrue, returns a(primals_out, tangents_out, aux)tuple whereauxis the auxiliary data returned byfun.
For example:
>>> import jax >>> >>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(primals) 0.09983342 >>> print(tangents) 0.19900084
- scico.linearize(fun, *primals, has_aux=False)¶
Produces a linear approximation to
funusingjvpand partial eval.- Parameters:
fun (
Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.primals – The primal values at which the Jacobian of
funshould be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters offun.has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be linearized, and the second is auxiliary data. Default False.
- Return type:
- Returns:
If
has_auxisFalse, returns a pair where the first element is the value off(*primals)and the second element is a function that evaluates the (forward-mode) Jacobian-vector product offunevaluated atprimalswithout re-doing the linearization work. Ifhas_auxisTrue, returns a(primals_out, lin_fn, aux)tuple whereauxis the auxiliary data returned byfun.
In terms of values computed,
linearizebehaves much like a curriedjvp, where these two code blocks compute the same values:y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that
linearizeuses partial evaluation so that the functionfis not re-linearized on calls tof_jvp. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed,linearizehas a similar signature tovjp!)This function is mainly useful if you want to apply
f_jvpmultiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize usingvmap, as in:pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using
vmapandjvptogether like this we avoid the stored-linearization memory cost that scales with the depth of the computation, which is incurred by bothlinearizeandvjp.Here’s a more complete example of using
linearize:>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704
- scico.vjp(fun, *primals, has_aux=False, reduce_axes=())¶
Compute a (reverse-mode) vector-Jacobian product of
fun.gradis implemented as a special case ofvjp.- Parameters:
fun (
Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – A sequence of primal values at which the Jacobian of
funshould be evaluated. The number ofprimalsshould be equal to the number of positional parameters offun. Each primal value should be an array, a scalar, or a pytree (standard Python containers) thereof.has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
- Return type:
- Returns:
If
has_auxisFalse, returns a(primals_out, vjpfun)pair, whereprimals_outisfun(*primals). Ifhas_auxisTrue, returns a(primals_out, vjpfun, aux)tuple whereauxis the auxiliary data returned byfun.vjpfunis a function from a cotangent vector with the same shape asprimals_outto a tuple of cotangent vectors with the same number and shapes asprimals, representing the vector-Jacobian product offunevaluated atprimals.
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((-0.7, 0.3)) >>> print(xbar) -0.61430776 >>> print(ybar) -0.2524413
- scico.cvjp(fun, *primals, jidx=None)¶
Compute a vector-Jacobian product with conjugate transpose.
Compute the product \([J(\mb{x})]^H \mb{v}\) where \([J(\mb{x})]\) is the Jacobian of function fun evaluated at \(\mb{x}\). Instead of directly evaluating the product, a function is returned that takes \(\mb{v}\) as an argument. If fun has multiple positional parameters, the Jacobian can be taken with respect to only one of them by setting the jidx parameter of this function to the positional index of that parameter.
- Parameters:
fun (
Callable) – Function for which the Jacobian is implicitly computed.primals – Sequence of values at which the Jacobian is evaluated, with length equal to the number of positional arguments of fun.
jidx (
Optional[int]) – Index of the positional parameter of fun with respect to which the Jacobian is taken.
- Return type:
- Returns:
A pair (primals_out, conj_vjp) where primals_out is the output of fun evaluated at primals, i.e. primals_out = fun(*primals), and conj_vjp is a function that computes the product of the conjugate (Hermitian) transpose of the Jacobian of fun and its argument. If the jidx parameter is an integer, then the Jacobian is only taken with respect to the coresponding positional parameter of fun.
- scico.eval_shape(fun, *args, **kwargs)¶
Compute the shape and dtype of a function without executing it.
Compute the shape and dtype of a function without executing it, via a call to
jax.eval_shape, withargsandkwargsmapped to handlejax.ShapeDtypeStructobjects with nested shapes corresponding toBlockArrayobjects.Docstring for
jax.eval_shape:This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.tree_util.tree_map(shape_dtype_struct, out)
But instead of applying
fundirectly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shapecan also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs).- Parameters:
fun (
Callable) – The function whose output shape should be evaluated.*args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the
shapeanddtypeattributes are accessed, one can usejax.ShapeDtypeStructor another container that duck-types as ndarrays (note however that duck-typed objects cannot be namedtuples because those are treated as standard Python containers).**kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in
args, array values need only be duck-typed to haveshapeanddtypeattributes.
- Return type:
- Returns:
out – a nested PyTree containing
jax.ShapeDtypeStructobjects as leaves.
For example:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32
All arguments passed via
eval_shapewill be treated as dynamic; static arguments can be included via closure, for example usingfunctools.partial:>>> import jax >>> from jax import lax >>> from functools import partial >>> import jax.numpy as jnp >>> >>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32) >>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32) >>> >>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME") >>> out = jax.eval_shape(conv_same, x, kernel) >>> print(out.shape) (1, 32, 28, 28) >>> print(out.dtype) float32
- scico.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)¶
Create a function that evaluates the gradient of fun.
scico.graddiffers fromjax.gradin that the output is conjugated.Docstring for
jax.grad:- Parameters:
fun (
Callable) – Function to be differentiated. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers. Argument arrays in the positions specified byargnumsmust be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums (
Union[int,Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
- Return type:
- Returns:
A function with the same arguments as
fun, that evaluates the gradient offun. Ifargnumsis an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_auxis True then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043
- scico.jacrev(fun, argnums=0, holomorphic=False, allow_int=False)¶
Jacobian of fun evaluated row-by-row using reverse-mode AD.
scico.jacrevdiffers fromjax.jacrevin that the output is conjugated.Docstring for
jax.jacrev:- Parameters:
fun (
Callable) – Function whose Jacobian is to be computed.argnums (
Union[int,Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0).has_aux – Optional, bool. Indicates whether
funreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. Default False.allow_int (
bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
- Return type:
- Returns:
A function with the same arguments as
fun, that evaluates the Jacobian offunusing reverse-mode automatic differentiation. Ifhas_auxis True then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]
- scico.linear_adjoint(fun, *primals)¶
Conjugate transpose a function that is guaranteed to be linear.
scico.linear_adjointdiffers fromjax.linear_transposefor complex inputs in that the conjugate transpose (adjoint) of fun is returned.scico.linear_adjointis identical tojax.linear_transposefor real-valued primals.Docstring for
jax.linear_adjoint:For linear functions, this transformation is equivalent to
vjp, but avoids the overhead of computing the forward pass.The outputs of the transposed function will always have the exact same dtypes as
primals, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes inprimalsthat match the full range of desired outputs from the transposed function. Integer dtypes are not supported.- Parameters:
fun (
Callable) – the linear function to be transposed.*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals). These arguments may be real scalars/ndarrays, but that is not required: only theshapeanddtypeattributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)
- Return type:
- Returns:
A callable that calculates the transpose of
fun. Valid input into this function must have the same shape/dtypes/structure as the result offun(*primals). Output will be a tuple, with the same shape/dtypes/structure asprimals.
>>> import jax >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y >>> scalar = jax.ShapeDtypeStruct(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
- scico.linear_transpose(fun, *primals)¶
Transpose a function that is guaranteed to be linear.
scico.linear_adjointdiffers fromjax.linear_transposein that it correctly handles primals consisting ofjax.ShapeDtypeStructobjects with nested shapes, i.e. corresponding toBlockArrayshapes.Docstring for
jax.linear_transpose:For linear functions, this transformation is equivalent to
vjp, but avoids the overhead of computing the forward pass.The outputs of the transposed function will always have the exact same dtypes as
primals, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes inprimalsthat match the full range of desired outputs from the transposed function. Integer dtypes are not supported.- Parameters:
fun (
Callable) – the linear function to be transposed.*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals). These arguments may be real scalars/ndarrays, but that is not required: only theshapeanddtypeattributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)
- Return type:
- Returns:
A callable that calculates the transpose of
fun. Valid input into this function must have the same shape/dtypes/structure as the result offun(*primals). Output will be a tuple, with the same shape/dtypes/structure asprimals.
>>> import jax >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y >>> scalar = jax.ShapeDtypeStruct(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
- scico.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)¶
Create a function that evaluates both fun and its gradient.
scico.value_and_graddiffers fromjax.value_and_gradin that the gradient is conjugated.Docstring for
jax.value_and_grad:- Parameters:
fun (
Callable) – Function to be differentiated. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums (
Union[int,Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
- Return type:
- Returns:
A function with the same arguments as
funthat evaluates bothfunand the gradient offunand returns them as a pair (a two-element tuple). Ifargnumsis an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_auxis True then a tuple of ((value, auxiliary_data), gradient) is returned.