深度学习常见概念解释(四)——损失函数定义,作用与种类(附公式和代码)

损失函数

  • 前言
  • 定义
  • 作用
  • 种类
    • 1. 均方误差损失(Mean Squared Error Loss,MSE)
      • 公式
      • 特点和优点
      • 缺点
      • 使用场景
      • 示例代码
      • 在机器学习框架中的使用
      • 总结
    • 2. 交叉熵损失(Cross-Entropy Loss)
      • 公式
      • 特点和优点
      • 使用场景
      • 示例代码
      • 在机器学习框架中的使用
      • 总结
  • 总结

前言

在机器学习和深度学习中,损失函数(Loss Function)起着至关重要的作用。它是模型优化过程中不可或缺的一部分,用于衡量模型预测值与真实值之间的差异。选择合适的损失函数不仅可以帮助模型更好地拟合数据,还能反映任务的特性,提高模型的性能和鲁棒性。本文将详细介绍损失函数的定义、作用及常见种类,并通过具体的示例代码展示如何在实际应用中使用这些损失函数。

定义

损失函数(loss function)是在机器学习和深度学习中用来衡量模型预测值与真实值之间差异的函数。它通常表示为一个标量值,用来评估模型在训练数据上的表现。

作用

  1. 衡量预测值与真实值之间的差异: 损失函数衡量了模型在给定数据上的表现,即模型对于输入数据的预测与实际标签之间的差异程度。通过最小化损失函数,模型可以更好地拟合训练数据,提高预测的准确性。

  2. 指导模型优化: 在训练过程中,损失函数是优化算法的目标函数,模型的参数通过最小化损失函数来调整,使得模型能够更好地拟合训练数据。常见的优化算法包括梯度下降(Gradient Descent)及其变种,它们通过计算损失函数的梯度来更新模型参数。

  3. 反映任务的特性: 不同任务和模型需要选择不同的损失函数。例如,分类任务常用的损失函数包括交叉熵损失(Cross-Entropy Loss),回归任务常用的损失函数包括均方误差损失(Mean Squared Error Loss)。选择合适的损失函数能够更好地反映任务的特性,有助于提高模型的性能。

  4. 处理不平衡数据: 在某些情况下,数据可能存在类别不平衡或者噪声,选择合适的损失函数可以帮助模型更好地处理这些情况,提高模型的鲁棒性。

总的来说,损失函数在机器学习和深度学习中扮演着至关重要的角色,它不仅指导模型的训练过程,还反映了模型对于任务的表现和适应能力。

种类

在机器学习和深度学习中,常见的损失函数包括以下几种:

1. 均方误差损失(Mean Squared Error Loss,MSE)

均方误差损失(Mean Squared Error Loss,简称 MSE)是一种常用的回归模型损失函数,用于衡量预测值与真实值之间的差异。MSE 的计算方式是将每个预测值与真实值之间的差值平方,然后求这些差值平方的平均值。

公式

MSE = 1 2 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE} = \frac{1}{2n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE=2n1i=1n(yiy^i)2
其中:

  • n n n 是数据点的数量。
  • y i y_i yi 是第 i i i 个真实值。
  • y ^ i \hat{y}_i y^i 是第 i i i 个预测值。

特点和优点

  1. 平滑性:MSE 损失函数是连续和可微的,这使得它非常适合用于梯度下降等优化算法。
  2. 凸性:MSE 是一个凸函数,这意味着在大多数情况下,它只有一个全局最小值,这对优化问题非常重要。
  3. 简单性:MSE 的公式简单,计算方便,容易实现。

缺点

  1. 对异常值敏感:由于误差被平方,MSE 对异常值(outliers)特别敏感。如果数据集中存在极端值,这些值会对整体误差有很大影响,导致模型不稳定。
  2. 不适用于分类问题:MSE 主要用于回归问题,对于分类问题,通常使用交叉熵损失等其他损失函数。

使用场景

MSE 广泛用于各种回归问题中,例如:

  • 预测房价
  • 股票价格预测
  • 气温预测
  • 机器学习模型中的损失计算

示例代码

import numpy as np

# 定义真实值和预测值
y_true = np.array([1.5, 2.0, 3.5, 4.0, 5.5])
y_pred = np.array([1.4, 2.1, 3.6, 3.9, 5.8])

# 计算均方误差
mse = np.mean((y_true - y_pred) ** 2)
print(f"Mean Squared Error: {mse}")

在机器学习框架中的使用

在流行的机器学习框架中,如 TensorFlow 和 PyTorch,均方误差损失通常作为内置函数提供,使用非常方便。

import torch
import torch.nn as nn

# 定义真实值和预测值
y_true = torch.tensor([1.5, 2.0, 3.5, 4.0, 5.5])
y_pred = torch.tensor([1.4, 2.1, 3.6, 3.9, 5.8])

# 定义 MSE 损失函数
mse_loss = nn.MSELoss()

# 计算损失
loss = mse_loss(y_pred, y_true)
print(f"Mean Squared Error Loss: {loss.item()}")

总结

均方误差损失(MSE)是衡量回归模型性能的一种标准方法,通过计算预测值与真实值之间的平方误差平均值来评估模型的准确性。尽管它对异常值敏感,但其简单性和计算效率使其在各种回归任务中广泛应用。

2. 交叉熵损失(Cross-Entropy Loss)

交叉熵损失(Cross-Entropy Loss)是一种常用于分类任务中的损失函数,特别适用于多类别分类问题。交叉熵损失用于衡量预测的概率分布与真实分布之间的差异。它通过计算真实标签和预测概率之间的不确定性来衡量模型的性能。

公式

  1. 对于二分类问题,二分类交叉熵损失(Binary Cross-Entropy Loss, BCE)的公式如下:
    CE = − ( y log ⁡ ( p ) + ( 1 − y ) log ⁡ ( 1 − p ) ) \text{CE} = - \left( y \log(p) + (1 - y) \log(1 - p) \right) CE=(ylog(p)+(1y)log(1p))
    其中:

    • y y y 是真实标签,取值为 0 或 1。
    • p p p 是预测为类别 1 的概率。
  2. 对于多分类问题,多分类交叉熵损失(Categorical Cross-Entropy Loss, CCE)的公式为:
    CE = − ∑ i = 1 n y i log ⁡ ( p i ) \text{CE} = - \sum_{i=1}^{n} y_i \log(p_i) CE=i=1nyilog(pi)
    其中:

    • n n n 是类别的数量。
    • y i y_i yi 是真实标签,如果样本属于第 i i i类,则 y i = 1 y_i = 1 yi=1 ,否则 y i = 0 y_i = 0 yi=0
    • p i p_i pi 是模型预测样本属于第 i i i类的概率。

PS.:二分类交叉熵损失(Binary Cross-Entropy Loss)也被称为对数损失(Log Loss)。
PPS. 注意在正式计算的时候需要把所有的误差值加起来取平均值(具体步骤见下面的示例代码)。

特点和优点

  1. 概率输出:交叉熵损失函数使用预测的概率分布,这使得它特别适用于分类问题。
  2. 敏感性:它对错误分类的惩罚较大,尤其是在预测概率较高但实际类别不匹配的情况下。
  3. 凸性:交叉熵损失通常是凸的,这有助于优化算法找到全局最优解。

使用场景

交叉熵损失广泛用于各种分类问题中,例如:

  • 图像分类
  • 文本分类
  • 语音识别
  • 机器翻译

示例代码

import numpy as np

# 二分类问题
def binary_cross_entropy(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

# 示例数据
y_true = [1, 0, 1, 1, 0]
y_pred = [0.9, 0.1, 0.8, 0.7, 0.2]

# 计算二分类交叉熵损失
loss = binary_cross_entropy(y_true, y_pred)
print(f"Binary Cross-Entropy Loss: {loss}")

# 多分类问题
def categorical_cross_entropy(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    return -np.sum(y_true * np.log(y_pred)) / y_true.shape[0]

# 示例数据
y_true = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
y_pred = [[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]]

# 计算多分类交叉熵损失
loss = categorical_cross_entropy(y_true, y_pred)
print(f"Categorical Cross-Entropy Loss: {loss}")

在机器学习框架中的使用

在流行的机器学习框架中,如 TensorFlow 和 PyTorch,交叉熵损失通常作为内置函数提供,使用非常方便。

import torch
import torch.nn as nn

# 定义真实标签和预测概率
y_true = torch.tensor([2, 0, 1])
y_pred = torch.tensor([[0.1, 0.2, 0.7], [0.8, 0.1, 0.1], [0.2, 0.6, 0.2]])

# 定义交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(y_pred, y_true)
print(f"Cross-Entropy Loss: {loss.item()}")

总结

交叉熵损失(Cross-Entropy Loss)是分类问题中常用的损失函数,通过衡量预测的概率分布与真实分布之间的差异来评估模型性能。它对错误分类的惩罚较大,并且使用概率输出,非常适合分类任务。流行的深度学习框架通常提供了内置的交叉熵损失函数,方便用户使用。

总结

损失函数在机器学习和深度学习中扮演着至关重要的角色。它不仅指导模型的训练过程,还反映了模型对于任务的表现和适应能力。选择合适的损失函数是模型优化的重要一步,能够显著提高模型的性能和鲁棒性。希望通过本文的介绍,读者能够对损失函数有一个全面的了解,并在实际项目中选择和应用合适的损失函数,这对于模型的训练和性能至关重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/734982.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

<Rust><iced>在iced中显示gif动态图片的一种方法

前言 本文是在rust的GUI库iced中在窗口显示动态图片GIF格式图片的一种方法。 环境配置 系统:window 平台:visual studio code 语言:rust 库:iced、image 概述 在iced中,提供了image部件,从理论上说&…

软考 系统架构设计师系列知识点之杂项集萃(44)

接前一篇文章:软考 系统架构设计师系列知识点之杂项集萃(43) 第71题 设有员工实体Employee(员工号,姓名,性别,年龄,电话,家庭住址,家庭成员,关系…

自动驾驶⻋辆环境感知:多传感器融合

目录 一、多传感器融合技术概述 二、基于传统方法的多传感器融合 三、基于深度学习的视觉和LiDAR的目标级融合 四、基于深度学习的视觉和LiDAR数据的前融合方法 概念介绍 同步和配准 时间同步 标定 摄像机内参标定(使用OpenCV) 摄像机与LiDAR外…

【FreeRTOS】任务状态改进播放控制

这里写目录标题 1 任务状态1.1 阻塞状态(Blocked)1.2 暂停状态(Suspended)1.3 就绪状态(Ready)1.4 完整的状态转换图 2 举个例子3 编写代码 参考《FreeRTOS入门与工程实践(基于DshanMCU-103).pdf》 本节课实现音乐任务的创建,音乐播放的暂停与继续播放,删…

java泛型学习

没有java泛型会存在的问题 假设我们有一个方法,希望通过传递不同类型的参数,输出不同类型的对象值。正常情况下我们可能会写不同的方法来实现,但是这样会导致类不断增加,并且类方法很相似,不能够复用。进而导致类爆炸…

C#实现音乐在线播放和下载——Windows程序设计作业3

1. 作业内容 编写一个C#程序,在作业二实现的本地播放功能的基础上,新增在线播放和在线下载功能,作业二博客地址:C#实现简单音乐文件解析播放——Windows程序设计作业2 2. 架构选择 考虑到需求中的界面友好和跨版本兼容性&#xf…

网站监控定时计划任务

网站监控是一种保护网站安全和稳定性的重要手段,而定时计划任务则是网站监控的一种常见方法。通过设置定时计划任务,可以定期对网站进行监测和检测,及时发现并解决潜在的问题,从而保障网站的正常运行。 首先,网站监控定…

AI播客下载:Eye on AI(AI深度洞察)

"Eye on A.I." 是一档双周播客节目,由长期担任《纽约时报》记者的 Craig S. Smith 主持。在每一集中,Craig 都会与在人工智能领域产生影响的人们交谈。该播客的目的是将渐进的进步置于更广阔的背景中,并考虑发展中的技术的全球影响…

MySQL的自增 ID 用完了,怎么办?

MySQL 自增 ID 一般用的数据类型是 INT 或 BIGINT,正常情况下这两种类型可以满足大多数应用的需求。 当然也有不正常的情况,当达到其最大值时,尝试插入新的记录会导致错误,错误信息类似于: ERROR 167 (22003): Out o…

【深度学习驱动流体力学】计算流体力学openfoam-paraview与python3交互

目的1:配置 ParaView 中的 Python Shell 和 Python 交互环境 ParaView 提供了强大的 Python 接口,允许用户通过 Python 脚本来控制和操作其可视化功能。在 ParaView 中,可以通过 View > Python Shell 菜单打开 Python Shell 窗口,用于执行 Python 代码。要确保正确配置 …

Mkdocs中文系列教程补充(1)

什么是requirements.txt 我的理解是mkdocs依赖的py库 第一次建立MKdocs文档使用 mkdocs new . 完后,比较建议执行一下: pip install -r requirements.txt 不然mkdocs serve后会出现什么 xxx not found ,比如下面这位老哥 示例 mkdocs …

【大数据】—量化交易实战案例(基础策略)

声明:股市有风险,投资需谨慎!本人没有系统学过金融知识,对股票有敬畏之心没有踏入其大门,所以只能写本文来模拟炒股。 量化交易,也被称为算法交易,是一种使用数学模型和计算机算法来分析市场数…

骑马与砍杀战团mod制作-基础-军队笔记(一)

骑马与砍杀战团mod制作-基础-军队装备笔记(一) 资料来源 学习的资料来源: b站【三啸解说】手把手教你做【骑砍】MOD,基础篇,链接为: https://www.bilibili.com/video/BV19x411Q7No?p4&vd_sourcea507…

设施布置之车间布局优化SLP分析

一 物流分析(Flow Analysis) 的基本方法 1、当物料移动是工艺过程的主要部分时,物流分析就是工厂布置设计的核心工作,也是物料搬运分析的开始。 2、零部件物流是该部件在工厂内移动时所走过的路线, 物流分析不仅要考虑…

Python18 数据结构与数据类型转换

1.python中的数据结构 在Python中,数据结构是用来存储、组织和管理数据的方式,以便有效地执行各种数据操作。Python提供了几种内置的数据结构,每种都有其特定的用途和操作方法。以下是Python中一些主要的数据结构: 1.列表&#…

Linux下Cmake安装或版本更新

下载Cmake源码 https://cmake.org/download/ 找到对应的版本和类型 放进linux环境解压 编译 安装 tar -vxvf cmake-3.13.0.tar.gz cd cmake-3.13.0 ./bootstrap make make install设置环境变量 vi ~/.bashrc在文件尾加入 export PATH/your_path/cmake-3.13.0/bin:$PAT…

css-vxe列表中ant进度条与百分比

1.vxe列表 ant进度条 <vxe-column field"actualProgress" title"进度" align"center" width"200"><template #default"{ row }"><a-progress:percent"Math.floor(row.actualProgress)"size"s…

KEIL5软件仿真观察PIN脚电平(软件仿真逻辑分析仪的使用)

仿真前的调整&#xff1a; 例&#xff1a;STM32F103C8T6 &#xff08;如果是F4的板子稍微对着修改一下&#xff09; 逻辑分析仪的使用 输入 PORTA.6( PORAT(哪一组).(哪一个引脚) )

【MySQL】

基础篇 执行一条 select 语句,期间发生了什么? 大家好,我是小林。 学习 SQL 的时候,大家肯定第一个先学到的就是 select 查询语句了,比如下面这句查询语句: // 在 product 表中,查询 id = 1 的记录 select * from product where id = 1;但是有没有想过,MySQL 执行一…