Source code for scico.optimize._common

# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Functions common to multiple optimizer modules."""

# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from scico.diagnostics import IterationStats
from scico.numpy import Array, BlockArray
from scico.util import Timer


def itstat_func_and_object(
    itstat_fields: dict, itstat_attrib: List, itstat_options: Optional[dict] = None
) -> Tuple[Callable, IterationStats]:
    """Iteration statistics initialization.

    Iteration statistics initialization steps common to all optimizer
    classes.

    Args:
        itstat_fields: A dictionary associating field names with format
              strings for displaying the corresponding values.
        itstat_attrib: A list of expressions corresponding of optimizer
              class attributes to be evaluated when computing iteration
              statistics.
        itstat_options: A dict of named parameters to be passed to
              the :class:`.diagnostics.IterationStats` initializer. The
              dict may also include an additional key "itstat_func"
              with the corresponding value being a function with two
              parameters, an integer and a :class:`Optimizer` object,
              responsible for constructing a tuple ready for insertion
              into the :class:`.diagnostics.IterationStats` object. If
              ``None``, default values are used for the dict entries,
              otherwise the default dict is updated with the dict
              specified by this parameter.

    Returns:
        A tuple consisting of the statistics insertion function and the
        :class:`.diagnostics.IterationStats` object.
    """
    # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831
    itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")"
    scope: Dict[str, Callable] = {}
    exec("def itstat_func(obj): " + itstat_return, scope)

    # determine itstat options and initialize IterationStats object
    default_itstat_options: Dict[str, Union[dict, Callable, bool]] = {
        "fields": itstat_fields,
        "itstat_func": scope["itstat_func"],
        "display": False,
    }
    if itstat_options:
        default_itstat_options.update(itstat_options)

    itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None)  # type: ignore
    itstat_object = IterationStats(**default_itstat_options)  # type: ignore

    return itstat_insert_func, itstat_object


class Optimizer:
    """Base class for optimizer classes.

    Attributes:
        itnum (int): Optimizer iteration counter.
        maxiter (int): Maximum number of optimizer outer-loop iterations.
        timer (:class:`.Timer`): Iteration timer.
    """

    def __init__(self, **kwargs: Any):
        """Initialize common attributes of :class:`Optimizer` objects.

        Args:
            **kwargs: Optional parameter dict. Valid keys are:

                iter0:
                  Initial value of iteration counter. Default value is 0.

                maxiter:
                  Maximum iterations on call to :meth:`solve`. Default
                  value is 100.

                nanstop:
                  If ``True``, stop iterations if a ``NaN`` or ``Inf``
                  value is encountered in a solver working variable.
                  Default value is ``False``.

                itstat_options:
                  A dict of named parameters to be passed to
                  the :class:`.diagnostics.IterationStats` initializer.
                  The dict may also include an additional key
                  "itstat_func" with the corresponding value being a
                  function with two parameters, an integer and an
                  :class:`Optimizer` object, responsible for constructing
                  a tuple ready for insertion into the
                  :class:`.diagnostics.IterationStats` object. If
                  ``None``, default values are used for the dict entries,
                  otherwise the default dict is updated with the dict
                  specified by this parameter.
        """
        iter0 = kwargs.pop("iter0", 0)
        self.maxiter: int = kwargs.pop("maxiter", 100)
        self.nanstop: bool = kwargs.pop("nanstop", False)
        itstat_options = kwargs.pop("itstat_options", None)

        if kwargs:
            raise TypeError(f"Unrecognized keyword argument(s) {', '.join([k for k in kwargs])}")

        self.itnum: int = iter0
        self.timer: Timer = Timer()

        itstat_fields, itstat_attrib = self._itstat_default_fields()
        itstat_extra_fields, itstat_extra_attrib = self._itstat_extra_fields()
        itstat_fields.update(itstat_extra_fields)
        itstat_attrib.extend(itstat_extra_attrib)

        self.itstat_insert_func, self.itstat_object = itstat_func_and_object(
            itstat_fields, itstat_attrib, itstat_options
        )

    def _working_vars_finite(self) -> bool:
        """Determine where ``NaN`` of ``Inf`` encountered in solve.

        Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in
        a solver working variable.
        """
        raise NotImplementedError(
            "NaN check requested but _working_vars_finite not implemented." ""
        )

    def _itstat_default_fields(self) -> Tuple[Dict[str, str], List[str]]:
        """Define iterations stats default fields.

        Return a dict mapping field names to format strings, and a list
        of strings containing the names of attributes or methods to call
        in order to determine the value for each field.
        """
        # iteration number and time fields
        itstat_fields = {
            "Iter": "%d",
            "Time": "%8.2e",
        }
        itstat_attrib = ["itnum", "timer.elapsed()"]
        # objective function can be evaluated if 'g' function can be evaluated
        if self._objective_evaluatable():
            itstat_fields.update({"Objective": "%9.3e"})
            itstat_attrib.append("objective()")

        return itstat_fields, itstat_attrib

    def _objective_evaluatable(self) -> bool:
        """Determine whether the objective function can be evaluated.

        Determine whether the objective function for a :class:`Optimizer`
        object can be evaluated.
        """
        return False

    def _itstat_extra_fields(self) -> Tuple[Dict[str, str], List[str]]:
        """Define additional iterations stats fields.

        Define iterations stats fields that are not common to all
        :class:`Optimizer` classes. Return a dict mapping field names to
        format strings, and a list of strings containing the names of
        attributes or methods to call in order to determine the value for
        each field.
        """
        return {}, []

[docs] def minimizer(self) -> Union[Array, BlockArray]: """Return the current estimate of the functional mimimizer."""
[docs] def step(self): """Perform a single optimizer step."""
[docs] def solve( self, callback: Optional[Callable[[Optimizer], None]] = None, ) -> Union[Array, BlockArray]: r"""Initialize and run the optimization algorithm. Initialize and run the opimization algorithm for a total of `self.maxiter` iterations. Args: callback: An optional callback function, taking an a single argument of type :class:`Optimizer`, that is called at the end of every iteration. Returns: Computed solution. """ self.timer.start() for self.itnum in range(self.itnum, self.itnum + self.maxiter): self.step() if self.nanstop and not self._working_vars_finite(): raise ValueError( f"NaN or Inf value encountered in working variable in iteration {self.itnum}." "" ) self.itstat_object.insert(self.itstat_insert_func(self)) if callback: self.timer.stop() callback(self) self.timer.start() self.timer.stop() self.itnum += 1 self.itstat_object.end() return self.minimizer()