在上一篇中,我们介绍了了用TensorflowJS构建一个神经网络,然后用该模型来进行手写MINST数据的识别。和之前的基本模型比起来,模型的准确率上升的似乎不是很大。(在我的例子中,验证部分比较简单,只是一个大致的统计)甚至有些情况下,如果参数选择不当,训练效果还会更差。
卷积网络,也叫做卷积神经网络(con-volutional neural network, CNN),是一种专门用来处理具有类似网格结构的数据的神经网络。例如时间序列数据(可以认为是在时间轴上有规律地采样形成的一维网格)和图像数据(可以看作是二维的像素网格)。对于MINST手写数据来说,应用卷积网络会不会是更好的选择呢?
先上图:
代码见Codepen
该图是我应用CNN对MINST数据进行训练的结果,准确率在97%,可以说和之前的模型来比较,提高显著。要知道,要知道在获得比较高的准确率后,要提高一点都是比较困难的。那我们就简单的看看卷积网络是什么,他为什么对于手写数据的识别做的比其他模型的更好?
CNN的原理实际上是模拟了人类的视觉神经如何识别图像。每个视觉神经只负责处理不同大小的一小块画面,在不同的神经层次处理不同的信息。
大家可能有用过Photoshop的经验,Photoshop提供很多不同类型的滤镜来处理图像,其实那个本质上就是应用不同的核函数对图像进行卷积的结果。
卷积操作如下图所示:
左边的矩形是输入数据,也就是我们要处理的图像的张量表示。中间的矩形是核,而右边的矩形就是卷积的结果。核函数从左至右,从上到下,每次移动一个像素扫描图像,计算出卷积和的结果矩阵。
卷积的计算过程如下图:
计算就是乘法和加法,但是上图的例子计算有个错误,看你找不找得到。下面这个图计算更简单一点:
如果你能够理解上图的数学含义你就能理解,核函数其实是一个权重,对于每一个小块的图像,不同的核对不同区域的权重不一样。
如上图的两个核,左面的对于图像中间的权重为0,上面的是负向加权,而下面的正向加权。可以想像对应于普通图像,数据分布均匀,这个加权计算的结果趋近于零,对应于水平边缘,上面没有数据而下面有数据,这个加权的值就比较大,这样我们就能够检测出水平边缘。同理右边核函数对应垂直边缘。
上图就是应用垂直,水平,垂直加水平的核,对安卓小机器人图像卷积的结果。我们可以看出对应的核函数是如何识别出边缘的。
然而在学习的时候要使用什么样的核呢?我们看一下网络结构:
每一个像素都是一个特征,每一个特征是一个输入节点。每一个卷积的结果都输入到下一层的隐藏节点。核的权重就连接了输入层和隐藏层。经过0填充的输入层可以输出不同形状的卷积结果。同时可以调整扫描的步幅(stride)。
上图中,输入为7*7,没有填充,步幅为1,输出为5*5
上图中,输入为5*5,填充1格,步幅为1,输出为5*5
上图中,输入为5*5,填充1格,步幅为2,输出为3*3
上图中,输入为2*2,填充2格,步幅为1,输出为4*4
我们可以看出来,增加填充会导致隐藏层节点数量增加,而增大步幅可以使得隐藏层的节点变少。
通过神经网络的学习,就能够确定核的权重。实际的应用,可能会有多个核,因为有许多的特征要学习。
就像我们之前看到如果要学习图像的轮廓,其实是两个不同的核的组合。
池化层通常是紧跟着卷积的一层,通常是做区域的均值或者最大值操作。如下图:
如下图,池化的策略通常是取最大值或者取均值。
池化的作用类似取样,使得下一层神经网络要处理的数据极大的缩小。减少整个网络的参数,防止出现过拟合。
通常,CNN网络的由如上图所示的层次构成:
在了解的基本的卷积网络的概念后,我们来看看如何在TensorflowJS中实现一个CNN。
下面是模型的代码:
function cnn() {
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense(
{units: 10, kernelInitializer: 'varianceScaling', activation: 'softmax'}));
return model;
}
类似这样一个结构
训练的代码如下:
const model = cnn();
const LEARNING_RATE = 0.15;
const optimizer = tf.train.sgd(LEARNING_RATE);
model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
async function train() {
const BATCH_SIZE = 16;
const TRAIN_BATCHES = 1000;
const TEST_BATCH_SIZE = 100;
const TEST_ITERATION_FREQUENCY = 5;
for (let i = 0; i < TRAIN_BATCHES; i++) {
const batch = data.nextTrainBatch(BATCH_SIZE);
let testBatch;
let validationData;
// Every few batches test the accuracy of the mode.
if (i % TEST_ITERATION_FREQUENCY === 0 && i > 0 ) {
testBatch = data.nextTestBatch(TEST_BATCH_SIZE);
validationData = [
testBatch.xs.reshape([TEST_BATCH_SIZE, 28, 28, 1]), testBatch.labels
];
}
// The entire dataset doesn't fit into memory so we call fit repeatedly
// with batches.
const history = await model.fit(
batch.xs.reshape([BATCH_SIZE, 28, 28, 1]), batch.labels,
{batchSize: BATCH_SIZE, validationData, epochs: 1});
const loss = history.history.loss[0];
const accuracy = history.history.acc[0];
batch.xs.dispose();
batch.labels.dispose();
if (testBatch != null) {
testBatch.xs.dispose();
testBatch.labels.dispose();
}
await tf.nextFrame();
}
}
如果和之前的神经网络的训练的代码比较,这里唯一的变化就是输入数据的形状。
// For CNN
batch.xs.reshape([BATCH_SIZE, 28, 28, 1])
// For NN
batch.xs.reshape([BATCH_SIZE, 784])
这是由两个网络输入层的形状来决定的。
大家在选取模型的时候可以考虑CNN的优缺点。
优点:
缺点:
CNN是非常流行的深度学习的模型,广泛用于图像相关的有关领域,从阿尔法狗到自动驾驶,到处都有他的身影。如果大家希望进一步了解,可以研习下面的文章。
Original url: Access
Created at: 2018-10-18 11:50:19
Category: default
Tags: none
未标明原创文章均为采集,版权归作者所有,转载无需和我联系,请注明原出处,南摩阿彌陀佛,知识,不只知道,要得到
java windows火焰图_mob64ca12ec8020的技术博客_51CTO博客 - 在windows下不可行,不知道作者是怎样搞的 监听SpringBoot 服务启动成功事件并打印信息_监听springboot启动完毕-CSDN博客 SpringBoot中就绪探针和存活探针_management.endpoint.health.probes.enabled-CSDN博客 u2u转换板 - 嘉立创EDA开源硬件平台 Spring Boot 项目的轻量级 HTTP 客户端 retrofit 框架,快来试试它!_Java精选-CSDN博客 手把手教你打造一套最牛的知识笔记管理系统! - 知乎 - 想法有重合-理论可参考 安宇雨 闲鱼 机械键盘 客制化 开贴记录 文本 linux 使用find命令查找包含某字符串的文件_beijihukk的博客-CSDN博客_find 查找字符串 ---- mac 也适用 安宇雨 打字音 记录集合 B站 bilibili 自行搭建 开坑 真正的客制化 安宇雨 黑苹果开坑 查找工具包maven pom 引用地 工具网站 Dantelis 介绍的玩轴入坑攻略 --- 关于轴的一些说法 --- 非官方 ---- 心得而已 --- 长期开坑更新 [本人问题][新开坑位]关于自动化测试的工具与平台应用 机械键盘 开团 网站记录 -- 能做一个收集的程序就好了 不过现在没时间 -- 信息大多是在群里发的 - 你要让垃圾佬 都去一个地方看难度也是很大的 精神支柱 [超级前台]sprinbboot maven superdesk-app 记录 [信息有用] [环境准备] [基本完成] [sebp/elk] 给已创建的Docker容器增加新的端口映射 - qq_30599553的博客 - CSDN博客 [正在研究] Elasticsearch, Logstash, Kibana (ELK) Docker image documentation elasticsearch centos 安装记录 及 启动手记 正式服务器 39 elasticsearch 问题合集 不断更新 6.1.1 | 6.5.1 两个版本 博客程序 - 测试 - bug记录 等等问题 laravel的启动过程解析 - lpfuture - 博客园 OAuth2 Server PHP 用 Laravel 搭建带 OAuth2 验证的 RESTful 服务 | Laravel China 社区 - 高品质的 Laravel 和 PHP 开发者社区 利用Laravel 搭建oauth2 API接口 附 Unauthenticated 解决办法 - 煮茶的博客 - SegmentFault 思否 使用 OAuth2-Server-php 搭建 OAuth2 Server - 午时的海 - 博客园 基于PHP构建OAuth 2.0 服务端 认证平台 - Endv - 博客园 Laravel 的 Artisan 命令行工具 Laravel 的文件系统和云存储功能集成 浅谈Chromium中的设计模式--终--Observer模式 浅谈Chromium中的设计模式--二--pre/post和Delegate模式 浅谈Chromium中的设计模式--一--Chromium中模块分层和进程模型 DeepMind 4 Hacking Yourself README.md update 20211011
Laravel China 简书 知乎 博客园 CSDN博客 开源中国 Go Further Ryan是菜鸟 | LNMP技术栈笔记 云栖社区-阿里云 Netflix技术博客 Techie Delight Linkedin技术博客 Dropbox技术博客 Facebook技术博客 淘宝中间件团队 美团技术博客 360技术博客 古巷博客 - 一个专注于分享的不正常博客 软件测试知识传播 - 测试窝 有赞技术团队 阮一峰 语雀 静觅丨崔庆才的个人博客 软件测试从业者综合能力提升 - isTester IBM Java 开发 使用开放 Java 生态系统开发现代应用程序 pengdai 一个强大的博主 HTML5资源教程 | 分享HTML5开发资源和开发教程 蘑菇博客 - 专注于技术分享的博客平台 个人博客-leapMie 流星007 CSDN博客 - 舍其小伙伴 稀土掘金 Go 技术论坛 | Golang / Go 语言中国知识社区
最新评论