Source code for scico.linop.xray.astra
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""X-ray transform LinearOperator wrapping the ASTRA toolbox.
X-ray transform :class:`.LinearOperator` wrapping the parallel beam
projections in the
`ASTRA toolbox <https://github.com/astra-toolbox/astra-toolbox>`_.
This package provides both C and CUDA implementations of core
functionality, but note that use of the CUDA/GPU implementation is
expected to result in GPU-host-GPU memory copies when transferring
JAX arrays. Other JAX features such as automatic differentiation are
not available.
"""
from typing import List, Optional, Tuple, Union
import numpy as np
import jax
try:
import astra
except ModuleNotFoundError as e:
if e.name == "astra":
new_e = ModuleNotFoundError("Could not import astra; please install the ASTRA toolbox.")
new_e.name = "astra"
raise new_e from e
else:
raise e
from scico.typing import Shape
from .._linop import LinearOperator
[docs]class XRayTransform(LinearOperator):
r"""Parallel beam X-ray transform based on the ASTRA toolbox.
Perform tomographic projection (also called X-ray projection) of an
image or volume at specified angles, using the
`ASTRA toolbox <https://github.com/astra-toolbox/astra-toolbox>`_.
"""
def __init__(
self,
input_shape: Shape,
detector_spacing: Union[float, Tuple[float, float]],
det_count: Union[int, Tuple[int, int]],
angles: np.ndarray,
volume_geometry: Optional[List[float]] = None,
device: str = "auto",
):
"""
Args:
input_shape: Shape of the input array. Determines whether 2D
or 3D algorithm is used.
detector_spacing: Spacing between detector elements. See the
astra documentation for more information for
`2d <https://www.astra-toolbox.com/docs/geom2d.html#projection-geometries>`__
or
`3d <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__
geometries.
det_count: Number of detector elements. See the astra
documentation for more information for
`2d <https://www.astra-toolbox.com/docs/geom2d.html#projection-geometries>`__
or
`3d <https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries>`__
geometries.
angles: Array of projection angles in radians.
volume_geometry: Specification of the shape of the
discretized reconstruction volume. Must either ``None``,
in which case it is inferred from `input_shape`, or
follow the astra syntax described in the astra
documentation for
`2d <https://www.astra-toolbox.com/docs/geom2d.html#volume-geometries>`__
or
`3d <https://www.astra-toolbox.com/docs/geom3d.html#d-geometries>`__
geometries.
device: Specifies device for projection operation.
One of ["auto", "gpu", "cpu"]. If "auto", a GPU is used if
available, otherwise, the CPU is used.
"""
self.num_dims = len(input_shape)
if self.num_dims not in [2, 3]:
raise ValueError(
f"Only 2D and 3D projections are supported, but input_shape is {input_shape}."
)
output_shape: Shape
if self.num_dims == 2:
output_shape = (len(angles), det_count)
elif self.num_dims == 3:
assert isinstance(det_count, (list, tuple))
if len(det_count) != 2:
raise ValueError("Expected det_count to have 2 elements")
output_shape = (det_count[0], len(angles), det_count[1])
# Set up all the ASTRA config
self.detector_spacing = detector_spacing
self.det_count = det_count
self.angles: np.ndarray = np.array(angles)
if self.num_dims == 2:
self.proj_geom: dict = astra.create_proj_geom(
"parallel", detector_spacing, det_count, self.angles
)
elif self.num_dims == 3:
assert isinstance(detector_spacing, (list, tuple))
assert isinstance(det_count, (list, tuple))
if len(detector_spacing) != 2:
raise ValueError("Expected detector_spacing to have 2 elements")
self.proj_geom = astra.create_proj_geom(
"parallel3d",
detector_spacing[0],
detector_spacing[1],
det_count[0],
det_count[1],
self.angles,
)
self.proj_id: Optional[int]
self.input_shape: tuple = input_shape
if volume_geometry is not None:
if (self.num_dims == 2 and len(volume_geometry) == 4) or (
self.num_dims == 3 and len(volume_geometry) == 6
):
self.vol_geom: dict = astra.create_vol_geom(*input_shape, *volume_geometry)
else:
raise ValueError(
"volume_geometry must be a tuple of len 4 (2D) or 6 (3D)."
"Please see the astra documentation for details."
)
else:
if self.num_dims == 2:
self.vol_geom = astra.create_vol_geom(*input_shape)
elif self.num_dims == 3:
self.vol_geom = astra.create_vol_geom(
input_shape[1], input_shape[2], input_shape[0]
)
dev0 = jax.devices()[0]
if dev0.platform == "cpu" or device == "cpu":
self.device = "cpu"
elif dev0.platform == "gpu" and device in ["gpu", "auto"]:
self.device = "gpu"
else:
raise ValueError(f"Invalid device specified; got {device}.")
if self.num_dims == 3 and self.device == "cpu":
raise ValueError("No CPU algorithm for 3D projection.")
if self.num_dims == 3:
# not needed for astra's 3D algorithm
self.proj_id = None
elif self.num_dims == 2:
if self.device == "cpu":
self.proj_id = astra.create_projector("line", self.proj_geom, self.vol_geom)
elif self.device == "gpu":
self.proj_id = astra.create_projector("cuda", self.proj_geom, self.vol_geom)
# Wrap our non-jax function to indicate we will supply fwd/rev mode functions
self._eval = jax.custom_vjp(self._proj)
self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) # type: ignore
self._adj = jax.custom_vjp(self._bproj)
self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),)) # type: ignore
super().__init__(
input_shape=self.input_shape,
output_shape=output_shape,
input_dtype=np.float32,
output_dtype=np.float32,
adj_fn=self._adj,
jit=False,
)
def _proj(self, x: jax.Array) -> jax.Array:
# apply the forward projector and generate a sinogram
def f(x):
x = ensure_writeable(x)
if self.num_dims == 2:
proj_id, result = astra.create_sino(x, self.proj_id)
astra.data2d.delete(proj_id)
elif self.num_dims == 3:
proj_id, result = astra.create_sino3d_gpu(x, self.proj_geom, self.vol_geom)
astra.data3d.delete(proj_id)
return result
return jax.pure_callback(f, jax.ShapeDtypeStruct(self.output_shape, self.output_dtype), x)
def _bproj(self, y: jax.Array) -> jax.Array:
# apply backprojector
def f(y):
y = ensure_writeable(y)
if self.num_dims == 2:
proj_id, result = astra.create_backprojection(y, self.proj_id)
astra.data2d.delete(proj_id)
elif self.num_dims == 3:
proj_id, result = astra.create_backprojection3d_gpu(
y, self.proj_geom, self.vol_geom
)
astra.data3d.delete(proj_id)
return result
return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), y)
[docs] def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array:
"""Filtered back projection (FBP) reconstruction.
Perform tomographic reconstruction using the filtered back
projection (FBP) algorithm.
Args:
sino: Sinogram to reconstruct.
filter_type: Select the filter to use. For a list of options
see `cfg.FilterType` in the `ASTRA documentation
<https://www.astra-toolbox.com/docs/algs/FBP_CUDA.html>`__.
"""
if self.num_dims == 3:
raise NotImplementedError("3D FBP is not implemented.")
# Just use the CPU FBP alg for now; hitting memory issues with GPU one.
def f(sino):
sino = ensure_writeable(sino)
sino_id = astra.data2d.create("-sino", self.proj_geom, sino)
# create memory for result
rec_id = astra.data2d.create("-vol", self.vol_geom)
# start to populate config
cfg = astra.astra_dict("FBP")
cfg["ReconstructionDataId"] = rec_id
cfg["ProjectorId"] = self.proj_id
cfg["ProjectionDataId"] = sino_id
cfg["option"] = {"FilterType": filter_type}
# initialize algorithm; run
alg_id = astra.algorithm.create(cfg)
astra.algorithm.run(alg_id)
# get the result
out = astra.data2d.get(rec_id)
# cleanup FBP-specific arra
astra.algorithm.delete(alg_id)
astra.data2d.delete(rec_id)
astra.data2d.delete(sino_id)
return out
return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), sino)
[docs]def ensure_writeable(x):
"""Ensure that `x.flags.writeable` is ``True``, copying if needed."""
if not x.flags.writeable:
try:
x.setflags(write=True)
except ValueError:
x = x.copy()
return x