Source code for scico.flax.examples.data_preprocessing

# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Image manipulation utils."""

import glob
import math
import os
import tarfile
import tempfile
from typing import Any, Callable, Optional, Tuple, Union

import numpy as np

import jax.numpy as jnp

import imageio

from scico import util
from scico.examples import rgb2gray
from scico.flax.train.typed_dict import DataSetDict
from scico.linop import CircularConvolve, LinearOperator
from scico.numpy import Array
from scico.typing import Shape

from .typed_dict import ConfigImageSetDict


[docs]def rotation90(img: Array) -> Array: """Rotate an image, or a batch of images, by 90 degrees. Rotate an image or a batch of images by 90 degrees counterclockwise. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. A batch of images is an array with size N x H x W x C with N number of images. Args: img: The array to be rotated. Returns: An image, or batch of images, rotated by 90 degrees counterclockwise. """ if img.ndim < 4: return np.swapaxes(img, 0, 1) else: return np.swapaxes(img, 1, 2)
[docs]def flip(img: Array) -> Array: """Horizontal flip of an image or a batch of images. Horizontally flip an image or a batch of images. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. A batch of images is an array with size N x H x W x C with N number of images. Args: img: The array to be flipped. Returns: An image, or batch of images, flipped horizontally. """ if img.ndim < 4: return img[:, ::-1, ...] else: return img[..., ::-1, :]
[docs]class CenterCrop: """Crop central part of an image to a specified size. Crop central part of an image. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. """ def __init__(self, output_size: Union[Shape, int]): """ Args: output_size: Desired output size. If int, square crop is made. """ # assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size: Shape = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size
[docs] def __call__(self, image: Array) -> Array: """Apply center crop. Args: image: The array to be cropped. Returns: The cropped image. """ h, w = image.shape[:2] new_h, new_w = self.output_size top = (h - new_h) // 2 left = (w - new_w) // 2 image = image[top : top + new_h, left : left + new_w] return image
[docs]class PositionalCrop: """Crop an image from a given corner to a specified size. Crop an image from a given corner. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. """ def __init__(self, output_size: Union[Shape, int]): """ Args: output_size: Desired output size. If int, square crop is made. """ # assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size: Shape = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size
[docs] def __call__(self, image: Array, top: int, left: int) -> Array: """Apply positional crop. Args: image: The array to be cropped. top: Vertical top coordinate of corner to start cropping. left: Horizontal left coordinate of corner to start cropping. Returns: The cropped image. """ h, w = image.shape[:2] new_h, new_w = self.output_size image = image[top : top + new_h, left : left + new_w] return image
[docs]class RandomNoise: """Add Gaussian noise to an image or a batch of images. Add Gaussian noise to an image or a batch of images. An image is an array with size H x W x C with H and W spatial dimensions and C number of channels. A batch of images is an array with size N x H x W x C with N number of images. The Gaussian noise is a Gaussian random variable with mean zero and given standard deviation. The standard deviation can be a fix value corresponding to the specified noise level or randomly selected on a range between 50% and 100% of the specified noise level. """ def __init__(self, noise_level: float, range_flag: bool = False): """ Args: noise_level: Standard dev of the Gaussian noise. range_flag: If ``True``, the standard dev is randomly selected between 50% and 100% of `noise_level` set. Default: ``False``. """ self.range_flag = range_flag if range_flag: self.noise_level_low = 0.5 * noise_level self.noise_level = noise_level
[docs] def __call__(self, image: Array) -> Array: """Add Gaussian noise. Args: image: The array to add noise to. Returns: The noisy image. """ noise_level = self.noise_level if self.range_flag: if image.ndim > 3: num_img = image.shape[0] else: num_img = 1 noise_level_range = np.random.uniform(self.noise_level_low, self.noise_level, num_img) noise_level = noise_level_range.reshape( (noise_level_range.shape[0],) + (1,) * (image.ndim - 1) ) imgnoised = image + np.random.normal(0.0, noise_level, image.shape) imgnoised = np.clip(imgnoised, 0.0, 1.0) return imgnoised
[docs]def preprocess_images( images: Array, output_size: Union[Shape, int], gray_flag: bool = False, num_img: Optional[int] = None, multi_flag: bool = False, stride: Optional[Union[Shape, int]] = None, dtype: Any = np.float32, ) -> Array: """Preprocess (scale, crop, etc.) set of images. Preprocess set of images, converting to gray scale, or cropping or sampling multiple patches from each one, or selecting a subset of them, according to specified setup. Args: images: Array of color images. output_size: Desired output size. If int, square crop is made. gray_flag: If ``True``, converts to gray scale. num_img: If specified, reads that number of images, if not reads all the images in path. multi_flag: If ``True``, samples multiple patches of specified size in each image. stride: Stride between patch origins (indexed from left-top corner). If int, the same stride is used in h and w. dtype: dtype of array. Default: :attr:`~numpy.float32`. Returns: Preprocessed array. """ # Get number of images to use. if num_img is None: num_img = images.shape[0] # Get channels of ouput image. C = 3 if gray_flag: C = 1 # Define functionality to crop and create signal array. if multi_flag: tsfm = PositionalCrop(output_size) assert stride is not None if isinstance(stride, int): stride_multi = (stride, stride) S = np.zeros((num_img, images.shape[1], images.shape[2], C), dtype=dtype) else: tsfm_crop = CenterCrop(output_size) S = np.zeros((num_img, tsfm_crop.output_size[0], tsfm_crop.output_size[1], C), dtype=dtype) # Convert to gray scale and/or crop. for i in range(S.shape[0]): img = images[i] / 255.0 if gray_flag: imgG = rgb2gray(img) # Keep channel singleton. img = imgG.reshape(imgG.shape + (1,)) if not multi_flag: # Crop image img = tsfm_crop(img) S[i] = img if multi_flag: # Sample multiple patches from image h = S.shape[1] w = S.shape[2] nh = int(math.floor((h - tsfm.output_size[0]) / stride_multi[0])) + 1 nw = int(math.floor((w - tsfm.output_size[1]) / stride_multi[1])) + 1 saux = np.zeros( (nh * nw * num_img, tsfm.output_size[0], tsfm.output_size[1], S.shape[-1]), dtype=dtype ) count2 = 0 for i in range(S.shape[0]): for top in range(0, h - tsfm.output_size[0], stride_multi[0]): for left in range(0, w - tsfm.output_size[1], stride_multi[1]): saux[count2, ...] = tsfm(S[i], top, left) count2 += 1 S = saux return S
[docs]def build_image_dataset( imgs_train, imgs_test, config: ConfigImageSetDict, transf: Optional[Callable] = None ) -> Tuple[DataSetDict, ...]: """Preprocess and assemble dataset for training. Preprocess images according to the specified configuration and assemble a dataset into a structure that can be used for training machine learning models. Keep training and testing partitions. Each dictionary returned has images and labels, which are arrays of dimensions (N, H, W, C) with N: number of images; H, W: spatial dimensions and C: number of channels. Args: imgs_train: 4D array (NHWC) with images for training. imgs_test: 4D array (NHWC) with images for testing. config: Configuration of image data set to read. transf: Operator for blurring or other non-trivial transformations. Default: ``None``. Returns: tuple: A tuple (train_ds, test_ds) containing: - **train_ds** : Dictionary of training data (includes images and labels). - **test_ds** : Dictionary of testing data (includes images and labels). """ # Preprocess images by converting to gray scale or sampling multiple # patches according to specified configuration. S_train = preprocess_images( imgs_train, config["output_size"], gray_flag=config["run_gray"], num_img=config["num_img"], multi_flag=config["multi"], stride=config["stride"], ) S_test = preprocess_images( imgs_test, config["output_size"], gray_flag=config["run_gray"], num_img=config["test_num_img"], multi_flag=config["multi"], stride=config["stride"], ) # Check for transformation tsfm: Optional[Callable] = None # Processing: add noise or blur or etc. if config["data_mode"] == "dn": # Denoise problem tsfm = RandomNoise(config["noise_level"], config["noise_range"]) elif config["data_mode"] == "dcnv": # Deconvolution problem assert transf is not None tsfm = transf if config["augment"]: # Augment training data set by flip and 90 degrees rotation strain1 = rotation90(S_train.copy()) strain2 = flip(S_train.copy()) S_train = np.concatenate((S_train, strain1, strain2), axis=0) # Processing: apply transformation if tsfm is not None: if config["data_mode"] == "dn": Stsfm_train = tsfm(S_train.copy()) Stsfm_test = tsfm(S_test.copy()) elif config["data_mode"] == "dcnv": tsfm2 = RandomNoise(config["noise_level"], config["noise_range"]) Stsfm_train = tsfm2(tsfm(S_train.copy())) Stsfm_test = tsfm2(tsfm(S_test.copy())) # Shuffle data rng = np.random.default_rng(config["seed"]) perm_tr = rng.permutation(Stsfm_train.shape[0]) perm_tt = rng.permutation(Stsfm_test.shape[0]) train_ds: DataSetDict = {"image": Stsfm_train[perm_tr], "label": S_train[perm_tr]} test_ds: DataSetDict = {"image": Stsfm_test[perm_tt], "label": S_test[perm_tt]} return train_ds, test_ds
[docs]def images_read(path: str, ext: str = "jpg") -> Array: # pragma: no cover """Read a collection of color images from a set of files. Read a collection of color images from a set of files in the specified directory. All files with extension `ext` (i.e. matching glob `*.ext`) in directory `path` are assumed to be image files and are read. Images may have different aspect ratios, therefore, they are transposed to keep the aspect ratio of the first image read. Args: path: Path to directory containing the image files. ext: Filename extension. Returns: Collection of color images as a 4D array. """ slices = [] shape = None for file in sorted(glob.glob(os.path.join(path, "*." + ext))): image = imageio.imread(file) if shape is None: shape = image.shape[:2] if shape != image.shape[:2]: image = np.transpose(image, (1, 0, 2)) slices.append(image) return np.stack(slices)
[docs]def get_bsds_data(path: str, verbose: bool = False): # pragma: no cover """Download BSDS500 data from the BSDB project. Download the BSDS500 dataset, a set of 500 color images of size 481x321 or 321x481, from the Berkeley Segmentation Dataset and Benchmark project. The downloaded data is converted to `.npz` format for convenient access via :func:`numpy.load`. The converted data is saved in a file `bsds500.npz` in the directory specified by `path`. Note that train and test folders are merged to get a set of 400 images for training while the val folder is reserved as a set of 100 images for testing. This is done in multiple works such as :cite:`zhang-2017-dncnn`. Args: path: Directory in which converted data is saved. verbose: Flag indicating whether to print status messages. """ # data source URL and filenames data_base_url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/" data_tar_file = "BSR_bsds500.tgz" # ensure path directory exists if not os.path.isdir(path): raise ValueError(f"Path {path} does not exist or is not a directory") # create temporary directory temp_dir = tempfile.TemporaryDirectory() if verbose: print(f"Downloading {data_tar_file} from {data_base_url}") data = util.url_get(data_base_url + data_tar_file) f = open(os.path.join(temp_dir.name, data_tar_file), "wb") f.write(data.read()) f.close() if verbose: print("Download complete") # untar downloaded data into temporary directory if verbose: print(f"Extracting content from tar file {data_tar_file}") with tarfile.open(os.path.join(temp_dir.name, data_tar_file), "r") as tar_ref: tar_ref.extractall(temp_dir.name) # read untared data files into 4D arrays and save as .npz data_path = os.path.join("BSR", "BSDS500", "data", "images") train_path = os.path.join(data_path, "train") imgs_train = images_read(os.path.join(temp_dir.name, train_path)) val_path = os.path.join(data_path, "val") imgs_val = images_read(os.path.join(temp_dir.name, val_path)) test_path = os.path.join(data_path, "test") imgs_test = images_read(os.path.join(temp_dir.name, test_path)) # Train and test data merge into train. # Leave val data for testing. imgs400 = np.vstack([imgs_train, imgs_test]) if verbose: print(f"Read {imgs400.shape[0]} images for training") print(f"Read {imgs_val.shape[0]} images for testing") npz_file = os.path.join(path, "bsds500.npz") if verbose: subpath = str.split(npz_file, ".cache") npz_file_display = "~/.cache" + subpath[-1] print(f"Saving as {npz_file_display}") np.savez(npz_file, imgstr=imgs400, imgstt=imgs_val)
[docs]def build_blur_kernel( kernel_size: Shape, blur_sigma: float, dtype: Any = np.float32, ): """Construct a blur kernel as specified. Args: kernel_size: Size of the blur kernel. blur_sigma: Standard deviation of the blur kernel. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ kernel = 1.0 meshgrids = np.meshgrid(*[np.arange(size, dtype=dtype) for size in kernel_size]) for size, mgrid in zip(kernel_size, meshgrids): mean = (size - 1) / 2 kernel *= np.exp(-(((mgrid - mean) / blur_sigma) ** 2) / 2) # Make sure norm of values in gaussian kernel equals 1. knorm = np.sqrt(np.sum(kernel * kernel)) kernel = kernel / knorm return kernel
[docs]class PaddedCircularConvolve(LinearOperator): """Define padded convolutional operator. The operator pads the signal with a reflection of the borders before convolving with the kernel provided at initialization. It crops the result of the convolution to maintain the same signal size. """ def __init__( self, output_size: Union[Shape, int], channels: int, kernel_size: Union[Shape, int], blur_sigma: float, dtype: Any = np.float32, ): """ Args: output_size: Size of the image to blur. channels: Number of channels in image to blur. kernel_size: Size of the blur kernel. blur_sigma: Standard deviation of the blur kernel. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ if isinstance(output_size, int): output_size = (output_size, output_size) else: assert len(output_size) == 2 if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) else: assert len(kernel_size) == 2 # Define padding. self.padsz = ( (kernel_size[0] // 2, kernel_size[0] // 2), (kernel_size[1] // 2, kernel_size[1] // 2), (0, 0), ) shape = (output_size[0], output_size[1], channels) with_pad = ( output_size[0] + self.padsz[0][0] + self.padsz[0][1], output_size[1] + self.padsz[1][0] + self.padsz[1][1], ) shape_padded = (with_pad[0], with_pad[1], channels) # Define data types. input_dtype = dtype output_dtype = dtype # Construct blur kernel as specified. kernel = build_blur_kernel(kernel_size, blur_sigma) # Define convolution part. self.conv = CircularConvolve(kernel, input_shape=shape_padded, ndims=2, input_dtype=dtype) # Initialize Linear Operator. super().__init__( input_shape=shape, output_shape=shape, input_dtype=input_dtype, output_dtype=output_dtype, jit=True, ) def _eval(self, x: Array) -> Array: """Apply operator. Args: x: The array with input signal. The input to the constructed operator should be HWC with H and W spatial dimensions given by `output_size` and C the given `channels`. Returns: The result of padding, convolving and cropping the signal. The output signal has the same HWC dimensions as the input signal. """ xpadd: Array = jnp.pad(x, self.padsz, mode="reflect") rconv: Array = self.conv(xpadd) return rconv[self.padsz[0][0] : -self.padsz[0][1], self.padsz[1][0] : -self.padsz[1][1], :]