Source code for scico.linop.xray.symcone
# -*- coding: utf-8 -*-
# Copyright (C) 2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Cone beam X-ray transform for cylindrically symmetric objects.
Cone beam X-ray transform and FDK reconstruction for cylindrically
symmetric objects; essentialy a cone-beam variant of the Abel transform.
The implementation is based on code modified from the
`axitom <https://github.com/PolymerGuy/AXITOM>`_ package
:cite:`olufsen-2019-axitom`.
"""
from functools import partial
from typing import Optional, Tuple
import numpy as np
import jax.numpy as jnp
from jax import Array, jit, vjp
from jax.scipy.ndimage import map_coordinates
from jax.typing import ArrayLike
from scico.typing import DType, Shape
from .._linop import LinearOperator
from ._axitom import backprojection, config, projection
@partial(jit, static_argnames=["axis", "center"])
def _volume_by_axial_symmetry(
x: Array, axis: int = 0, center: Optional[int] = None, zrange: Optional[Array] = None
) -> Array:
"""Create a volume by axial rotation of a plane.
Args:
x: 2D array that is rotated about an axis to generate a volume.
axis: Index of axis of symmetry (must be 0 or 1).
center: Location of the axis of symmetry on the other axis. If
``None``, defaults to center of that axis. Otherwise identifies
the center coordinate on that axis.
zrange: 1D array of points at which the extended axis is
constructed. Defaults to the same as for axis :code:`1 - axis`.
Returns:
Volume as a 3D array.
"""
N0, N1 = x.shape
N0h, N1h = (N0 + 1) / 2 - 1, (N1 + 1) / 2 - 1
half_shape = (N0h, N1h)
if zrange is None:
N2 = x.shape[1 - axis]
N2h = (N2 + 1) / 2 - 1
zrange = jnp.arange(-N2h, N2h + 1)
if axis == 0:
g1d = [np.arange(0, N0), jnp.arange(-N1h, N1h + 1), zrange]
else:
g1d = [np.arange(-N0h, N0h + 1), jnp.arange(0, N1), zrange]
if center is None:
offset = 0
else:
offset = center - half_shape[1 - axis]
g0, g1, g2 = jnp.meshgrid(*g1d, indexing="ij")
grids = (g0, g1, g2)
r = jnp.hypot(grids[1 - axis], g2)
sym_ax_crd = jnp.where(
grids[1 - axis] >= 0, half_shape[1 - axis] + offset + r, half_shape[1 - axis] + offset - r
)
if axis == 0:
coords = [grids[axis], sym_ax_crd]
else:
coords = [sym_ax_crd, grids[axis]]
v = map_coordinates(x, coords, cval=0.0, order=1)
return v
[docs]
class AxiallySymmetricVolume(LinearOperator):
"""Create a volume by axial rotation of a plane."""
def __init__(
self,
input_shape: Shape,
input_dtype: DType = np.float32,
axis: int = 0,
center: Optional[int] = None,
):
"""
Args:
input_shape: Input image shape.
input_dtype: Input image dtype.
axis: Index of axis of symmetry (must be 0 or 1).
center: If ``None``, defaults to the center of the image on
the specified axis. Otherwise identifies the center
coordinate on that axis.
"""
self.axis = axis
self.center = center
output_shape = input_shape + (input_shape[axis],)
super().__init__(
input_shape=input_shape,
output_shape=output_shape,
input_dtype=input_dtype,
output_dtype=input_dtype,
eval_fn=lambda x: _volume_by_axial_symmetry(x, axis=self.axis, center=self.center),
jit=True,
)
[docs]
class SymConeXRayTransform(LinearOperator):
"""Cone beam X-ray transform for cylindrically symmetric objects.
Cone beam X-ray transform of a cylindrically symmetric volume, which
may be represented by a 2D central slice, which is rotated about
the specified axis to generate a 3D volume for projection.
The implementation is based on code modified from the AXITOM package
:cite:`olufsen-2019-axitom`..
"""
def __init__(
self,
input_shape: Shape,
obj_dist: float,
det_dist: float,
axis: int = 0,
pixel_size: Optional[Tuple[float, float]] = None,
num_slabs: int = 1,
):
"""
Args:
input_shape: Shape of the input array. If 2D, the input is
extended to 3D (onto a new axis 1) by cylindrical symmetry.
obj_dist: Source-object distance in arbitary length units (ALU).
det_dist: Source-detector distance in ALU.
axis: Index of axis of symmetry (must be 0 or 1).
pixel_size: Tuple of pixel size values in ALU.
num_slabs: Number of slabs into which the volume should be
divided (for serial processing, to limit memory usage) in
the imaging direction.
"""
if len(input_shape) == 2:
self.input_2d = True
output_shape = input_shape[::-1]
else:
self.input_2d = False
output_shape = (input_shape[2], input_shape[0])
if pixel_size is None:
pixel_size = (1.0, 1.0)
self.axis = axis
self.config = config.Config(*output_shape, *pixel_size, det_dist, obj_dist)
self.num_slabs = num_slabs
if len(input_shape) == 2 and axis == 1:
eval_fn = lambda x: projection.forward_project(
x.T, self.config, num_slabs=self.num_slabs, input_2d=self.input_2d
).T
else:
eval_fn = lambda x: projection.forward_project(
x, self.config, num_slabs=self.num_slabs, input_2d=self.input_2d
)
# use vjp rather than linear_transpose due to jax-ml/jax#30552
adj_fn = vjp(eval_fn, jnp.zeros(input_shape))[1]
super().__init__(
input_shape=input_shape,
output_shape=output_shape,
input_dtype=np.float32,
output_dtype=np.float32,
eval_fn=eval_fn,
adj_fn=lambda x: adj_fn(x)[0],
jit=True,
)
[docs]
def fdk(self, y: ArrayLike, num_angles: int = 360):
"""Reconstruct central slice from projection.
Reconstruct the central slice of the cylindrically symmetric
volume from a projection. The reconstruction makes use of the
Feldkamp David Kress (FDK) algorithm implemented in the
`axitom <https://github.com/PolymerGuy/AXITOM>`_ package.
Args:
y: The projection to be reconstructed.
num_angles: Number of angles to be averaged in the
reconstruction.
Returns:
Reconstruction of the central slice of the volume.
"""
angles = jnp.linspace(0, 360, num_angles, endpoint=False)
x = backprojection.fdk(y if self.axis == 1 else y.T, self.config, angles)
return x if self.axis == 1 else x.T