# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Utilities for computing and displaying performance metrics during training.
Assumes sharded batched data.
"""
from typing import Callable, Dict, Tuple, Union
from jax import lax
from scico.diagnostics import IterationStats
from scico.metric import snr
from scico.numpy import Array
from .losses import mse_loss
from .typed_dict import MetricsDict
[docs]def compute_metrics(output: Array, labels: Array, criterion: Callable = mse_loss) -> MetricsDict:
"""Compute diagnostic metrics.
Assumes sharded batched data (i.e. it only works inside pmap because
it needs an axis name).
Args:
output: Comparison signal.
labels: Reference signal.
criterion: Loss function. Default: :meth:`~scico.flax.train.losses.mse_loss`.
Returns:
Loss and SNR between `output` and `labels`.
"""
loss = criterion(output, labels)
snr_ = snr(labels, output)
metrics: MetricsDict = {
"loss": loss,
"snr": snr_,
}
metrics = lax.pmean(metrics, axis_name="batch")
return metrics
[docs]class ArgumentStruct:
"""Class that converts a dictionary into an object with named entries.
Class that converts a python dictionary into an object with named
entries given by the dictionary keys. After the object instantiation
both modes of access (dictionary or object entries) can be used.
"""
def __init__(self, **entries):
self.__dict__.update(entries)
[docs]def stats_obj() -> Tuple[IterationStats, Callable]:
"""Functionality to log and store iteration statistics.
This function initializes an object
:class:`~.diagnostics.IterationStats` to log and store iteration
statistics if logging is enabled during training. The statistics
collected are: epoch, time, learning rate, loss and snr in training
and loss and snr in evaluation. The
:class:`~.diagnostics.IterationStats` object takes care of both
printing stats to command line and storing them for further analysis.
"""
# epoch, time learning rate loss and snr (train and
# eval) fields
itstat_fields = {
"Epoch": "%d",
"Time": "%8.2e",
"Train_LR": "%.6f",
"Train_Loss": "%.6f",
"Train_SNR": "%.2f",
"Eval_Loss": "%.6f",
"Eval_SNR": "%.2f",
}
itstat_attrib = [
"epoch",
"time",
"train_learning_rate",
"train_loss",
"train_snr",
"loss",
"snr",
]
# dynamically create itstat_func; see https://stackoverflow.com/questions/24733831
itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")"
scope: Dict[str, Callable] = {}
exec("def itstat_func(obj): " + itstat_return, scope)
default_itstat_options: Dict[str, Union[dict, Callable, bool]] = {
"fields": itstat_fields,
"itstat_func": scope["itstat_func"],
"display": True,
}
itstat_insert_func: Callable = default_itstat_options.pop("itstat_func") # type: ignore
itstat_object = IterationStats(**default_itstat_options) # type: ignore
return itstat_object, itstat_insert_func