API Reference#
Scientific Computational Imaging COde (SCICO) is a Python package for solving the inverse problems that arise in scientific imaging applications.
Data files for usage examples. |
|
Interfaces to standard denoisers. |
|
Diagnostic information for iterative solvers. |
|
Utility functions used by example scripts. |
|
Neural network models implemented in Flax and utility functions. |
|
Function class. |
|
Functionals and functionals classes. |
|
Linear operator functions and classes. |
|
Loss function classes. |
|
Image quality metrics and related functions. |
|
|
|
Operator functions and classes. |
|
Optimization algorithms. |
|
Plotting/visualization functions. |
|
Random number generation. |
|
Simplified interfaces to ray. |
|
Wrapped versions of |
|
Solver and optimization algorithms. |
|
Type definitions. |
|
General utility functions. |
- class scico.custom_jvp(fun, nondiff_argnums=())#
Bases:
Generic
[ReturnValue
]Set up a JAX-transformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like
jax.jvp
orjax.grad
) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.There are two instance methods available for defining the custom JVP rule:
defjvp
for defining a single custom JVP rule for all the function’s inputs, and for conveniencedefjvps
, which wrapsdefjvp
, and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.For example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.
- defjvp(jvp, symbolic_zeros=False)[source]#
Define a custom JVP rule for the function represented by this instance.
- Parameters:
jvp (
Callable
[...
,Tuple
[TypeVar
(ReturnValue
),TypeVar
(ReturnValue
)]]) – a Python callable representing the custom JVP rule. When there are nonondiff_argnums
, thejvp
function should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of parameters of thecustom_jvp
function. Thejvp
function should produce as output a pair where the first element is the primal output and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof.symbolic_zeros (
bool
) – boolean, indicating whether the rule should be passed objects representing static symbolic zeros in its tangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed. Setting this option toTrue
allows a JVP rule to detect whether certain inputs are not involved in differentiation, but at the cost of needing special handling for these objects (which e.g. can’t be passed into jax.numpy functions). DefaultFalse
.
- Return type:
Callable
[...
,Tuple
[TypeVar
(ReturnValue
),TypeVar
(ReturnValue
)]]- Returns:
None.
Example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
- defjvps(*jvps)[source]#
Convenience wrapper for defining JVPs for each argument separately.
This convenience wrapper cannot be used together with
nondiff_argnums
.- Parameters:
*jvps (
Optional
[Callable
[...
,TypeVar
(ReturnValue
)]]) – a sequence of functions, one for each positional argument of thecustom_jvp
function. Each function takes as arguments the tangent value for the corresponding primal input, the primal output, and the primal inputs. See the example below.- Returns:
None.
Example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
- class scico.custom_vjp(fun, nondiff_argnums=())#
Bases:
Generic
[ReturnValue
]Set up a JAX-transformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like
jax.grad
) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method,defvjp
, which may be used to define the custom VJP rule.This decorator precludes the use of forward-mode automatic differentiation.
For example:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
For a more detailed introduction, see the tutorial.
- defvjp(fwd, bwd, symbolic_zeros=False)[source]#
Define a custom VJP rule for the function represented by this instance.
- Parameters:
fwd (
Callable
[...
,Tuple
[TypeVar
(ReturnValue
),Any
]]) – a Python callable representing the forward pass of the custom VJP rule. When there are nonondiff_argnums
, thefwd
function has the same input signature as the underlying primal function. It should return as output a pair, where the first element represents the primal output and the second element represents any “residual” values to store from the forward pass for use on the backward pass by the functionbwd
. Input arguments and elements of the output pair may be arrays or nested tuples/lists/dicts thereof.bwd (
Callable
[...
,Tuple
[Any
,...
]]) – a Python callable representing the backward pass of the custom VJP rule. When there are nonondiff_argnums
, thebwd
function takes two arguments, where the first is the “residual” values produced on the forward pass byfwd
, and the second is the output cotangent with the same structure as the primal function output. The output ofbwd
must be a tuple of length equal to the number of arguments of the primal function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments.symbolic_zeros (
bool
) –boolean, determining whether to indicate symbolic zeros to the
fwd
andbwd
rules. Enabling this option allows custom derivative rules to detect when certain inputs, and when certain output cotangents, are not involved in differentiation. IfTrue
:fwd
must accept, in place of each leaf valuex
in the pytree comprising an argument to the original function, an object with two attributes instead:value
andperturbed
. Thevalue
field is the original primal argument, andperturbed
is a boolean. Theperturbed
bit indicates whether the argument is involved in differentiation (i.e., if it isFalse
, then the corresponding Jacobian “column” is zero).bwd
will be passed objects representing static symbolic zeros in its cotangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed.
Setting this option to
True
allows these rules to detect whether certain inputs and outputs are not involved in differentiation, but at the cost of special handling. For instance:The signature of
fwd
changes, and the objects it is passed cannot be output from the rule directly.The
bwd
rule is passed objects that are not entirely array-like, and that cannot be passed to mostjax.numpy
functions.Any custom pytree nodes involved in the primal function’s arguments must accept, in their unflattening functions, the two-field record objects that are given as input leaves to the
fwd
rule.
Default
False
.
- Return type:
- Returns:
None.
Example:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
- scico.hessian(fun, argnums=0, has_aux=False, holomorphic=False)#
Hessian of
fun
as a dense array.- Parameters:
fun (
Callable
) – Function whose Hessian is to be computed. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.
- Return type:
- Returns:
A function with the same arguments as
fun
, that evaluates the Hessian offun
.
>>> import jax >>> >>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. -2.] [ -2. -480.]]
hessian
is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure ofjax.hessian(fun)(x)
is given by forming a tree product of the structure offun(x)
with a tree product of two copies of the structure ofx
. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:>>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': Array([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of
jax.hessian(fun)(x)
corresponds to a leaf offun(x)
and a pair of leaves ofx
. For each leaf injax.hessian(fun)(x)
, if the corresponding array leaf offun(x)
has shape(out_1, out_2, ...)
and the corresponding array leaves ofx
have shape(in_1_1, in_1_2, ...)
and(in_2_1, in_2_2, ...)
respectively, then the Hessian leaf has shape(out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)
. In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.In particular, an array is produced (with no pytrees involved) when the function input
x
and outputfun(x)
are each a single array, as in theg
example above. Iffun(x)
has shape(out1, out2, ...)
andx
has shape(in1, in2, ...)
thenjax.hessian(fun)(x)
has shape(out1, out2, ..., in1, in2, ..., in1, in2, ...)
. To flatten pytrees into 1D vectors, consider usingjax.flatten_util.flatten_pytree
.
- scico.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)#
Jacobian of
fun
evaluated column-by-column using forward-mode AD.- Parameters:
fun (
Callable
) – Function whose Jacobian is to be computed.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.
- Return type:
- Returns:
A function with the same arguments as
fun
, that evaluates the Jacobian offun
using forward-mode automatic differentiation. Ifhas_aux
is True then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]
- scico.jvp(fun, primals, tangents, has_aux=False)#
Computes a (forward-mode) Jacobian-vector product of
fun
.- Parameters:
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun
.tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals
.has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
- Return type:
- Returns:
If
has_aux
isFalse
, returns a(primals_out, tangents_out)
pair, whereprimals_out
isfun(*primals)
, andtangents_out
is the Jacobian-vector product offunction
evaluated atprimals
withtangents
. Thetangents_out
value has the same Python tree structure and shapes asprimals_out
. Ifhas_aux
isTrue
, returns a(primals_out, tangents_out, aux)
tuple whereaux
is the auxiliary data returned byfun
.
For example:
>>> import jax >>> >>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(primals) 0.09983342 >>> print(tangents) 0.19900084
- scico.linearize(fun, *primals, has_aux=False)#
Produces a linear approximation to
fun
usingjvp
and partial eval.- Parameters:
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters offun
.has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be linearized, and the second is auxiliary data. Default False.
- Return type:
- Returns:
If
has_aux
isFalse
, returns a pair where the first element is the value off(*primals)
and the second element is a function that evaluates the (forward-mode) Jacobian-vector product offun
evaluated atprimals
without re-doing the linearization work. Ifhas_aux
isTrue
, returns a(primals_out, lin_fn, aux)
tuple whereaux
is the auxiliary data returned byfun
.
In terms of values computed,
linearize
behaves much like a curriedjvp
, where these two code blocks compute the same values:y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that
linearize
uses partial evaluation so that the functionf
is not re-linearized on calls tof_jvp
. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed,linearize
has a similar signature tovjp
!)This function is mainly useful if you want to apply
f_jvp
multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize usingvmap
, as in:pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using
vmap
andjvp
together like this we avoid the stored-linearization memory cost that scales with the depth of the computation, which is incurred by bothlinearize
andvjp
.Here’s a more complete example of using
linearize
:>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704
- scico.vjp(fun, *primals, has_aux=False, reduce_axes=())#
Compute a (reverse-mode) vector-Jacobian product of
fun
.grad
is implemented as a special case ofvjp
.- Parameters:
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – A sequence of primal values at which the Jacobian of
fun
should be evaluated. The number ofprimals
should be equal to the number of positional parameters offun
. Each primal value should be an array, a scalar, or a pytree (standard Python containers) thereof.has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the VJP will be per-example over named axes. For example, if'batch'
is a named batch axis,vjp(f, *args, reduce_axes=('batch',))
will create a VJP function that sums over the batch whilevjp(f, *args)
will create a per-example VJP.
- Return type:
- Returns:
If
has_aux
isFalse
, returns a(primals_out, vjpfun)
pair, whereprimals_out
isfun(*primals)
. Ifhas_aux
isTrue
, returns a(primals_out, vjpfun, aux)
tuple whereaux
is the auxiliary data returned byfun
.vjpfun
is a function from a cotangent vector with the same shape asprimals_out
to a tuple of cotangent vectors with the same number and shapes asprimals
, representing the vector-Jacobian product offun
evaluated atprimals
.
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((-0.7, 0.3)) >>> print(xbar) -0.61430776 >>> print(ybar) -0.2524413
- scico.cvjp(fun, *primals, jidx=None)#
Compute a vector-Jacobian product with conjugate transpose.
Compute the product \([J(\mb{x})]^H \mb{v}\) where \([J(\mb{x})]\) is the Jacobian of function fun evaluated at \(\mb{x}\). Instead of directly evaluating the product, a function is returned that takes \(\mb{v}\) as an argument. If fun has multiple positional parameters, the Jacobian can be taken with respect to only one of them by setting the jidx parameter of this function to the positional index of that parameter.
- Parameters:
fun (
Callable
) – Function for which the Jacobian is implicitly computed.primals – Sequence of values at which the Jacobian is evaluated, with length equal to the number of positional arguments of fun.
jidx (
Optional
[int
]) – Index of the positional parameter of fun with respect to which the Jacobian is taken.
- Return type:
- Returns:
A pair (primals_out, conj_vjp) where primals_out is the output of fun evaluated at primals, i.e. primals_out = fun(*primals), and conj_vjp is a function that computes the product of the conjugate (Hermitian) transpose of the Jacobian of fun and its argument. If the jidx parameter is an integer, then the Jacobian is only taken with respect to the coresponding positional parameter of fun.
- scico.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)#
Create a function that evaluates the gradient of fun.
scico.grad
differs fromjax.grad
in that the output is conjugated.Docstring for
jax.grad
:- Parameters:
fun (
Callable
) – Function to be differentiated. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified byargnums
must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if'batch'
is a named batch axis,grad(f, reduce_axes=('batch',))
will create a function that computes the total gradient whilegrad(f)
will create one that computes the per-example gradient.
- Return type:
- Returns:
A function with the same arguments as
fun
, that evaluates the gradient offun
. Ifargnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_aux
is True then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043
- scico.jacrev(fun, argnums=0, holomorphic=False, allow_int=False)#
Jacobian of fun evaluated row-by-row using reverse-mode AD.
scico.jacrev
differs fromjax.jacrev
in that the output is conjugated.Docstring for
jax.jacrev
:- Parameters:
fun (
Callable
) – Function whose Jacobian is to be computed.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).has_aux – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
- Return type:
- Returns:
A function with the same arguments as
fun
, that evaluates the Jacobian offun
using reverse-mode automatic differentiation. Ifhas_aux
is True then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]
- scico.linear_adjoint(fun, *primals)#
Conjugate transpose a function that is guaranteed to be linear.
scico.linear_adjoint
differs fromjax.linear_transpose
for complex inputs in that the conjugate transpose (adjoint) of fun is returned.scico.linear_adjoint
is identical tojax.linear_transpose
for real-valued primals.Docstring for
jax.linear_adjoint
:For linear functions, this transformation is equivalent to
vjp
, but avoids the overhead of computing the forward pass.The outputs of the transposed function will always have the exact same dtypes as
primals
, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes inprimals
that match the full range of desired outputs from the transposed function. Integer dtypes are not supported.- Parameters:
fun (
Callable
) – the linear function to be transposed.*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals)
. These arguments may be real scalars/ndarrays, but that is not required: only theshape
anddtype
attributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding cotangent. Otherwise, the transposed function will be per-example over named axes. For example, if'batch'
is a named batch axis,linear_transpose(f, *args, reduce_axes=('batch',))
will create a transpose function that sums over the batch whilelinear_transpose(f, args)
will create a per-example transpose.
- Return type:
- Returns:
A callable that calculates the transpose of
fun
. Valid input into this function must have the same shape/dtypes/structure as the result offun(*primals)
. Output will be a tuple, with the same shape/dtypes/structure asprimals
.
>>> import jax >>> import types >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
- scico.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)#
Create a function that evaluates both fun and its gradient.
scico.value_and_grad
differs fromjax.value_and_grad
in that the gradient is conjugated.Docstring for
jax.value_and_grad
:- Parameters:
fun (
Callable
) – Function to be differentiated. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if'batch'
is a named batch axis,value_and_grad(f, reduce_axes=('batch',))
will create a function that computes the total gradient whilevalue_and_grad(f)
will create one that computes the per-example gradient.
- Return type:
- Returns:
A function with the same arguments as
fun
that evaluates bothfun
and the gradient offun
and returns them as a pair (a two-element tuple). Ifargnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_aux
is True then a tuple of ((value, auxiliary_data), gradient) is returned.