scico.flax.train.diagnostics#
Utilities for computing and displaying performance metrics during training.
Assumes sharded batched data.
Functions
|
Compute diagnostic metrics. |
Functionality to log and store iteration statistics. |
Classes
|
Class that converts a dictionary into an object with named entries. |
- scico.flax.train.diagnostics.compute_metrics(output, labels, criterion=<function mse_loss>)[source]#
Compute diagnostic metrics.
Assumes sharded batched data (i.e. it only works inside pmap because it needs an axis name).
- Parameters:
- Return type:
- Returns:
Loss and SNR between output and labels.
- class scico.flax.train.diagnostics.ArgumentStruct(**entries)[source]#
Bases:
object
Class that converts a dictionary into an object with named entries.
Class that converts a python dictionary into an object with named entries given by the dictionary keys. After the object instantiation both modes of access (dictionary or object entries) can be used.
- scico.flax.train.diagnostics.stats_obj()[source]#
Functionality to log and store iteration statistics.
This function initializes an object
IterationStats
to log and store iteration statistics if logging is enabled during training. The statistics collected are: epoch, time, learning rate, loss and snr in training and loss and snr in evaluation. TheIterationStats
object takes care of both printing stats to command line and storing them for further analysis.- Return type: