3D X-ray Transform ComparisonΒΆ

This example shows how to define a SCICO native 3D X-ray transform using ASTRA toolbox conventions and vice versa.

[1]:
import numpy as np

import jax
import jax.numpy as jnp

import komplot as kplt

import scico.linop.xray.astra as astra
from scico.examples import create_block_phantom
from scico.linop.xray import XRayTransform3D
from scico.util import ContextTimer, Timer
kplt.config_notebook_plotting()

Create a ground truth image and set detector dimensions.

[2]:
N = 64
# use rectangular volume to check whether axes are handled correctly
in_shape = (N + 1, N + 2, N + 3)
x = create_block_phantom(in_shape)
x = jnp.array(x)

# use rectangular detector to check whether axes are handled correctly
out_shape = (N, N + 1)

Set up SCICO projection.

[3]:
num_angles = 3


rot_X = 90.0 - 16.0
rot_Y = np.linspace(0, 180, num_angles, endpoint=False)
angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1)
matrices = XRayTransform3D.matrices_from_euler_angles(
    in_shape, out_shape, "XY", angles, degrees=True
)

Specify geometry using SCICO conventions and project.

[4]:
num_repeats = 3

timer_scico = Timer()
with ContextTimer(timer_scico, "init"):
    H_scico = XRayTransform3D(in_shape, matrices, out_shape)

with ContextTimer(timer_scico, "first_fwd"):
    y_scico = H_scico @ x
    jax.block_until_ready(y_scico)

with ContextTimer(timer_scico, "avg_fwd"):
    for _ in range(num_repeats):
        y_scico = H_scico @ x
        jax.block_until_ready(y_scico)
timer_scico.td["avg_fwd"] /= num_repeats

with ContextTimer(timer_scico, "first_back"):
    HTy_scico = H_scico.T @ y_scico

with ContextTimer(timer_scico, "avg_back"):
    for _ in range(num_repeats):
        HTy_scico = H_scico.T @ y_scico
        jax.block_until_ready(HTy_scico)
timer_scico.td["avg_back"] /= num_repeats

Convert SCICO geometry to ASTRA and project.

[5]:
vectors_from_scico = astra.convert_from_scico_geometry(in_shape, matrices, out_shape)

timer_astra = Timer()
with ContextTimer(timer_astra, "init"):
    H_astra_from_scico = astra.XRayTransform3D(
        input_shape=in_shape, det_count=out_shape, vectors=vectors_from_scico
    )

with ContextTimer(timer_astra, "first_fwd"):
    y_astra_from_scico = H_astra_from_scico @ x
    jax.block_until_ready(y_astra_from_scico)

with ContextTimer(timer_astra, "avg_fwd"):
    for _ in range(num_repeats):
        y_astra_from_scico = H_astra_from_scico @ x
        jax.block_until_ready(y_astra_from_scico)
timer_astra.td["avg_fwd"] /= num_repeats

with ContextTimer(timer_astra, "first_back"):
    HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico

with ContextTimer(timer_astra, "avg_back"):
    for _ in range(num_repeats):
        HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico
        jax.block_until_ready(HTy_astra_from_scico)
timer_astra.td["avg_back"] /= num_repeats

Specify geometry with ASTRA conventions and project.

[6]:
angles = np.random.rand(num_angles) * 180  # random projection angles
det_spacing = [1.0, 1.0]
vectors = astra.angle_to_vector(det_spacing, angles)

H_astra = astra.XRayTransform3D(input_shape=in_shape, det_count=out_shape, vectors=vectors)

y_astra = H_astra @ x
HTy_astra = H_astra.T @ y_astra

Convert ASTRA geometry to SCICO and project.

[7]:
P_from_astra = astra._astra_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom)
H_scico_from_astra = XRayTransform3D(in_shape, P_from_astra, out_shape)

y_scico_from_astra = H_scico_from_astra @ x
HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra

Print timing results.

[8]:
print(f"init         astra    {timer_astra.td['init']:.2e} s")
print(f"init         scico    {timer_scico.td['init']:.2e} s")
print("")
for tstr in ("first", "avg"):
    for dstr in ("fwd", "back"):
        for timer, pstr in zip((timer_astra, timer_scico), ("astra", "scico")):
            print(f"{tstr:5s}  {dstr:4s}  {pstr}    {timer.td[tstr + '_' + dstr]:.2e} s")
        print()
init         astra    2.46e-03 s
init         scico    1.87e-04 s

first  fwd   astra    3.87e-02 s
first  fwd   scico    1.78e+00 s

first  back  astra    2.92e-02 s
first  back  scico    1.32e+00 s

avg    fwd   astra    2.52e-02 s
avg    fwd   scico    5.63e-02 s

avg    back  astra    2.99e-02 s
avg    back  scico    5.32e-02 s

Show projections.

[9]:
fig, ax = kplt.subplots(nrows=3, sharex=True, sharey=True, ncols=2, figsize=(8, 10))
kplt.imview(y_scico[0], title="SCICO projections", cmap="viridis", show_cbar=None, ax=ax[0, 0])
kplt.imview(y_scico[1], cmap="viridis", show_cbar=None, ax=ax[1, 0])
kplt.imview(y_scico[2], cmap="viridis", show_cbar=None, ax=ax[2, 0])
kplt.imview(
    y_astra_from_scico[:, 0], cmap="viridis", title="ASTRA projections", show_cbar=None, ax=ax[0, 1]
)
kplt.imview(y_astra_from_scico[:, 1], cmap="viridis", show_cbar=None, ax=ax[1, 1])
kplt.imview(y_astra_from_scico[:, 2], cmap="viridis", show_cbar=None, ax=ax[2, 1])
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = kplt.subplots(nrows=3, ncols=2, sharex=True, sharey=True, figsize=(8, 10))
kplt.imview(
    y_scico_from_astra[0], title="SCICO projections", cmap="viridis", show_cbar=None, ax=ax[0, 0]
)
kplt.imview(y_scico_from_astra[1], cmap="viridis", show_cbar=None, ax=ax[1, 0])
kplt.imview(y_scico_from_astra[2], cmap="viridis", show_cbar=None, ax=ax[2, 0])
kplt.imview(y_astra[:, 0], title="ASTRA projections", cmap="viridis", show_cbar=None, ax=ax[0, 1])
kplt.imview(y_astra[:, 1], cmap="viridis", show_cbar=None, ax=ax[1, 1])
kplt.imview(y_astra[:, 2], cmap="viridis", show_cbar=None, ax=ax[2, 1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()
../_images/examples_ct_projector_comparison_3d_17_0.png
../_images/examples_ct_projector_comparison_3d_17_1.png

Show back projections.

[10]:
fig, ax = kplt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(8, 5))
kplt.imview(
    HTy_scico[N // 2], title="SCICO back projection", cmap="viridis", show_cbar=None, ax=ax[0]
)
kplt.imview(
    HTy_astra_from_scico[N // 2],
    title="ASTRA back projection",
    cmap="viridis",
    show_cbar=None,
    ax=ax[1],
)
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = kplt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(8, 5))
kplt.imview(
    HTy_scico_from_astra[N // 2],
    title="SCICO back projection",
    cmap="viridis",
    show_cbar=None,
    ax=ax[0],
)
kplt.imview(
    HTy_astra[N // 2], title="ASTRA back projection", cmap="viridis", show_cbar=None, ax=ax[1]
)
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()
../_images/examples_ct_projector_comparison_3d_19_0.png
../_images/examples_ct_projector_comparison_3d_19_1.png