3D X-ray Transform ComparisonΒΆ

This example shows how to define a SCICO native 3D X-ray transform using ASTRA toolbox conventions and vice versa.

[1]:
import numpy as np

import jax
import jax.numpy as jnp

import scico.linop.xray.astra as astra
from scico import plot
from scico.examples import create_block_phantom
from scico.linop.xray import XRayTransform3D
from scico.util import ContextTimer, Timer
plot.config_notebook_plotting()

Create a ground truth image and set detector dimensions.

[2]:
N = 64
# use rectangular volume to check whether axes are handled correctly
in_shape = (N + 1, N + 2, N + 3)
x = create_block_phantom(in_shape)
x = jnp.array(x)

# use rectangular detector to check whether axes are handled correctly
out_shape = (N, N + 1)

Set up SCICO projection.

[3]:
num_angles = 3


rot_X = 90.0 - 16.0
rot_Y = np.linspace(0, 180, num_angles, endpoint=False)
angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1)
matrices = XRayTransform3D.matrices_from_euler_angles(
    in_shape, out_shape, "XY", angles, degrees=True
)

Specify geometry using SCICO conventions and project.

[4]:
num_repeats = 3

timer_scico = Timer()
with ContextTimer(timer_scico, "init"):
    H_scico = XRayTransform3D(in_shape, matrices, out_shape)

with ContextTimer(timer_scico, "first_fwd"):
    y_scico = H_scico @ x
    jax.block_until_ready(y_scico)

with ContextTimer(timer_scico, "avg_fwd"):
    for _ in range(num_repeats):
        y_scico = H_scico @ x
        jax.block_until_ready(y_scico)
timer_scico.td["avg_fwd"] /= num_repeats

with ContextTimer(timer_scico, "first_back"):
    HTy_scico = H_scico.T @ y_scico

with ContextTimer(timer_scico, "avg_back"):
    for _ in range(num_repeats):
        HTy_scico = H_scico.T @ y_scico
        jax.block_until_ready(HTy_scico)
timer_scico.td["avg_back"] /= num_repeats

Convert SCICO geometry to ASTRA and project.

[5]:
vectors_from_scico = astra.convert_from_scico_geometry(in_shape, matrices, out_shape)

timer_astra = Timer()
with ContextTimer(timer_astra, "init"):
    H_astra_from_scico = astra.XRayTransform3D(
        input_shape=in_shape, det_count=out_shape, vectors=vectors_from_scico
    )

with ContextTimer(timer_astra, "first_fwd"):
    y_astra_from_scico = H_astra_from_scico @ x
    jax.block_until_ready(y_astra_from_scico)

with ContextTimer(timer_astra, "avg_fwd"):
    for _ in range(num_repeats):
        y_astra_from_scico = H_astra_from_scico @ x
        jax.block_until_ready(y_astra_from_scico)
timer_astra.td["avg_fwd"] /= num_repeats

with ContextTimer(timer_astra, "first_back"):
    HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico

with ContextTimer(timer_astra, "avg_back"):
    for _ in range(num_repeats):
        HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico
        jax.block_until_ready(HTy_astra_from_scico)
timer_astra.td["avg_back"] /= num_repeats

Specify geometry with ASTRA conventions and project.

[6]:
angles = np.random.rand(num_angles) * 180  # random projection angles
det_spacing = [1.0, 1.0]
vectors = astra.angle_to_vector(det_spacing, angles)

H_astra = astra.XRayTransform3D(input_shape=in_shape, det_count=out_shape, vectors=vectors)

y_astra = H_astra @ x
HTy_astra = H_astra.T @ y_astra

Convert ASTRA geometry to SCICO and project.

[7]:
P_from_astra = astra._astra_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom)
H_scico_from_astra = XRayTransform3D(in_shape, P_from_astra, out_shape)

y_scico_from_astra = H_scico_from_astra @ x
HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra

Print timing results.

[8]:
print(f"init         astra    {timer_astra.td['init']:.2e} s")
print(f"init         scico    {timer_scico.td['init']:.2e} s")
print("")
for tstr in ("first", "avg"):
    for dstr in ("fwd", "back"):
        for timer, pstr in zip((timer_astra, timer_scico), ("astra", "scico")):
            print(f"{tstr:5s}  {dstr:4s}  {pstr}    {timer.td[tstr + '_' + dstr]:.2e} s")
        print()
init         astra    1.04e-04 s
init         scico    9.71e-01 s

first  fwd   astra    5.27e-02 s
first  fwd   scico    4.84e-02 s

first  back  astra    4.84e-02 s
first  back  scico    6.53e-01 s

avg    fwd   astra    4.67e-02 s
avg    fwd   scico    4.90e-02 s

avg    back  astra    4.59e-02 s
avg    back  scico    3.31e-02 s

Show projections.

[9]:
fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10))
plot.imview(y_scico[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0])
plot.imview(y_scico[1], cbar=None, fig=fig, ax=ax[1, 0])
plot.imview(y_scico[2], cbar=None, fig=fig, ax=ax[2, 0])
plot.imview(y_astra_from_scico[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra_from_scico[:, 1], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra_from_scico[:, 2], cbar=None, fig=fig, ax=ax[2, 1])
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10))
plot.imview(y_scico_from_astra[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0])
plot.imview(y_scico_from_astra[1], cbar=None, fig=fig, ax=ax[1, 0])
plot.imview(y_scico_from_astra[2], cbar=None, fig=fig, ax=ax[2, 0])
plot.imview(y_astra[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra[:, 1], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra[:, 2], cbar=None, fig=fig, ax=ax[2, 1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()
../_images/examples_ct_projector_comparison_3d_17_0.png
../_images/examples_ct_projector_comparison_3d_17_1.png

Show back projections.

[10]:
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5))
plot.imview(HTy_scico[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0])
plot.imview(
    HTy_astra_from_scico[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1]
)
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5))
plot.imview(
    HTy_scico_from_astra[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0]
)
plot.imview(HTy_astra[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()
../_images/examples_ct_projector_comparison_3d_19_0.png
../_images/examples_ct_projector_comparison_3d_19_1.png