API Reference#

Scientific Computational Imaging COde (SCICO) is a Python package for solving the inverse problems that arise in scientific imaging applications.

scico.data

Data files for usage examples.

scico.denoiser

Interfaces to standard denoisers.

scico.diagnostics

Diagnostic information for iterative solvers.

scico.examples

Utility functions used by example scripts.

scico.flax

Neural network models implemented in Flax and utility functions.

scico.function

Function class.

scico.functional

Functionals and functionals classes.

scico.linop

Linear operator functions and classes.

scico.loss

Loss function classes.

scico.metric

Image quality metrics and related functions.

scico.numpy

BlockArray and compatible functions.

scico.operator

Operator functions and classes.

scico.optimize

Optimization algorithms.

scico.plot

Plotting/visualization functions.

scico.random

Random number generation.

scico.ray

Simplified interfaces to ray.

scico.scipy

Wrapped versions of jax.scipy functions.

scico.solver

Solver and optimization algorithms.

scico.typing

Type definitions.

scico.util

General utility functions.

class scico.custom_jvp(fun, nondiff_argnums=())#

Bases: Generic[ReturnValue]

Set up a JAX-transformable function for a custom JVP rule definition.

This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like jax.jvp or jax.grad) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.

There are two instance methods available for defining the custom JVP rule: defjvp for defining a single custom JVP rule for all the function’s inputs, and for convenience defjvps, which wraps defjvp, and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.

For example:

@jax.custom_jvp
def f(x, y):
  return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out

For a more detailed introduction, see the tutorial.

__call__(*args, **kwargs)[source]#

Call self as a function.

Return type:

TypeVar(ReturnValue)

defjvp(jvp, symbolic_zeros=False)[source]#

Define a custom JVP rule for the function represented by this instance.

Parameters:
  • jvp (Callable[..., Tuple[TypeVar(ReturnValue), TypeVar(ReturnValue)]]) – a Python callable representing the custom JVP rule. When there are no nondiff_argnums, the jvp function should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of parameters of the custom_jvp function. The jvp function should produce as output a pair where the first element is the primal output and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof.

  • symbolic_zeros (bool) – boolean, indicating whether the rule should be passed objects representing static symbolic zeros in its tangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed. Setting this option to True allows a JVP rule to detect whether certain inputs are not involved in differentiation, but at the cost of needing special handling for these objects (which e.g. can’t be passed into jax.numpy functions). Default False.

Return type:

Callable[..., Tuple[TypeVar(ReturnValue), TypeVar(ReturnValue)]]

Returns:

None.

Example:

@jax.custom_jvp
def f(x, y):
  return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out
defjvps(*jvps)[source]#

Convenience wrapper for defining JVPs for each argument separately.

This convenience wrapper cannot be used together with nondiff_argnums.

Parameters:

*jvps (Optional[Callable[..., TypeVar(ReturnValue)]]) – a sequence of functions, one for each positional argument of the custom_jvp function. Each function takes as arguments the tangent value for the corresponding primal input, the primal output, and the primal inputs. See the example below.

Returns:

None.

Example:

@jax.custom_jvp
def f(x, y):
  return jnp.sin(x) * y

f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
          lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
class scico.custom_vjp(fun, nondiff_argnums=())#

Bases: Generic[ReturnValue]

Set up a JAX-transformable function for a custom VJP rule definition.

This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like jax.grad) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defvjp, which may be used to define the custom VJP rule.

This decorator precludes the use of forward-mode automatic differentiation.

For example:

@jax.custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)

For a more detailed introduction, see the tutorial.

__call__(*args, **kwargs)[source]#

Call self as a function.

Return type:

TypeVar(ReturnValue)

defvjp(fwd, bwd, symbolic_zeros=False)[source]#

Define a custom VJP rule for the function represented by this instance.

Parameters:
  • fwd (Callable[..., Tuple[TypeVar(ReturnValue), Any]]) – a Python callable representing the forward pass of the custom VJP rule. When there are no nondiff_argnums, the fwd function has the same input signature as the underlying primal function. It should return as output a pair, where the first element represents the primal output and the second element represents any “residual” values to store from the forward pass for use on the backward pass by the function bwd. Input arguments and elements of the output pair may be arrays or nested tuples/lists/dicts thereof.

  • bwd (Callable[..., Tuple[Any, ...]]) – a Python callable representing the backward pass of the custom VJP rule. When there are no nondiff_argnums, the bwd function takes two arguments, where the first is the “residual” values produced on the forward pass by fwd, and the second is the output cotangent with the same structure as the primal function output. The output of bwd must be a tuple of length equal to the number of arguments of the primal function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments.

  • symbolic_zeros (bool) –

    boolean, determining whether to indicate symbolic zeros to the fwd and bwd rules. Enabling this option allows custom derivative rules to detect when certain inputs, and when certain output cotangents, are not involved in differentiation. If True:

    • fwd must accept, in place of each leaf value x in the pytree comprising an argument to the original function, an object with two attributes instead: value and perturbed. The value field is the original primal argument, and perturbed is a boolean. The perturbed bit indicates whether the argument is involved in differentiation (i.e., if it is False, then the corresponding Jacobian “column” is zero).

    • bwd will be passed objects representing static symbolic zeros in its cotangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed.

    Setting this option to True allows these rules to detect whether certain inputs and outputs are not involved in differentiation, but at the cost of special handling. For instance:

    • The signature of fwd changes, and the objects it is passed cannot be output from the rule directly.

    • The bwd rule is passed objects that are not entirely array-like, and that cannot be passed to most jax.numpy functions.

    • Any custom pytree nodes involved in the primal function’s arguments must accept, in their unflattening functions, the two-field record objects that are given as input leaves to the fwd rule.

    Default False.

Return type:

None

Returns:

None.

Example:

@jax.custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)
scico.hessian(fun, argnums=0, has_aux=False, holomorphic=False)#

Hessian of fun as a dense array.

Parameters:
  • fun (Callable) – Function whose Hessian is to be computed. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.

  • argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Return type:

Callable

Returns:

A function with the same arguments as fun, that evaluates the Hessian of fun.

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
 [  -2. -480.]]

hessian is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure of jax.hessian(fun)(x) is given by forming a tree product of the structure of fun(x) with a tree product of two copies of the structure of x. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2.,  0.], [ 0.,  0.]],
                         [[ 0.,  0.], [ 0., 12.]]], dtype=float32),
             'b': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
       'b': {'a': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
             'b': Array([[[0.      , 0.      ], [0.      , 0.      ]],
                         [[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}

Thus each leaf in the tree structure of jax.hessian(fun)(x) corresponds to a leaf of fun(x) and a pair of leaves of x. For each leaf in jax.hessian(fun)(x), if the corresponding array leaf of fun(x) has shape (out_1, out_2, ...) and the corresponding array leaves of x have shape (in_1_1, in_1_2, ...) and (in_2_1, in_2_2, ...) respectively, then the Hessian leaf has shape (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...). In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.

In particular, an array is produced (with no pytrees involved) when the function input x and output fun(x) are each a single array, as in the g example above. If fun(x) has shape (out1, out2, ...) and x has shape (in1, in2, ...) then jax.hessian(fun)(x) has shape (out1, out2, ..., in1, in2, ..., in1, in2, ...). To flatten pytrees into 1D vectors, consider using jax.flatten_util.flatten_pytree.

scico.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)#

Jacobian of fun evaluated column-by-column using forward-mode AD.

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed.

  • argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Return type:

Callable

Returns:

A function with the same arguments as fun, that evaluates the Jacobian of fun using forward-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]
scico.jvp(fun, primals, tangents, has_aux=False)#

Computes a (forward-mode) Jacobian-vector product of fun.

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.

  • primals – The primal values at which the Jacobian of fun should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of fun.

  • tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as primals.

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

Return type:

Tuple[Any, ...]

Returns:

If has_aux is False, returns a (primals_out, tangents_out) pair, where primals_out is fun(*primals), and tangents_out is the Jacobian-vector product of function evaluated at primals with tangents. The tangents_out value has the same Python tree structure and shapes as primals_out. If has_aux is True, returns a (primals_out, tangents_out, aux) tuple where aux is the auxiliary data returned by fun.

For example:

>>> import jax
>>>
>>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
>>> print(primals)
0.09983342
>>> print(tangents)
0.19900084
scico.linearize(fun, *primals, has_aux=False)#

Produces a linear approximation to fun using jvp and partial eval.

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.

  • primals – The primal values at which the Jacobian of fun should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters of fun.

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be linearized, and the second is auxiliary data. Default False.

Return type:

Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]

Returns:

If has_aux is False, returns a pair where the first element is the value of f(*primals) and the second element is a function that evaluates the (forward-mode) Jacobian-vector product of fun evaluated at primals without re-doing the linearization work. If has_aux is True, returns a (primals_out, lin_fn, aux) tuple where aux is the auxiliary data returned by fun.

In terms of values computed, linearize behaves much like a curried jvp, where these two code blocks compute the same values:

y, out_tangent = jax.jvp(f, (x,), (in_tangent,))

y, f_jvp = jax.linearize(f, x)
out_tangent = f_jvp(in_tangent)

However, the difference is that linearize uses partial evaluation so that the function f is not re-linearized on calls to f_jvp. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed, linearize has a similar signature to vjp!)

This function is mainly useful if you want to apply f_jvp multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize using vmap, as in:

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))

By using vmap and jvp together like this we avoid the stored-linearization memory cost that scales with the depth of the computation, which is incurred by both linearize and vjp.

Here’s a more complete example of using linearize:

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
>>> print(f_jvp(3.))
-5.007528
>>> print(f_jvp(4.))
-6.676704
scico.vjp(fun, *primals, has_aux=False, reduce_axes=())#

Compute a (reverse-mode) vector-Jacobian product of fun.

grad is implemented as a special case of vjp.

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.

  • primals – A sequence of primal values at which the Jacobian of fun should be evaluated. The number of primals should be equal to the number of positional parameters of fun. Each primal value should be an array, a scalar, or a pytree (standard Python containers) thereof.

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • reduce_axes – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the VJP will be per-example over named axes. For example, if 'batch' is a named batch axis, vjp(f, *args, reduce_axes=('batch',)) will create a VJP function that sums over the batch while vjp(f, *args) will create a per-example VJP.

Return type:

Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]

Returns:

If has_aux is False, returns a (primals_out, vjpfun) pair, where primals_out is fun(*primals). If has_aux is True, returns a (primals_out, vjpfun, aux) tuple where aux is the auxiliary data returned by fun.

vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same number and shapes as primals, representing the vector-Jacobian product of fun evaluated at primals.

>>> import jax
>>>
>>> def f(x, y):
...   return jax.numpy.sin(x), jax.numpy.cos(y)
...
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
>>> xbar, ybar = f_vjp((-0.7, 0.3))
>>> print(xbar)
-0.61430776
>>> print(ybar)
-0.2524413
scico.cvjp(fun, *primals, jidx=None)#

Compute a vector-Jacobian product with conjugate transpose.

Compute the product \([J(\mb{x})]^H \mb{v}\) where \([J(\mb{x})]\) is the Jacobian of function fun evaluated at \(\mb{x}\). Instead of directly evaluating the product, a function is returned that takes \(\mb{v}\) as an argument. If fun has multiple positional parameters, the Jacobian can be taken with respect to only one of them by setting the jidx parameter of this function to the positional index of that parameter.

Parameters:
  • fun (Callable) – Function for which the Jacobian is implicitly computed.

  • primals – Sequence of values at which the Jacobian is evaluated, with length equal to the number of positional arguments of fun.

  • jidx (Optional[int]) – Index of the positional parameter of fun with respect to which the Jacobian is taken.

Return type:

Tuple[Tuple[Any, ...], Callable]

Returns:

A pair (primals_out, conj_vjp) where primals_out is the output of fun evaluated at primals, i.e. primals_out = fun(*primals), and conj_vjp is a function that computes the product of the conjugate (Hermitian) transpose of the Jacobian of fun and its argument. If the jidx parameter is an integer, then the Jacobian is only taken with respect to the coresponding positional parameter of fun.

scico.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)#

Create a function that evaluates the gradient of fun.

scico.grad differs from jax.grad in that the output is conjugated.

Docstring for jax.grad:

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified by argnums must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

  • reduce_axes – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if 'batch' is a named batch axis, grad(f, reduce_axes=('batch',)) will create a function that computes the total gradient while grad(f) will create one that computes the per-example gradient.

Return type:

Callable

Returns:

A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of (gradient, auxiliary_data) is returned.

For example:

>>> import jax
>>>
>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.961043
scico.jacrev(fun, argnums=0, holomorphic=False, allow_int=False)#

Jacobian of fun evaluated row-by-row using reverse-mode AD.

scico.jacrev differs from jax.jacrev in that the output is conjugated.

Docstring for jax.jacrev:

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed.

  • argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

Return type:

Callable

Returns:

A function with the same arguments as fun, that evaluates the Jacobian of fun using reverse-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]
scico.linear_adjoint(fun, *primals)#

Conjugate transpose a function that is guaranteed to be linear.

scico.linear_adjoint differs from jax.linear_transpose for complex inputs in that the conjugate transpose (adjoint) of fun is returned. scico.linear_adjoint is identical to jax.linear_transpose for real-valued primals.

Docstring for jax.linear_adjoint:

For linear functions, this transformation is equivalent to vjp, but avoids the overhead of computing the forward pass.

The outputs of the transposed function will always have the exact same dtypes as primals, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes in primals that match the full range of desired outputs from the transposed function. Integer dtypes are not supported.

Parameters:
  • fun (Callable) – the linear function to be transposed.

  • *primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of fun(*primals). These arguments may be real scalars/ndarrays, but that is not required: only the shape and dtype attributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)

  • reduce_axes – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding cotangent. Otherwise, the transposed function will be per-example over named axes. For example, if 'batch' is a named batch axis, linear_transpose(f, *args, reduce_axes=('batch',)) will create a transpose function that sums over the batch while linear_transpose(f, args) will create a per-example transpose.

Return type:

Callable

Returns:

A callable that calculates the transpose of fun. Valid input into this function must have the same shape/dtypes/structure as the result of fun(*primals). Output will be a tuple, with the same shape/dtypes/structure as primals.

>>> import jax
>>> import types
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
scico.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)#

Create a function that evaluates both fun and its gradient.

scico.value_and_grad differs from jax.value_and_grad in that the gradient is conjugated.

Docstring for jax.value_and_grad:

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

  • reduce_axes – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if 'batch' is a named batch axis, value_and_grad(f, reduce_axes=('batch',)) will create a function that computes the total gradient while value_and_grad(f) will create one that computes the per-example gradient.

Return type:

Callable[..., Tuple[Any, Any]]

Returns:

A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a two-element tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.