在上一次的博客中,我们介绍了如果实现一个最简单的线性回归的模型,今天我们来看一下,如何利用同样的思路实现更多的模型。
逻辑回归并非只能实现二分类,我们下面就看一个利用逻辑回归(Multinomial logistic regression)实现多分类的例子。
这个是训练数据:
这个是分类的结果。我们可以看到对某些点,蓝色和橙色,分类效果比较好;而对于绿色和红色的点,分类的结果不是很理想。
代码在这里:
function logistic_regression(train_data, train_label) {
const numIterations = 100;
const learningRate = 0.1;
const optimizer = tf.train.adam(learningRate);
//Caculate how many category do we have
const number_of_labels = Array.from(new Set(train_label)).length;
const number_of_data = train_label.length;
const w = tf.variable(tf.zeros([2,number_of_labels]));
const b = tf.variable(tf.zeros([number_of_labels]));
const train_x = tf.tensor2d(train_data);
const train_y = tf.tensor1d(train_label);
function predict(x) {
return tf.softmax(tf.add(tf.matMul(x, w),b));
}
function loss(predictions, labels) {
const y = tf.oneHot(labels,number_of_labels);
const entropy = tf.mean(tf.sub(tf.scalar(1),tf.sum(tf.mul(y, tf.log(predictions)),1)));
return entropy;
}
for (let iter = 0; iter < numIterations; iter++) {
optimizer.minimize(() => {
const loss_var = loss(predict(train_x), train_y);
loss_var.print();
return loss_var;
})
}
}
逻辑回归和之前的线性回归的过程基本类似,有几个要注意的地方:
训练数据
预测模型, 对于softmax这个模型,简单说,就是二元逻辑回归向更多元素向量的扩展。有兴趣进一步了解的可以去看下面的这两篇文章:
损失,损失函数用交叉熵,可以参考这篇文章。
这里在计算损失的时候,对于lable,调用tf.oneHot()方法,把对Label数据变换为以下形式
// Labels
Tensor
[0, 0, 1, 1, 2, 2]
// OneHot
Tensor
[[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 1]]
K近邻是一个特别简单的算法,简单到没有训练的过程。大家在我的另一篇关于机器学习的博客里,可以找到对这个算法的可视化介绍。
利用TensorflowJS也可以实现该算法。下面的代码使用L1距离来实现近邻算法(k=1)
function knn(train_data,train_label) {
const train_x = tf.tensor2d(train_data);
return function(x) {
var result = [];
x.map(function(point){
const input_tensor = tf.tensor1d(point);
const distance = tf.sum(tf.abs(tf.sub(input_tensor, train_x)),1);
const index = tf.argMin(distance, 0);
result.push(train_label[index.dataSync()[0]]);
});
return result;
};
}
这个算法虽然简单,但是计算量不小,预测的效果似乎也不比逻辑回归差呢。我们注意类似的数据和逻辑回归的差异。
同样的,对于一般的学习问题,TensorflowJS自然是不在话下。参考笔者的另一篇博客:一个利用Tensorflow求解几何问题的例子。
代码如下:
function train(train_data) {
const numIterations = 200;
const learningRate = 0.05;
const optimizer = tf.train.sgd(learningRate);
const training_data = tf.tensor2d(train_data);
const center = tf.variable(tf.tensor1d([Math.random()* Math.floor(domain_max),Math.random()* Math.floor(domain_max)]));
// Caculate the distance of this center point to the each point in the training data
const distance = function() {
return tf.pow(tf.sum(tf.pow(tf.sub(training_data, center),tf.scalar(2)),1),tf.scalar(1/2));
}
// Mean Square Error
const loss = function(dis) {
return tf.sum(tf.pow(tf.sub(dis,tf.mean(dis)),tf.scalar(2)));
}
for (let iter = 0; iter < numIterations; iter++) {
var result = {};
optimizer.minimize(() => {
const loss_var = loss(distance());
loss_var.print();
result.loss = loss_var.dataSync();
return loss_var;
})
}
return center;
}
运行效果如下:
Original url: Access
Created at: 2018-10-18 11:59:04
Category: default
Tags: none
未标明原创文章均为采集,版权归作者所有,转载无需和我联系,请注明原出处,南摩阿彌陀佛,知识,不只知道,要得到
java windows火焰图_mob64ca12ec8020的技术博客_51CTO博客 - 在windows下不可行,不知道作者是怎样搞的 监听SpringBoot 服务启动成功事件并打印信息_监听springboot启动完毕-CSDN博客 SpringBoot中就绪探针和存活探针_management.endpoint.health.probes.enabled-CSDN博客 u2u转换板 - 嘉立创EDA开源硬件平台 Spring Boot 项目的轻量级 HTTP 客户端 retrofit 框架,快来试试它!_Java精选-CSDN博客 手把手教你打造一套最牛的知识笔记管理系统! - 知乎 - 想法有重合-理论可参考 安宇雨 闲鱼 机械键盘 客制化 开贴记录 文本 linux 使用find命令查找包含某字符串的文件_beijihukk的博客-CSDN博客_find 查找字符串 ---- mac 也适用 安宇雨 打字音 记录集合 B站 bilibili 自行搭建 开坑 真正的客制化 安宇雨 黑苹果开坑 查找工具包maven pom 引用地 工具网站 Dantelis 介绍的玩轴入坑攻略 --- 关于轴的一些说法 --- 非官方 ---- 心得而已 --- 长期开坑更新 [本人问题][新开坑位]关于自动化测试的工具与平台应用 机械键盘 开团 网站记录 -- 能做一个收集的程序就好了 不过现在没时间 -- 信息大多是在群里发的 - 你要让垃圾佬 都去一个地方看难度也是很大的 精神支柱 [超级前台]sprinbboot maven superdesk-app 记录 [信息有用] [环境准备] [基本完成] [sebp/elk] 给已创建的Docker容器增加新的端口映射 - qq_30599553的博客 - CSDN博客 [正在研究] Elasticsearch, Logstash, Kibana (ELK) Docker image documentation elasticsearch centos 安装记录 及 启动手记 正式服务器 39 elasticsearch 问题合集 不断更新 6.1.1 | 6.5.1 两个版本 博客程序 - 测试 - bug记录 等等问题 laravel的启动过程解析 - lpfuture - 博客园 OAuth2 Server PHP 用 Laravel 搭建带 OAuth2 验证的 RESTful 服务 | Laravel China 社区 - 高品质的 Laravel 和 PHP 开发者社区 利用Laravel 搭建oauth2 API接口 附 Unauthenticated 解决办法 - 煮茶的博客 - SegmentFault 思否 使用 OAuth2-Server-php 搭建 OAuth2 Server - 午时的海 - 博客园 基于PHP构建OAuth 2.0 服务端 认证平台 - Endv - 博客园 Laravel 的 Artisan 命令行工具 Laravel 的文件系统和云存储功能集成 浅谈Chromium中的设计模式--终--Observer模式 浅谈Chromium中的设计模式--二--pre/post和Delegate模式 浅谈Chromium中的设计模式--一--Chromium中模块分层和进程模型 DeepMind 4 Hacking Yourself README.md update 20211011
Laravel China 简书 知乎 博客园 CSDN博客 开源中国 Go Further Ryan是菜鸟 | LNMP技术栈笔记 云栖社区-阿里云 Netflix技术博客 Techie Delight Linkedin技术博客 Dropbox技术博客 Facebook技术博客 淘宝中间件团队 美团技术博客 360技术博客 古巷博客 - 一个专注于分享的不正常博客 软件测试知识传播 - 测试窝 有赞技术团队 阮一峰 语雀 静觅丨崔庆才的个人博客 软件测试从业者综合能力提升 - isTester IBM Java 开发 使用开放 Java 生态系统开发现代应用程序 pengdai 一个强大的博主 HTML5资源教程 | 分享HTML5开发资源和开发教程 蘑菇博客 - 专注于技术分享的博客平台 个人博客-leapMie 流星007 CSDN博客 - 舍其小伙伴 稀土掘金 Go 技术论坛 | Golang / Go 语言中国知识社区
最新评论