在前面的文章中,我们已经学习了TensorFlow激励函数的操作使用方法(见文章:快速掌握TensorFlow(三)),今天我们将继续学习TensorFlow。
本文主要是学习掌握TensorFlow的损失函数。
一、什么是损失函数
损失函数(loss function)是机器学习中非常重要的内容,它是度量模型输出值与目标值的差异,也就是作为评估模型效果的一种重要指标,损失函数越小,表明模型的鲁棒性就越好。
二、怎样使用损失函数
在TensorFlow中训练模型时,通过损失函数告诉TensorFlow预测结果相比目标结果是好还是坏。在多种情况下,我们会给出模型训练的样本数据和目标数据,损失函数即是比较预测值与给定的目标值之间的差异。
下面将介绍在TensorFlow中常用的损失函数。
1、回归模型的损失函数
首先讲解回归模型的损失函数,回归模型是预测连续因变量的。为方便介绍,先定义预测结果(-1至1的等差序列)、目标结果(目标值为0),代码如下:
import tensorflow as tf
sess=tf.Session()
y_pred=tf.linspace(-1., 1., 100)
y_target=tf.constant(0.)
注意,在实际训练模型时,预测结果是模型输出的结果值,目标结果是样本提供的。
(1)L1正则损失函数(即绝对值损失函数)
L1正则损失函数是对预测值与目标值的差值求绝对值,公式如下:
在TensorFlow中调用方式如下:
loss_l1_vals=tf.abs(y_pred-y_target)
loss_l1_out=sess.run(loss_l1_vals)
L1正则损失函数在目标值附近不平滑,会导致模型不能很好地收敛。
(2)L2正则损失函数(即欧拉损失函数)
L2正则损失函数是预测值与目标值差值的平方和,公式如下:
当对L2取平均值,就变成均方误差(MSE, mean squared error),公式如下:
在TensorFlow中调用方式如下:
# L2损失
loss_l2_vals=tf.square(y_pred - y_target)
loss_l2_out=sess.run(loss_l2_vals)
# 均方误差
loss_mse_vals= tf.reduce.mean(tf.square(y_pred - y_target))
loss_mse_out = sess.run(loss_mse_vals)
L2正则损失函数在目标值附近有很好的曲度,离目标越近收敛越慢,是非常有用的损失函数。
L1、L2正则损失函数如下图所示:
(3)Pseudo-Huber 损失函数
Huber损失函数经常用于回归问题,它是分段函数,公式如下:
从这个公式可以看出当残差(预测值与目标值的差值,即y-f(x) )很小的时候,损失函数为L2范数,残差大的时候,为L1范数的线性函数。
Peseudo-Huber损失函数是Huber损失函数的连续、平滑估计,在目标附近连续,公式如下:
该公式依赖于参数delta,delta越大,则两边的线性部分越陡峭。
在TensorFlow中的调用方式如下:
delta=tf.constant(0.25)
loss_huber_vals = tf.mul(tf.square(delta), tf.sqrt(1. + tf.square(y_target – y_pred)/delta)) – 1.)
loss_huber_out = sess.run(loss_huber_vals)
L1、L2、Huber损失函数的对比图如下,其中Huber的delta取0.25、5两个值:
2、分类模型的损失函数
分类损失函数主要用于评估预测分类结果,重新定义预测值(-3至5的等差序列)和目标值(目标值为1),如下:
y_pred=tf.linspace(-3., 5., 100)
y_target=tf.constant(1.)
y_targets=tf.fill([100, ], 1.)
(1)Hinge损失函数
Hinge损失常用于二分类问题,主要用来评估向量机算法,但有时也用来评估神经网络算法,公式如下:
在TensorFlow中的调用方式如下:
loss_hinge_vals = tf.maximum(0., 1. – tf.mul(y_target, y_pred))
loss_hinge_out = sess.run(loss_hinge_vals)
上面的代码中,目标值为1,当预测值离1越近,则损失函数越小,如下图:
(2)两类交叉熵(Cross-entropy)损失函数
交叉熵来自于信息论,是分类问题中使用广泛的损失函数。交叉熵刻画了两个概率分布之间的距离,当两个概率分布越接近时,它们的交叉熵也就越小,给定两个概率分布p和q,则距离如下:
对于两类问题,当一个概率p=y,则另一个概率q=1-y,因此代入化简后的公式如下:
在TensorFlow中的调用方式如下:
loss_ce_vals = tf.mul(y_target, tf.log(y_pred)) – tf.mul((1. – y_target), tf.log(1. – y_pred))
loss_ce_out = sess.run(loss_ce_vals)
Cross-entropy损失函数主要应用在二分类问题上,预测值为概率值,取值范围为[0,1],损失函数图如下:
(3)Sigmoid交叉熵损失函数
与上面的两类交叉熵类似,只是将预测值y_pred值通过sigmoid函数进行转换,再计算交叉熵损失。在TensorFlow中有内置了该函数,调用方式如下:
loss_sce_vals=tf.nn.sigmoid_cross_entropy_with_logits(y_pred, y_targets)
loss_sce_out=sess.run(loss_sce_vals)
由于sigmoid函数会将输入值变小很多,从而平滑了预测值,使得sigmoid交叉熵在预测值离目标值比较远时,其损失的增长没有那么的陡峭。与两类交叉熵的比较图如下:
(4)加权交叉熵损失函数
加权交叉熵损失函数是Sigmoid交叉熵损失函数的加权,是对正目标的加权。假定权重为0.5,在TensorFlow中的调用方式如下:
weight = tf.constant(0.5)
loss_wce_vals = tf.nn.weighted_cross_entropy_with_logits(y)vals, y_targets, weight)
loss_wce_out = sess.run(loss_wce_vals)
(5)Softmax交叉熵损失函数
Softmax交叉熵损失函数是作用于非归一化的输出结果,只针对单个目标分类计算损失。
通过softmax函数将输出结果转化成概率分布,从而便于输入到交叉熵里面进行计算(交叉熵要求输入为概率),softmax定义如下:
结合前面的交叉熵定义公式,则Softmax交叉熵损失函数公式如下:
在TensorFlow中调用方式如下:
y_pred=tf.constant([[1., -3., 10.]]
y_target=tf.constant([[0.1, 0.02, 0.88]])
loss_sce_vals=tf.nn.softmax_cross_entropy_with_logits(y_pred, y_target)
loss_sce_out=sess.run(loss_sce_vals)
用于回归相关的损失函数,对比图如下:
3、总结
下面对各种损失函数进行一个总结,如下表所示:
在实际使用中,对于回归问题经常会使用MSE均方误差(L2取平均)计算损失,对于分类问题经常会使用Sigmoid交叉熵损失函数。
大家在使用时,还要根据实际的场景、具体的模型,选择使用的损失函数,希望本文对你有帮助。
接下来的“快速掌握TensorFlow”系列文章,还会有更多讲解TensorFlow的精彩内容,敬请期待。
推荐相关阅读
关注本人公众号“大数据与人工智能Lab”(BigdataAILab),获取更多信息。
Original url: Access
Created at: 2018-09-29 19:58:23
Category: default
Tags: none
未标明原创文章均为采集,版权归作者所有,转载无需和我联系,请注明原出处,南摩阿彌陀佛,知识,不只知道,要得到
java windows火焰图_mob64ca12ec8020的技术博客_51CTO博客 - 在windows下不可行,不知道作者是怎样搞的 监听SpringBoot 服务启动成功事件并打印信息_监听springboot启动完毕-CSDN博客 SpringBoot中就绪探针和存活探针_management.endpoint.health.probes.enabled-CSDN博客 u2u转换板 - 嘉立创EDA开源硬件平台 Spring Boot 项目的轻量级 HTTP 客户端 retrofit 框架,快来试试它!_Java精选-CSDN博客 手把手教你打造一套最牛的知识笔记管理系统! - 知乎 - 想法有重合-理论可参考 安宇雨 闲鱼 机械键盘 客制化 开贴记录 文本 linux 使用find命令查找包含某字符串的文件_beijihukk的博客-CSDN博客_find 查找字符串 ---- mac 也适用 安宇雨 打字音 记录集合 B站 bilibili 自行搭建 开坑 真正的客制化 安宇雨 黑苹果开坑 查找工具包maven pom 引用地 工具网站 Dantelis 介绍的玩轴入坑攻略 --- 关于轴的一些说法 --- 非官方 ---- 心得而已 --- 长期开坑更新 [本人问题][新开坑位]关于自动化测试的工具与平台应用 机械键盘 开团 网站记录 -- 能做一个收集的程序就好了 不过现在没时间 -- 信息大多是在群里发的 - 你要让垃圾佬 都去一个地方看难度也是很大的 精神支柱 [超级前台]sprinbboot maven superdesk-app 记录 [信息有用] [环境准备] [基本完成] [sebp/elk] 给已创建的Docker容器增加新的端口映射 - qq_30599553的博客 - CSDN博客 [正在研究] Elasticsearch, Logstash, Kibana (ELK) Docker image documentation elasticsearch centos 安装记录 及 启动手记 正式服务器 39 elasticsearch 问题合集 不断更新 6.1.1 | 6.5.1 两个版本 博客程序 - 测试 - bug记录 等等问题 laravel的启动过程解析 - lpfuture - 博客园 OAuth2 Server PHP 用 Laravel 搭建带 OAuth2 验证的 RESTful 服务 | Laravel China 社区 - 高品质的 Laravel 和 PHP 开发者社区 利用Laravel 搭建oauth2 API接口 附 Unauthenticated 解决办法 - 煮茶的博客 - SegmentFault 思否 使用 OAuth2-Server-php 搭建 OAuth2 Server - 午时的海 - 博客园 基于PHP构建OAuth 2.0 服务端 认证平台 - Endv - 博客园 Laravel 的 Artisan 命令行工具 Laravel 的文件系统和云存储功能集成 浅谈Chromium中的设计模式--终--Observer模式 浅谈Chromium中的设计模式--二--pre/post和Delegate模式 浅谈Chromium中的设计模式--一--Chromium中模块分层和进程模型 DeepMind 4 Hacking Yourself README.md update 20211011
Laravel China 简书 知乎 博客园 CSDN博客 开源中国 Go Further Ryan是菜鸟 | LNMP技术栈笔记 云栖社区-阿里云 Netflix技术博客 Techie Delight Linkedin技术博客 Dropbox技术博客 Facebook技术博客 淘宝中间件团队 美团技术博客 360技术博客 古巷博客 - 一个专注于分享的不正常博客 软件测试知识传播 - 测试窝 有赞技术团队 阮一峰 语雀 静觅丨崔庆才的个人博客 软件测试从业者综合能力提升 - isTester IBM Java 开发 使用开放 Java 生态系统开发现代应用程序 pengdai 一个强大的博主 HTML5资源教程 | 分享HTML5开发资源和开发教程 蘑菇博客 - 专注于技术分享的博客平台 个人博客-leapMie 流星007 CSDN博客 - 舍其小伙伴 稀土掘金 Go 技术论坛 | Golang / Go 语言中国知识社区
最新评论