# -*- coding: utf-8 -*-# Copyright (C) 2022-2023 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Generalized data handling for training script.Includes construction of data iterator andinstantiation for parallel processing."""importwarningswarnings.simplefilter(action="ignore",category=FutureWarning)fromtypingimportAny,Unionimportjaximportjax.numpyasjnpfromflaximportjax_utilsfromscico.numpyimportArrayfrom.typed_dictimportDataSetDictDType=AnyKeyArray=Union[Array,jax.Array]
[docs]classIterateData:"""Class to load data for training and testing. It uses the generator pattern to obtain an iterable object. """def__init__(self,dt:DataSetDict,batch_size:int,train:bool=True,key:KeyArray=None):r"""Initialize a :class:`IterateData` object. Args: dt: Dictionary of data for supervised training including images and labels. batch_size: Size of batch for iterating through the data. train: Flag indicating use of iterator for training. Iterator for training is infinite, iterator for testing passes once through the data. Default: ``True``. key: A PRNGKey used as the random key. Default: ``None``. """self.dt=dtself.batch_size=batch_sizeself.train=trainself.n=dt["image"].shape[0]self.key=keyifkeyisNone:self.key=jax.random.PRNGKey(0)self.steps_per_epoch=self.n//batch_sizeself.reset()
[docs]defreset(self):"""Re-shuffle data in training."""ifself.train:self.key,subkey=jax.random.split(self.key)self.perms=jax.random.permutation(subkey,self.n)else:self.perms=jnp.arange(self.n)self.perms=self.perms[:self.steps_per_epoch*self.batch_size]# skips incomplete batchself.perms=self.perms.reshape((self.steps_per_epoch,self.batch_size))self.ns=0
def__iter__(self):returnselfdef__next__(self):"""Get next batch. During training it reshuffles the batches when the data is exhausted."""ifself.ns>=self.steps_per_epoch:ifself.train:self.reset()else:self.ns=0batch={k:v[self.perms[self.ns],...]fork,vinself.dt.items()}self.ns+=1returnbatch
defcreate_input_iter(key:KeyArray,dataset:DataSetDict,batch_size:int,size_device_prefetch:int=2,dtype:DType=jnp.float32,train:bool=True,)->Any:"""Create data iterator for training. Create data iterator for training by sharding and prefetching batches on device. Args: key: A PRNGKey used for random data permutations. dataset: Dictionary of data for supervised training including images and labels. batch_size: Size of batch for iterating through the data. size_device_prefetch: Size of prefetch buffer. Default: 2. dtype: Type of data to handle. Default: :attr:`~numpy.float32`. train: Flag indicating the type of iterator to construct and use. The iterator for training permutes data on each epoch while the iterator for testing passes through the data without permuting it. Default: ``True``. Returns: Array-like data sharded to specific devices coming from an iterator built from the provided dataset. """ds=IterateData(dataset,batch_size,train,key)it=map(prepare_data,ds)it=jax_utils.prefetch_to_device(it,size_device_prefetch)returnit