# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Primal-dual solvers."""
# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations
from typing import Optional, Union
import scico.numpy as snp
from scico.functional import Functional
from scico.linop import LinearOperator, jacobian, operator_norm
from scico.numpy import Array, BlockArray
from scico.numpy.linalg import norm
from scico.operator import Operator
from scico.typing import PRNGKey
from ._common import Optimizer
class PDHG(Optimizer):
r"""Primal–dual hybrid gradient (PDHG) algorithm.
|
Primal–dual hybrid gradient (PDHG) is a family of algorithms
:cite:`esser-2010-general` that includes the Chambolle-Pock
primal-dual algorithm :cite:`chambolle-2010-firstorder`. The form
implemented here is a minor variant :cite:`pock-2011-diagonal` of the
original Chambolle-Pock algorithm.
Solve an optimization problem of the form
.. math::
\argmin_{\mb{x}} \; f(\mb{x}) + g(C \mb{x}) \;,
where :math:`f` and :math:`g` are instances of :class:`.Functional`,
(in most cases :math:`f` will, more specifically be an an instance
of :class:`.Loss`), and :math:`C` is an instance of
:class:`.Operator` or :class:`.LinearOperator`.
When `C` is a :class:`.LinearOperator`, the algorithm iterations are
.. math::
\begin{aligned}
\mb{x}^{(k+1)} &= \mathrm{prox}_{\tau f} \left( \mb{x}^{(k)} -
\tau C^T \mb{z}^{(k)} \right) \\
\mb{z}^{(k+1)} &= \mathrm{prox}_{\sigma g^*} \left( \mb{z}^{(k)}
+ \sigma C((1 + \alpha) \mb{x}^{(k+1)} - \alpha \mb{x}^{(k)}
\right) \;,
\end{aligned}
where :math:`g^*` denotes the convex conjugate of :math:`g`.
Parameters :math:`\tau > 0` and :math:`\sigma > 0` are also required
to satisfy
.. math::
\tau \sigma < \| C \|_2^{-2} \;,
and it is required that :math:`\alpha \in [0, 1]`.
When `C` is a non-linear :class:`.Operator`, a non-linear PDHG variant
:cite:`valkonen-2014-primal` is used, with the same iterations except
for :math:`\mb{x}` update
.. math::
\mb{x}^{(k+1)} = \mathrm{prox}_{\tau f} \left( \mb{x}^{(k)} -
\tau [J_x C(\mb{x}^{(k)})]^T \mb{z}^{(k)} \right) \;.
Attributes:
f (:class:`.Functional`): Functional :math:`f` (usually a
:class:`.Loss`).
g (:class:`.Functional`): Functional :math:`g`.
C (:class:`.Operator`): :math:`C` operator.
tau (scalar): First algorithm parameter.
sigma (scalar): Second algorithm parameter.
alpha (scalar): Relaxation parameter.
x (array-like): Primal variable :math:`\mb{x}` at current
iteration.
x_old (array-like): Primal variable :math:`\mb{x}` at previous
iteration.
z (array-like): Dual variable :math:`\mb{z}` at current
iteration.
z_old (array-like): Dual variable :math:`\mb{z}` at previous
iteration.
"""
def __init__(
self,
f: Functional,
g: Functional,
C: Operator,
tau: float,
sigma: float,
alpha: float = 1.0,
x0: Optional[Union[Array, BlockArray]] = None,
z0: Optional[Union[Array, BlockArray]] = None,
**kwargs,
):
r"""Initialize a :class:`PDHG` object.
Args:
f: Functional :math:`f` (usually a loss function).
g: Functional :math:`g`.
C: Operator :math:`C`.
tau: First algorithm parameter.
sigma: Second algorithm parameter.
alpha: Relaxation parameter.
x0: Starting point for :math:`\mb{x}`. If ``None``, defaults
to an array of zeros.
z0: Starting point for :math:`\mb{z}`. If ``None``, defaults
to an array of zeros.
**kwargs: Additional optional parameters handled by
initializer of base class :class:`.Optimizer`.
"""
self.f: Functional = f
self.g: Functional = g
self.C: Operator = C
self.tau: float = tau
self.sigma: float = sigma
self.alpha: float = alpha
if x0 is None:
input_shape = C.input_shape
dtype = C.input_dtype
x0 = snp.zeros(input_shape, dtype=dtype)
self.x = x0
self.x_old = self.x
if z0 is None:
input_shape = C.output_shape
dtype = C.output_dtype
z0 = snp.zeros(input_shape, dtype=dtype)
self.z = z0
self.z_old = self.z
super().__init__(**kwargs)
def _working_vars_finite(self) -> bool:
"""Determine where ``NaN`` of ``Inf`` encountered in solve.
Return ``False`` if a ``NaN`` or ``Inf`` value is encountered in
a solver working variable.
"""
return snp.all(snp.isfinite(self.x)) and snp.all(snp.isfinite(self.z))
def _objective_evaluatable(self):
"""Determine whether the objective function can be evaluated."""
return self.f.has_eval and self.g.has_eval
def _itstat_extra_fields(self):
"""Define linearized ADMM-specific iteration statistics fields."""
itstat_fields = {"Prml Rsdl": "%9.3e", "Dual Rsdl": "%9.3e"}
itstat_attrib = ["norm_primal_residual()", "norm_dual_residual()"]
return itstat_fields, itstat_attrib
[docs] def minimizer(self):
"""Return current estimate of the functional mimimizer."""
return self.x
[docs] def objective(
self,
x: Optional[Union[Array, BlockArray]] = None,
) -> float:
r"""Evaluate the objective function.
Evaluate the objective function
.. math::
f(\mb{x}) + g(C \mb{x}) \;.
Args:
x: Point at which to evaluate objective function. If ``None``,
the objective is evaluated at the current iterate
:code:`self.x`
Returns:
scalar: Value of the objective function.
"""
if x is None:
x = self.x
return self.f(x) + self.g(self.C(x))
[docs] def norm_primal_residual(self) -> float:
r"""Compute the :math:`\ell_2` norm of the primal residual.
Compute the :math:`\ell_2` norm of the primal residual
.. math::
\tau^{-1} \norm{\mb{x}^{(k)} - \mb{x}^{(k-1)}}_2 \;.
Returns:
Current norm of primal residual.
"""
return norm(self.x - self.x_old) / self.tau # type: ignore
[docs] def norm_dual_residual(self) -> float:
r"""Compute the :math:`\ell_2` norm of the dual residual.
Compute the :math:`\ell_2` norm of the dual residual
.. math::
\sigma^{-1} \norm{\mb{z}^{(k)} - \mb{z}^{(k-1)}}_2 \;.
Returns:
Current norm of dual residual.
"""
return norm(self.z - self.z_old) / self.sigma
[docs] def step(self):
"""Perform a single iteration."""
self.x_old = self.x
self.z_old = self.z
if isinstance(self.C, LinearOperator):
proxarg = self.x - self.tau * self.C.conj().T(self.z)
else:
proxarg = self.x - self.tau * self.C.vjp(self.x, conjugate=True)[1](self.z)
self.x = self.f.prox(proxarg, self.tau, v0=self.x)
proxarg = self.z + self.sigma * self.C(
(1.0 + self.alpha) * self.x - self.alpha * self.x_old
)
self.z = self.g.conj_prox(proxarg, self.sigma, v0=self.z)
[docs] @staticmethod
def estimate_parameters(
C: Operator,
x: Optional[Union[Array, BlockArray]] = None,
ratio: float = 1.0,
factor: Optional[float] = 1.01,
maxiter: int = 100,
key: Optional[PRNGKey] = None,
):
r"""Estimate `tau` and `sigma` parameters of :class:`PDHG`.
Find values of the `tau` and `sigma` parameters of :class:`PDHG`
that respect the constraint
.. math::
\tau \sigma < \| C \|_2^{-2} \quad \text{or} \quad
\tau \sigma < \| J_x C(\mb{x}) \|_2^{-2} \;,
depending on whether :math:`C` is a :class:`.LinearOperator` or
not.
Args:
C: Operator :math:`C`.
x: Value of :math:`\mb{x}` at which to evaluate the Jacobian
of :math:`C` (when it is not a :class:`.LinearOperator`).
If ``None``, defaults to an array of zeros.
ratio: Desired ratio between return :math:`\tau` and
:math:`\sigma` values (:math:`\sigma = \mathrm{ratio}
\tau`).
factor: Safety factor with which to multiply :math:`\| C
\|_2^{-2}` to ensure strict inequality compliance. If
``None``, the value is set to 1.0.
maxiter: Maximum number of power iterations to use in operator
norm estimation (see :func:`.operator_norm`). Default: 100.
key: Jax PRNG key to use in operator norm estimation (see
:func:`.operator_norm`). Defaults to ``None``, in which
case a new key is created.
Returns:
A tuple (`tau`, `sigma`) representing the estimated parameter
values.
"""
if x is None:
x = snp.zeros(C.input_shape, dtype=C.input_dtype)
if factor is None:
factor = 1.0
if isinstance(C, LinearOperator):
J = C
else:
J = jacobian(C, x)
Cnrm = operator_norm(J, maxiter=maxiter, key=key)
tau = snp.sqrt(factor / ratio) / Cnrm
sigma = ratio * tau
return (tau, sigma)