scico.flax.train.input_pipeline#

Generalized data handling for training script.

Includes construction of data iterator and instantiation for parallel processing.

Functions

prepare_data(xs)

Reshape input batch for parallel training.

Classes

IterateData(dt, batch_size[, train, key])

Class to load data for training and testing.

class scico.flax.train.input_pipeline.IterateData(dt, batch_size, train=True, key=None)[source]#

Bases: object

Class to load data for training and testing.

It uses the generator pattern to obtain an iterable object.

Initialize a IterateData object.

Parameters:
  • dt (DataSetDict) – Dictionary of data for supervised training including images and labels.

  • batch_size (int) – Size of batch for iterating through the data.

  • train (bool) – Flag indicating use of iterator for training. Iterator for training is infinite, iterator for testing passes once through the data. Default: True.

  • key (Optional[Array]) – A PRNGKey used as the random key. Default: None.

reset()[source]#

Re-shuffle data in training.

scico.flax.train.input_pipeline.prepare_data(xs)[source]#

Reshape input batch for parallel training.

Return type:

Any