Source code for scico.flax.train.losses

# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Definition of loss functions for model optimization."""

import jax.numpy as jnp

import optax

from scico.numpy import Array


[docs]def mse_loss(output: Array, labels: Array) -> float: """Compute Mean Squared Error (MSE) loss for training via Optax. Args: output: Comparison signal. labels: Reference signal. Returns: MSE between `output` and `labels`. """ mse = optax.l2_loss(output, labels) return jnp.mean(mse)