Source code for scico.diagnostics

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Diagnostic information for iterative solvers."""

import re
import warnings
from collections import OrderedDict, namedtuple
from typing import List, Optional, Tuple, Union


[docs]class IterationStats: """Display and record iterative algorithms statistics. Display and record statistics related to convergence of iterative algorithms. """ def __init__( self, fields: OrderedDict, ident: Optional[dict] = None, display: bool = False, period: int = 1, shift_cycles: bool = True, overwrite: bool = True, colsep: int = 2, ): """ The `fields` parameter represents an OrderedDict (to ensure that field order is retained) specifying field names for each value to be inserted and a corresponding format string for when it is displayed. When inserted values are printed in tabular form, the field lengths are taken as the maxima of the header string lengths and the field lengths embedded in the format strings (if specified). For best results, the field lengths should be manually specified based on knowledge of the ranges of values that may be encountered. For example, for a '%e' format string, the specified field length should be at least the precision (e.g. '%.2e' specifies a precision of 2 places) plus 6 when only positive values may encountered, and plus 7 when negative values may be encountered. Args: fields: A dictionary associating field names with format strings for displaying the corresponding values. ident: A dictionary associating field names. with corresponding valid identifiers for use within the namedtuple used to record results. Defaults to ``None``. display: Flag indicating whether results should be printed to stdout. Defaults to ``False``. period: Only display one result in every cycle of length `period`. shift_cycles: If ``True``, apply an offset to the iteration count so that display cycles end at 0, `period` - 1, etc. Otherwise, cycles end at `period`, 2 * `period`, etc. overwrite: If ``True``, display all results, but each one overwrites the next, except for one result per cycle. colsep: Number of spaces seperating fields in displayed tables. Defaults to 2. Raises: TypeError: If the `fields` parameter is not a dict. """ # Parameter fields must be specified as an OrderedDict to ensure # that field order is retained if not isinstance(fields, dict): raise TypeError("Parameter fields must be an instance of dict.") # Subsampling rate of results that are to be displayed self.period: int = period # Offset to iteration count for determining start of period self.period_offset = 1 if shift_cycles else 0 # Flag indicating whether to display and overwrite, or not display at all self.overwrite: bool = overwrite # Number of spaces seperating fields in displayed tables self.colsep: int = colsep # Main list of inserted values self.iterations: List = [] # Total length of header string in displayed tables self.headlength: int = 0 # List of field names self.fieldname: List[str] = [] # List of field format strings self.fieldformat: List[str] = [] # List of lengths of each field in displayed tables self.fieldlength: List[int] = [] # Names of fields in namedtuple used to record iteration values self.tuplefields: List[str] = [] # Compile regex for decomposing format strings fmre = re.compile(r"%(\+?-?)((?:\d+)?)(\.?)((?:\d+)?)([a-z])") # Iterate over field names for name in fields: # Get format string and decompose it using compiled regex fmt = fields[name] fmtmatch = fmre.match(fmt) if not fmtmatch: raise ValueError(f'Format string "{fmt}" could not be parsed.') fmflg, fmlen, fmdot, fmprc, fmtyp = fmtmatch.groups() flen = len(fmt % 0) # Warn if actual formatted length longer than specified field # length, e.g. as in "%4e" if fmlen != "" and flen > int(fmlen): warnings.warn( f'Actual length {flen} of format "{fmt}" for field ' f'"{name}" is longer than specified value {fmlen}', stacklevel=2, ) # If the actual formatted length is less than that of the header # string, insert a field length specifier to increase the # length to that of the header string if flen < len(name): fmt = f"%{fmflg}{len(name)}{fmdot}{fmprc}{fmtyp}" flen = len(name) self.fieldname.append(name) self.fieldformat.append(fmt) self.fieldlength.append(flen) self.headlength += flen + colsep # If a distinct identifier is specified for this field, use it # as the namedtuple identifier, otherwise compute it from the # field name if ident is not None and name in ident: self.tuplefields.append(ident[name]) else: # See https://stackoverflow.com/a/3305731 tfnm = re.sub(r"\W+|^(?=\d)", "_", name) if tfnm[0] == "_": tfnm = tfnm[1:] self.tuplefields.append(tfnm) # Decrement head length to account for final colsep added self.headlength -= colsep # Construct namedtuple used to record values self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields) # type: ignore # Set up table header string display if requested self.display = display self.disphdr = None if display: self.disphdr = ( (" " * colsep).join( ["%-*s" % (fl, fn) for fl, fn in zip(self.fieldlength, self.fieldname)] ) + "\n" + "-" * self.headlength )
[docs] def insert(self, values: Union[List, Tuple]): """Insert a list of values for a single iteration. Args: values: Statistics for a single iteration. """ self.iterations.append(self.IterTuple(*values)) if self.display: if self.disphdr is not None: print(self.disphdr) self.disphdr = None if self.overwrite: if (len(self.iterations) - self.period_offset) % self.period == 0: end = "\n" else: end = "\r" print((" " * self.colsep).join(self.fieldformat) % values, end=end) else: if (len(self.iterations) - self.period_offset) % self.period == 0: print((" " * self.colsep).join(self.fieldformat) % values)
[docs] def end(self): """Mark end of iterations. This method should be called at the end of a set of iterations. Its only function is to ensure that the displayed output is left in an appropriate state when overwriting is active with a display period other than unity. """ if ( self.display and self.overwrite and self.period > 1 and (len(self.iterations) - self.period_offset) % self.period ): print()
[docs] def history(self, transpose: bool = False): """Retrieve record of all inserted iterations. Args: transpose: Flag indicating whether results should be returned in "transposed" form, i.e. as a namedtuple of lists rather than a list of namedtuples. Default: False. Returns: list of namedtuple or namedtuple of lists: Record of all inserted iterations. """ if transpose and self.iterations: return self.IterTuple( *[ [self.iterations[m][n] for m in range(len(self.iterations))] for n in range(len(self.iterations[0])) ] ) return self.iterations