# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Functionality to generate training data for Flax example scripts.
Computation is distributed via ray (if available) or jax or to reduce
processing time.
"""
import os
from time import time
from typing import Callable, List, Tuple, Union
import numpy as np
import jax
import jax.numpy as jnp
try:
import xdesign # noqa: F401
except ImportError:
have_xdesign = False
else:
have_xdesign = True
try:
import ray # noqa: F401
except ImportError:
have_ray = False
else:
have_ray = True
if have_xdesign:
from xdesign import Foam, SimpleMaterial, UnitCircle, discrete_phantom
from scico.linop import CircularConvolve
from scico.numpy import Array
try:
import astra # noqa: F401
except ImportError:
have_astra = False
else:
have_astra = True
if have_astra:
from scico.linop.xray.astra import XRayTransform
# Arbitrary process count: only applies if GPU is not available.
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
if have_xdesign:
class Foam2(UnitCircle):
"""Foam-like material with two attenuations.
Define functionality to generate phantom with structure similar
to foam with two different attenuation properties."""
def __init__(
self,
size_range: Union[float, List[float]] = [0.05, 0.01],
gap: float = 0,
porosity: float = 1,
attn1: float = 1.0,
attn2: float = 10.0,
):
"""Foam-like structure with two different attenuations.
Circles for material 1 are more sparse than for material 2
by design.
Args:
size_range: The radius, or range of radius, of the
circles to be added. Default: [0.05, 0.01].
gap: Minimum distance between circle boundaries.
Default: 0.
porosity: Target porosity. Must be a value between
[0, 1]. Default: 1.
attn1: Mass attenuation parameter for material 1.
Default: 1.
attn2: Mass attenuation parameter for material 2.
Default: 10.
"""
super(Foam2, self).__init__(radius=0.5, material=SimpleMaterial(attn1))
if porosity < 0 or porosity > 1:
raise ValueError("Porosity must be in the range [0,1).")
self.sprinkle(
300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0
) + self.sprinkle(
300, size_range, gap, material=SimpleMaterial(20), max_density=porosity
)
[docs]def generate_foam2_images(seed: float, size: int, ndata: int) -> Array:
"""Generate batch of foam2 structures.
Generate batch of images with :class:`Foam2` structure
(foam-like material with two different attenuations).
Args:
seed: Seed for data generation.
size: Size of image to generate.
ndata: Number of images to generate.
Returns:
Array of generated data.
"""
if not have_xdesign:
raise RuntimeError("Package xdesign is required for use of this function.")
np.random.seed(seed)
saux = np.zeros((ndata, size, size, 1))
for i in range(ndata):
foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)
saux[i, ..., 0] = discrete_phantom(foam, size=size)
# normalize
saux = saux / np.max(saux, axis=(1, 2), keepdims=True)
return saux
[docs]def generate_foam1_images(seed: float, size: int, ndata: int) -> Array:
"""Generate batch of xdesign foam-like structures.
Generate batch of images with `xdesign` foam-like structure, which
uses one attenuation.
Args:
seed: Seed for data generation.
size: Size of image to generate.
ndata: Number of images to generate.
Returns:
Array of generated data.
"""
if not have_xdesign:
raise RuntimeError("Package xdesign is required for use of this function.")
np.random.seed(seed)
saux = np.zeros((ndata, size, size, 1))
for i in range(ndata):
foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)
saux[i, ..., 0] = discrete_phantom(foam, size=size)
return saux
[docs]def generate_ct_data(
nimg: int,
size: int,
nproj: int,
imgfunc: Callable = generate_foam2_images,
seed: int = 1234,
verbose: bool = False,
test_flag: bool = False,
prefer_ray: bool = True,
) -> Tuple[Array, ...]:
"""Generate batch of computed tomography (CT) data.
Generate batch of CT data for training of machine learning network
models.
Args:
nimg: Number of images to generate.
size: Size of reconstruction images.
nproj: Number of CT views.
imgfunc: Function for generating input images (e.g. foams).
seed: Seed for data generation.
verbose: Flag indicating whether to print status messages.
Default: ``False``.
test_flag: Flag to indicate if running in testing mode. Testing
mode requires a different initialization of ray. Default:
``False``.
prefer_ray: Use ray for distributed processing if available.
Default: ``True``.
Returns:
tuple: A tuple (img, sino, fbp) containing:
- **img** : (:class:`jax.Array`): Generated foam images.
- **sino** : (:class:`jax.Array`): Corresponding sinograms.
- **fbp** : (:class:`jax.Array`) Corresponding filtered back projections.
"""
if not have_astra:
raise RuntimeError("Package astra is required for use of this function.")
# Generate input data.
if have_ray and prefer_ray:
start_time = time()
img = ray_distributed_data_generation(imgfunc, size, nimg, seed, test_flag)
time_dtgen = time() - start_time
else:
start_time = time()
img = imgfunc(seed, size, nimg)
time_dtgen = time() - start_time
# Clip to [0,1] range.
img = jnp.clip(img, a_min=0, a_max=1)
# Shard array
nproc = jax.device_count()
imgshd = img.reshape((nproc, -1, size, size, 1))
# Configure a CT projection operator to generate synthetic measurements.
angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles
gt_sh = (size, size)
detector_spacing = 1
A = XRayTransform(gt_sh, detector_spacing, size, angles) # Radon transform operator
# Compute sinograms in parallel.
a_map = lambda v: jnp.atleast_3d(A @ v.squeeze())
start_time = time()
sinoshd = jax.pmap(lambda i: jax.lax.map(a_map, imgshd[i]))(jnp.arange(nproc))
time_sino = time() - start_time
sino = sinoshd.reshape((-1, nproj, size, 1))
# Normalize sinogram
sino = sino / size
# Compute filter back-project in parallel.
afbp_map = lambda v: jnp.atleast_3d(A.fbp(v.squeeze()))
start_time = time()
fbpshd = jax.pmap(lambda i: jax.lax.map(afbp_map, sinoshd[i]))(jnp.arange(nproc))
time_fbp = time() - start_time
# Clip to [0,1] range.
fbpshd = jnp.clip(fbpshd, a_min=0, a_max=1)
fbp = fbpshd.reshape((-1, size, size, 1))
if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
print(f"{'Platform':26s}{':':4s}{platform}")
print(f"{'Device count':26s}{':':4s}{jax.device_count()}")
print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}")
print(f"{'Sinogram':19s}{'time[s]:':10s}{time_sino:>7.2f}")
print(f"{'FBP':19s}{'time[s]:':10s}{time_fbp:>7.2f}")
return img, sino, fbp
[docs]def generate_blur_data(
nimg: int,
size: int,
blur_kernel: Array,
noise_sigma: float,
imgfunc: Callable,
seed: int = 4321,
verbose: bool = False,
test_flag: bool = False,
prefer_ray: bool = True,
) -> Tuple[Array, ...]:
"""Generate batch of blurred data.
Generate batch of blurred data for training of machine learning
network models.
Args:
nimg: Number of images to generate.
size: Size of reconstruction images.
blur_kernel: Kernel for blurring the generated images.
noise_sigma: Level of additive Gaussian noise to apply.
imgfunc: Function to generate foams.
seed: Seed for data generation.
verbose: Flag indicating whether to print status messages.
Default: ``False``.
test_flag: Flag to indicate if running in testing mode. Testing
mode requires a different initialization of ray.
Default: ``False``.
prefer_ray: Use ray for distributed processing if available.
Default: ``True``.
Returns:
tuple: A tuple (img, blurn) containing:
- **img** : Generated foam images.
- **blurn** : Corresponding blurred and noisy images.
"""
if have_ray and prefer_ray:
start_time = time()
img = ray_distributed_data_generation(imgfunc, size, nimg, seed, test_flag)
time_dtgen = time() - start_time
else:
start_time = time()
img = imgfunc(seed, size, nimg)
time_dtgen = time() - start_time
# Clip to [0,1] range.
img = jnp.clip(img, a_min=0, a_max=1)
# Shard array
nproc = jax.device_count()
imgshd = img.reshape((nproc, -1, size, size, 1))
# Configure blur operator
ishape = (size, size)
A = CircularConvolve(h=blur_kernel, input_shape=ishape)
# Compute blurred images in parallel
a_map = lambda v: jnp.atleast_3d(A @ v.squeeze())
start_time = time()
blurshd = jax.pmap(lambda i: jax.lax.map(a_map, imgshd[i]))(jnp.arange(nproc))
time_blur = time() - start_time
blur = blurshd.reshape((-1, size, size, 1))
# Normalize blurred images
blur = blur / jnp.max(blur, axis=(1, 2), keepdims=True)
# Add Gaussian noise
key = jax.random.PRNGKey(seed)
noise = jax.random.normal(key, blur.shape)
blurn = blur + noise_sigma * noise
# Clip to [0,1] range.
blurn = jnp.clip(blurn, a_min=0, a_max=1)
if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
print(f"{'Platform':26s}{':':4s}{platform}")
print(f"{'Device count':26s}{':':4s}{jax.device_count()}")
print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}")
print(f"{'Blur generation':19s}{'time[s]:':10s}{time_blur:>7.2f}")
return img, blurn
[docs]def distributed_data_generation(
imgenf: Callable, size: int, nimg: int, sharded: bool = True
) -> Array:
"""Data generation distributed among processes using jax.
Args:
imagenf: Function for batch-data generation.
size: Size of image to generate.
ndata: Number of images to generate.
sharded: Flag to indicate if data is to be returned as the
chunks generated by each process or consolidated.
Default: ``True``.
Returns:
Array of generated data.
"""
nproc = jax.device_count()
seeds = jnp.arange(nproc)
if nproc > 1 and nimg % nproc > 0:
raise ValueError("Number of images to generate must be divisible by the number of devices")
ndata_per_proc = int(nimg // nproc)
imgs = jax.pmap(imgenf, static_broadcasted_argnums=(1, 2))(seeds, size, ndata_per_proc)
if not sharded:
imgs = imgs.reshape((-1, size, size, 1))
return imgs
[docs]def ray_distributed_data_generation(
imgenf: Callable, size: int, nimg: int, seedg: float = 123, test_flag: bool = False
) -> Array:
"""Data generation distributed among processes using ray.
Args:
imagenf: Function for batch-data generation.
size: Size of image to generate.
ndata: Number of images to generate.
seedg: Base seed for data generation. Default: 123.
test_flag: Flag to indicate if running in testing mode. Testing
mode requires a different initialization of ray. Default:
``False``.
Returns:
Array of generated data.
"""
if not have_ray:
raise RuntimeError("Package ray is required for use of this function.")
if test_flag:
ray.init(ignore_reinit_error=True)
else:
ray.init()
@ray.remote
def data_gen(seed, size, ndata, imgf):
return imgf(seed, size, ndata)
ar = ray.available_resources()
# Usage of half available CPU resources.
nproc = max(int(ar["CPU"]) // 2, 1)
if nproc > nimg:
nproc = nimg
if nproc > 1 and nimg % nproc > 0:
raise ValueError(
f"Number of images to generate ({nimg}) "
f"must be divisible by the number of available devices ({nproc})"
)
ndata_per_proc = int(nimg // nproc)
ray_return = ray.get(
[data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)]
)
imgs = np.vstack([t for t in ray_return])
ray.shutdown()
return imgs