scico.flax.train.diagnostics#

Utilities for computing and displaying performance metrics during training.

Assumes sharded batched data.

Functions

compute_metrics(output, labels[, criterion])

Compute diagnostic metrics.

stats_obj()

Functionality to log and store iteration statistics.

Classes

ArgumentStruct(**entries)

Class that converts a dictionary into an object with named entries.

scico.flax.train.diagnostics.compute_metrics(output, labels, criterion=<function mse_loss>)[source]#

Compute diagnostic metrics.

Assumes sharded batched data (i.e. it only works inside pmap because it needs an axis name).

Parameters:
  • output (Array) – Comparison signal.

  • labels (Array) – Reference signal.

  • criterion (Callable) – Loss function. Default: mse_loss.

Return type:

MetricsDict

Returns:

Loss and SNR between output and labels.

class scico.flax.train.diagnostics.ArgumentStruct(**entries)[source]#

Bases: object

Class that converts a dictionary into an object with named entries.

Class that converts a python dictionary into an object with named entries given by the dictionary keys. After the object instantiation both modes of access (dictionary or object entries) can be used.

scico.flax.train.diagnostics.stats_obj()[source]#

Functionality to log and store iteration statistics.

This function initializes an object IterationStats to log and store iteration statistics if logging is enabled during training. The statistics collected are: epoch, time, learning rate, loss and snr in training and loss and snr in evaluation. The IterationStats object takes care of both printing stats to command line and storing them for further analysis.

Return type:

Tuple[IterationStats, Callable]