Source code for scico.linop.xray.svmbir

# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""X-ray transform LinearOperator wrapping the svmbir package.

X-ray transform :class:`.LinearOperator` wrapping the
`svmbir <https://github.com/cabouman/svmbir>`_ package. Since this
package is an interface to compiled C code, JAX features such as
automatic differentiation and support for GPU devices are not available.
"""

from typing import Optional, Tuple, Union

import numpy as np

import jax

import scico.numpy as snp
from scico.loss import Loss, SquaredL2Loss
from scico.typing import Shape

from .._diag import Diagonal, Identity
from .._linop import LinearOperator

try:
    import svmbir
except ImportError:
    raise ImportError("Could not import svmbir; please install it.")


[docs]class XRayTransform(LinearOperator): r"""X-ray transform based on svmbir. Perform tomographic projection of an image at specified angles, using the `svmbir <https://github.com/cabouman/svmbir>`_ package. The `is_masked` option selects whether a valid region for projections (pixels outside this region are ignored when performing the projection) is active. This region of validity is also respected by :meth:`.SVMBIRSquaredL2Loss.prox` when :class:`.SVMBIRSquaredL2Loss` is initialized with a :class:`XRayTransform` with this option enabled. A brief description of the supported scanner geometries can be found in the `svmbir documentation <https://svmbir.readthedocs.io/en/latest/overview.html>`_. Parallel beam geometry and two different fan beam geometries are supported. .. list-table:: * - .. figure:: /figures/geom-parallel.png :align: center :width: 75% Fig 1. Parallel beam geometry. - .. figure:: /figures/geom-fan.png :align: center :width: 75% Fig 2. Curved fan beam geometry. """ def __init__( self, input_shape: Shape, angles: snp.Array, num_channels: int, center_offset: float = 0.0, is_masked: bool = False, geometry: str = "parallel", dist_source_detector: Optional[float] = None, magnification: Optional[float] = None, delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, ): """ The output of this linear operator is an array of shape `(num_angles, num_channels)` when input_shape is 2D, or of shape `(num_angles, num_slices, num_channels)` when input_shape is 3D, where `num_angles` is the length of the `angles` argument, and `num_slices` is inferred from the `input_shape` argument. Most of the the following arguments have the same name as and correspond to arguments of :func:`svmbir.project`. A brief summary of each is provided here, but the documentation for :func:`svmbir.project` should be consulted for further details. Args: input_shape: Shape of the input array. May be of length 2 (a 2D array) or 3 (a 3D array). When specifying a 2D array, the format for the input_shape is `(num_rows, num_cols)`. For a 3D array, the format for the input_shape is `(num_slices, num_rows, num_cols)`, where `num_slices` denotes the number of slices in the input, and `num_rows` and `num_cols` denote the number of rows and columns in a single slice of the input. A slice is a plane perpendicular to the axis of rotation of the tomographic system. At angle zero, each row is oriented along the X-rays (parallel beam) or the X-ray beam directed toward the detector center (fan beam). Note that `input_shape=(num_rows, num_cols)` and `input_shape=(1, num_rows, num_cols)` result in the same underlying projector. angles: Array of projection angles in radians, should be increasing. num_channels: Number of detector channels in the sinogram data. center_offset: Position of the detector center relative to the projection of the center of rotation onto the detector, in units of pixels. is_masked: If ``True``, the valid region of the image is determined by a mask defined as the circle inscribed within the image boundary. Otherwise, the whole image array is taken into account by projections. geometry: Scanner geometry, either "parallel", "fan-curved", or "fan-flat". Note that the `dist_source_detector` and `magnification` arguments must be provided for then fan beam geometries. dist_source_detector: Distance from X-ray focal spot to detectors in units of pixel pitch. Only used when geometry is "fan-flat" or "fan-curved". magnification: Magnification factor of the scanner geometry. Only used when geometry is "fan-flat" or "fan-curved". delta_channel: Detector channel spacing. delta_pixel: Spacing between image pixels in the 2D slice plane. """ self.angles = angles self.num_channels = num_channels self.center_offset = center_offset if len(input_shape) == 2: # 2D input self.svmbir_input_shape = (1,) + input_shape output_shape: Tuple[int, ...] = (len(angles), num_channels) self.svmbir_output_shape = output_shape[0:1] + (1,) + output_shape[1:2] elif len(input_shape) == 3: # 3D input self.svmbir_input_shape = input_shape output_shape = (len(angles), input_shape[0], num_channels) self.svmbir_output_shape = output_shape else: raise ValueError( f"Only 2D and 3D inputs are supported, but input_shape was {input_shape}." ) self.is_masked = is_masked if self.is_masked: self.roi_radius = None else: self.roi_radius = max(self.svmbir_input_shape[1], self.svmbir_input_shape[2]) self.geometry = geometry self.dist_source_detector = dist_source_detector self.magnification = magnification if delta_channel is None: self.delta_channel = 1.0 else: self.delta_channel = delta_channel if self.geometry == "fan-curved" or self.geometry == "fan-flat": if self.dist_source_detector is None: raise ValueError( "Parameter dist_source_detector must be specified for fan beam geometry." ) if self.magnification is None: raise ValueError("Parameter magnification must be specified for fan beam geometry.") if delta_pixel is None: self.delta_pixel = self.delta_channel / self.magnification else: self.delta_pixel = delta_pixel elif self.geometry == "parallel": self.magnification = 1.0 if delta_pixel is None: self.delta_pixel = self.delta_channel else: self.delta_pixel = delta_pixel else: raise ValueError("Unspecified geometry {}.".format(self.geometry)) # Set up custom_vjp for _eval and _adj so jax.grad works on them. self._eval = jax.custom_vjp(self._proj_hcb) self._eval.defvjp(lambda x: (self._proj_hcb(x), None), lambda _, y: (self._bproj_hcb(y),)) # type: ignore self._adj = jax.custom_vjp(self._bproj_hcb) self._adj.defvjp(lambda y: (self._bproj_hcb(y), None), lambda _, x: (self._proj_hcb(x),)) # type: ignore super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=np.float32, output_dtype=np.float32, adj_fn=self._adj, jit=False, ) @staticmethod def _proj( x: snp.Array, angles: snp.Array, num_channels: int, center_offset: float = 0.0, roi_radius: Optional[float] = None, geometry: str = "parallel", dist_source_detector: Optional[float] = None, magnification: Optional[float] = None, delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, ) -> snp.Array: return snp.array( svmbir.project( np.array(x), np.array(angles), num_channels, verbose=0, center_offset=center_offset, roi_radius=roi_radius, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, delta_channel=delta_channel, delta_pixel=delta_pixel, ) ) def _proj_hcb(self, x): x = x.reshape(self.svmbir_input_shape) # callback wrapper for _proj y = jax.pure_callback( lambda x: self._proj( x, self.angles, self.num_channels, center_offset=self.center_offset, roi_radius=self.roi_radius, geometry=self.geometry, dist_source_detector=self.dist_source_detector, magnification=self.magnification, delta_channel=self.delta_channel, delta_pixel=self.delta_pixel, ), jax.ShapeDtypeStruct(self.svmbir_output_shape, self.output_dtype), x, ) return y.reshape(self.output_shape) @staticmethod def _bproj( y: snp.Array, angles: snp.Array, num_rows: int, num_cols: int, center_offset: Optional[float] = 0.0, roi_radius: Optional[float] = None, geometry: str = "parallel", dist_source_detector: Optional[float] = None, magnification: Optional[float] = None, delta_channel: Optional[float] = None, delta_pixel: Optional[float] = None, ) -> snp.Array: return snp.array( svmbir.backproject( np.array(y), np.array(angles), num_rows=num_rows, num_cols=num_cols, verbose=0, center_offset=center_offset, roi_radius=roi_radius, geometry=geometry, dist_source_detector=dist_source_detector, magnification=magnification, delta_channel=delta_channel, delta_pixel=delta_pixel, ) ) def _bproj_hcb(self, y): y = y.reshape(self.svmbir_output_shape) # callback wrapper for _bproj x = jax.pure_callback( lambda y: self._bproj( y, self.angles, self.svmbir_input_shape[1], self.svmbir_input_shape[2], center_offset=self.center_offset, roi_radius=self.roi_radius, geometry=self.geometry, dist_source_detector=self.dist_source_detector, magnification=self.magnification, delta_channel=self.delta_channel, delta_pixel=self.delta_pixel, ), jax.ShapeDtypeStruct(self.svmbir_input_shape, self.input_dtype), y, ) return x.reshape(self.input_shape)
[docs]class SVMBIRExtendedLoss(Loss): r"""Extended squared :math:`\ell_2` loss with svmbir tomographic projector. Generalization of the weighted squared :math:`\ell_2` loss for a CT reconstruction problem, .. math:: \alpha \norm{\mb{y} - A(\mb{x})}_W^2 = \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right) \;, where :math:`A` is a :class:`.XRayTransform`, :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set to :class:`scico.linop.Identity`. The extended loss differs from a typical weighted squared :math:`\ell_2` loss as follows. When `positivity=True`, the prox projects onto the non-negative orthant and the loss is infinite if any element of the input is negative. When the `is_masked` option of the associated :class:`.XRayTransform` is ``True``, the reconstruction is computed over a masked region of the image as described in class :class:`.XRayTransform`. """ A: XRayTransform W: Union[Identity, Diagonal] def __init__( self, *args, scale: float = 0.5, prox_kwargs: Optional[dict] = None, positivity: bool = False, W: Optional[Diagonal] = None, **kwargs, ): r"""Initialize a :class:`SVMBIRExtendedLoss` object. Args: y: Sinogram measurement. A: Forward operator. scale: Scaling parameter. prox_kwargs: Dictionary of arguments passed to the :meth:`svmbir.recon` prox routine. Defaults to {"maxiter": 1000, "ctol": 0.001}. positivity: Enforce positivity in the prox operation. The loss is infinite if any element of the input is negative. W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ super().__init__(*args, scale=scale, **kwargs) # type: ignore if not isinstance(self.A, XRayTransform): raise ValueError("LinearOperator A must be a radon_svmbir.XRayTransform.") self.has_prox = True if prox_kwargs is None: prox_kwargs = {} default_prox_args = {"maxiter": 1000, "ctol": 0.001} default_prox_args.update(prox_kwargs) svmbir_prox_args = {} if "maxiter" in default_prox_args: svmbir_prox_args["max_iterations"] = default_prox_args["maxiter"] if "ctol" in default_prox_args: svmbir_prox_args["stop_threshold"] = default_prox_args["ctol"] self.svmbir_prox_args = svmbir_prox_args self.positivity = positivity if W is None: self.W = Identity(self.y.shape) elif isinstance(W, Diagonal): if snp.all(W.diagonal >= 0): self.W = W else: raise ValueError(f"The weights, W, must be non-negative.") else: raise TypeError(f"Parameter W must be None or a linop.Diagonal, got {type(W)}.")
[docs] def __call__(self, x: snp.Array) -> float: if self.positivity and snp.sum(x < 0) > 0: return snp.inf else: return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum()
[docs] def prox(self, v: snp.Array, lam: float = 1, **kwargs) -> snp.Array: v = v.reshape(self.A.svmbir_input_shape) y = self.y.reshape(self.A.svmbir_output_shape) weights = self.W.diagonal.reshape(self.A.svmbir_output_shape) sigma_p = snp.sqrt(lam) if "v0" in kwargs and kwargs["v0"] is not None: v0: Union[float, np.ndarray] = np.reshape( np.array(kwargs["v0"]), self.A.svmbir_input_shape ) else: v0 = 0.0 # change: stop, mask-rad, init result = svmbir.recon( np.array(y), np.array(self.A.angles), weights=np.array(weights), prox_image=np.array(v), num_rows=self.A.svmbir_input_shape[1], num_cols=self.A.svmbir_input_shape[2], center_offset=self.A.center_offset, roi_radius=self.A.roi_radius, geometry=self.A.geometry, dist_source_detector=self.A.dist_source_detector, magnification=self.A.magnification, delta_channel=self.A.delta_channel, delta_pixel=self.A.delta_pixel, sigma_p=float(sigma_p), sigma_y=1.0, positivity=self.positivity, verbose=0, init_image=v0, **self.svmbir_prox_args, ) if np.sum(np.isnan(result)): raise ValueError("Result contains NaNs.") return snp.array(result.reshape(self.A.input_shape))
[docs]class SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): r"""Weighted squared :math:`\ell_2` loss with svmbir tomographic projector. Weighted squared :math:`\ell_2` loss of a CT reconstruction problem, .. math:: \alpha \norm{\mb{y} - A(\mb{x})}_W^2 = \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right) \;, where :math:`A` is a :class:`.XRayTransform`, :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set to :class:`scico.linop.Identity`. """ def __init__( self, *args, prox_kwargs: Optional[dict] = None, **kwargs, ): r"""Initialize a :class:`SVMBIRSquaredL2Loss` object. Args: y: Sinogram measurement. A: Forward operator. scale: Scaling parameter. W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. prox_kwargs: Dictionary of arguments passed to the :meth:`svmbir.recon` prox routine. Defaults to {"maxiter": 1000, "ctol": 0.001}. """ super().__init__(*args, **kwargs, prox_kwargs=prox_kwargs, positivity=False) if self.A.is_masked: raise ValueError( "Parameter is_masked must be False for the XRayTransform in SVMBIRSquaredL2Loss." )