Source code for scico.flax.train.losses
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Definition of loss functions for model optimization."""
import jax.numpy as jnp
import optax
from scico.numpy import Array
[docs]def mse_loss(output: Array, labels: Array) -> float:
"""Compute Mean Squared Error (MSE) loss for training via Optax.
Args:
output: Comparison signal.
labels: Reference signal.
Returns:
MSE between `output` and `labels`.
"""
mse = optax.l2_loss(output, labels)
return jnp.mean(mse)