Notes#

No GPU/TPU Warning#

JAX currently issues a warning when used on a platform without a GPU. To disable this warning, set the environment variable JAX_PLATFORM_NAME=cpu before running Python. This warning is suppressed by SCICO for JAX versions after 0.3.23, making use of the environment variable unnecessary.

Debugging#

If difficulties are encountered in debugging jitted functions, jit can be globally disabled by setting the environment variable JAX_DISABLE_JIT=1 before running Python, as in

JAX_DISABLE_JIT=1 python test_script.py

Double Precision#

By default, JAX enforces single-precision numbers. Double precision can be enabled in one of two ways:

  1. Setting the environment variable JAX_ENABLE_X64=TRUE before launching Python.

  2. Manually setting the jax_enable_x64 flag at program startup; that is, before importing SCICO.

from jax.config import config
config.update("jax_enable_x64", True)
import scico # continue as usual

For more information, see the JAX notes on double precision.

Random Number Generation#

JAX implements an explicit, non-stateful pseudorandom number generator (PRNG). The user is responsible for generating a PRNG key and mutating it each time a new random number is generated. We recommend users read the JAX documentation for information on the design of JAX random number functionality.

In scico.random we provide convenient wrappers around several jax.random routines to handle the generation and splitting of PRNG keys.

# Calls to scico.random functions always return a PRNG key
# If no key is passed to the function, a new key is generated
x, key = scico.random.randn((2,))
print(x)   # [ 0.19307713 -0.52678305]

# scico.random functions automatically split the PRNGkey and return
# an updated key
y, key = scico.random.randn((2,), key=key)
print(y) # [ 0.00870693 -0.04888531]

The user is responsible for passing the PRNG key to scico.random functions. If no key is passed, repeated calls to scico.random functions will return the same random numbers:

x, key = scico.random.randn((2,))
print(x)   # [ 0.19307713 -0.52678305]

# No key passed, will return the same random numbers!
y, key = scico.random.randn((2,))
print(y)   # [ 0.19307713 -0.52678305]

Compiled Dependency Packages#

The code acceleration and automatic differentiation features of JAX are not available for some components of SCICO that are provided via interfaces to compiled C code. When these components are used on a platform with GPUs, the remainder of the code will run on a GPU, but there is potential for a considerable delay due to host-GPU memory transfers. This issue primarily affects:

Denoisers#

The bm3d and bm4d denoisers (and the corresponding BM3D and BM4D pseudo-functionals) are implemented via interfaces to the bm3d and bm4d packages respectively. The DnCNN denoiser (and the corresponding DnCNN pseudo-functional) denoiser should be used when the full benefits of JAX-based code are required.

Tomographic Projectors/Radon Transforms#

Note that the tomographic projections that are frequently referred to as Radon transforms are referred to as X-ray transforms in SCICO. While the Radon transform is far more well-known than the X-ray transform, which is the same as the Radon transform for projections in two dimensions, these two transform differ in higher numbers of dimensions, and it is the X-ray transform that is the appropriate mathematical model for beam attenuation based imaging in three or more dimensions.

SCICO includes three different implementations of X-ray transforms. Of these, linop.XRayTransform is an integral component of SCICO, while the other two depend on external packages. The xray.svmbir.XRayTransform class is implemented via an interface to the svmbir package. The xray.astra.XRayTransform class is implemented via an interface to the ASTRA toolbox. This toolbox does provide some GPU acceleration support, but efficiency is expected to be lower than JAX-based code due to host-GPU memory transfers.

Automatic Differentiation Caveats#

Complex Functions#

The JAX-defined gradient of a complex-valued function is a complex-conjugated version of the usual gradient used in mathematical optimization and computational imaging. Minimizing a function using the JAX convention involves taking steps in the direction of the complex conjugated gradient.

The function scico.grad returns the expected gradient, that is, the conjugate of the JAX gradient. For further discussion, see this JAX issue.

As a concrete example, consider the function \(f(x) = \frac{1}{2}\norm{\mb{A} \mb{x}}_2^2\) where \(\mb{A}\) is a complex matrix. The gradient of \(f\) is usually given \((\nabla f)(\mb{x}) = \mb{A}^H \mb{A} \mb{x}\), where \(\mb{A}^H\) is the conjugate transpose of \(\mb{A}\). Applying jax.grad to \(f\) will yield \((\mb{A}^H \mb{A} \mb{x})^*\), where \(\cdot^*\) denotes complex conjugation.

The following code demonstrates the use of jax.grad and scico.grad:

m, n = (4, 3)
A, key = randn((m, n), dtype=np.complex64, key=None)
x, key = randn((n,), dtype=np.complex64, key=key)

def f(x):
    return 0.5 * snp.linalg.norm(A @ x)**2

an_grad = A.conj().T @ A @ x  # The expected gradient

np.testing.assert_allclose(jax.grad(f)(x), an_grad.conj(), rtol=1e-4)
np.testing.assert_allclose(scico.grad(f)(x), an_grad, rtol=1e-4)

Non-differentiable Functionals#

scico.grad can be applied to any function, but has undefined behavior for non-differentiable functions. For non-differerentiable functions, scico.grad may or may not return a valid subgradient. As an example, scico.grad(snp.abs)(0.) = 0, which is a valid subgradient. However, scico.grad(snp.linalg.norm)([0., 0.]) = [nan, nan].

Differentiable functions that are written as the composition of a differentiable and non-differentiable function should be avoided. As an example, \(f(x) = \norm{x}_2^2\) can be implemented in as f = lambda x: snp.linalg.norm(x)**2. This involves first calculating the non-squared \(\ell_2\) norm, then squaring it. The un-squared \(\ell_2\) norm is not differentiable at zero. When evaluating the gradient of f at 0, scico.grad returns NaN:

>>> import scico
>>> import scico.numpy as snp
>>> f = lambda x: snp.linalg.norm(x)**2
>>> scico.grad(f)(snp.zeros(2, dtype=snp.float32))  
Array([nan, nan], dtype=float32)

This can be fixed (assuming real-valued arrays only) by defining the squared \(\ell_2\) norm directly as g = lambda x: snp.sum(x**2). The gradient will work as expected:

>>> g = lambda x: snp.sum(x**2)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32))  
Array([0., 0.], dtype=float32)

If complex-valued arrays also need to be supported, a minor modification is necessary:

>>> g = lambda x: snp.sum(snp.abs(x)**2)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32))  
Array([0., 0.], dtype=float32)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.complex64))  
Array([0.-0.j, 0.-0.j], dtype=complex64)

An alternative is to define a custom derivative rule to enforce a particular derivative convention at a point.

JAX Arrays#

JAX utilizes a new array type Array, which is similar to NumPy ndarray, but can be backed by CPU, GPU, or TPU memory and is immutable.

JAX and NumPy Arrays#

SCICO and JAX functions can be applied directly to NumPy arrays without explicit conversion to JAX arrays, but this is not recommended, as it can result in repeated data transfers from the CPU to GPU. Consider this toy example on a system with a GPU present:

x = np.random.randn(8)    # Array on host
A = np.random.randn(8, 8) # Array on host
y = snp.dot(A, x)         # A, x transfered to GPU
                          # y resides on GPU
z = y + x                 # x must be transfered to GPU again

The unnecessary transfer can be avoided by first converting A and x to JAX arrays:

x = np.random.randn(8)    # array on host
A = np.random.randn(8, 8) # array on host
x = jax.device_put(x)     # transfer to GPU
A = jax.device_put(A)
y = snp.dot(A, x)         # no transfer needed
z = y + x                 # no transfer needed

We recommend that input data be converted to JAX arrays via jax.device_put before calling any SCICO optimizers.

On a multi-GPU system, jax.device_put can place data on a specific GPU. See the JAX notes on data placement.

JAX Arrays are Immutable#

Unlike standard NumPy arrays, JAX arrays are immutable: once they have been created, they cannot be changed. This prohibits in-place updating of JAX arrays. JAX provides special syntax for updating individual array elements through the indexed update operators.