# -*- coding: utf-8 -*-# Copyright (C) 2022-2023 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Definition of loss functions for model optimization."""importjax.numpyasjnpimportoptaxfromscico.numpyimportArray
[docs]defmse_loss(output:Array,labels:Array)->float:"""Compute Mean Squared Error (MSE) loss for training via Optax. Args: output: Comparison signal. labels: Reference signal. Returns: MSE between `output` and `labels`. """mse=optax.l2_loss(output,labels)returnjnp.mean(mse)