TensorFlow Serving可以将离线训练好的机器学习模型轻松部署到线上,使用gRPC作为接口供外部调用。并且TensorFlow Serving可以支持模型热更新与自动模型版本管理,这可以让算法工作者将工作重心放在离线模型的效果优化上,而不用为线上服务操心。
本文就介绍如何用TensorFlow Serving搭建线性回归预测服务,当然针对这个线性回归任务你可以把训练好的w、b两个参数直接写到代码里,本文只是用简单例子做入门。
下面列出本次实验的环境:
环境准备好后,接下来我们用python写一个训练和保存模型的代码train.py。
模型代码:
import numpy as np
import tensorflow as tf
import tensorflow.contrib.session_bundle.exporter as exporter
# Generate input data
n_samples = 1000
x_data = np.arange(100, step=.1)
y_data = x_data + 20 * np.sin(x_data / 10)
x_data = np.reshape(x_data, (n_samples, 1))
y_data = np.reshape(y_data, (n_samples, 1))
sample = 1000
learning_rate = 0.01
batch_size = 100
n_steps = 500
# Placeholders for batched input
x = tf.placeholder(tf.float32, shape=(batch_size, 1))
y = tf.placeholder(tf.float32, shape=(batch_size, 1))
with tf.variable_scope('test'):
w = tf.get_variable('weights', (1, 1), initializer=tf.random\_normal\_initializer())
b = tf.get_variable('bias', (1,), initializer=tf.constant_initializer(0))
y_pred = tf.matmul(x, w) + b
loss = tf.reduce_sum((y - y_pred) ** 2 / n_samples)
opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
with tf.Session() as sess:
sess.run(tf.initialize\_all\_variables())
for _ in range(n_steps):
indices = np.random.choice(n_samples, batch_size)
x_batch = x_data\[indices\]
y_batch = y_data\[indices\]
_, loss_val = sess.run(\[opt, loss\], feed_dict={x:x_batch, y:y_batch})
print(w.eval())
print(b.eval())
print(loss_val)
saver = tf.train.Saver()
model_exporter = exporter.Exporter(saver)
model_exporter.init(
sess.graph.as\_graph\_def(),
named\_graph\_signatures={
'inputs': exporter.generic_signature({'x': x}),
'outputs': exporter.generic_signature({'y': y_pred})})
model_exporter.export("/tmp/linear-regression/",
tf.constant("1"),
sess)
运行代码:
python train.py
[[1.0108936]]
[1.9290849]
19.266457
上面结果表示w=1.0108936,b=1.9290849,y = 1.0108936 * x + 1.9290849
/tmp/linear-regression/ 目录下有以下文件(我训练了两次,第二次的版本号指定为2,所以有两个文件夹):
.
├── 00000001
│ ├── checkpoint
│ ├── export.data-00000-of-00001
│ ├── export.index
│ └── export.meta
└── 00000002
├── checkpoint
├── export.data-00000-of-00001
├── export.index
└── export.meta
# 下载tensorflow_serving源码
git clone https://github.com/tensorflow/serving.git
# 编译tensorflow_model_server
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
# 启动服务
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=test --model_base_path=/tmp/linear-regression
执行上述命令后,出现以下输出,就表示部署成功了。
Running ModelServer at 0.0.0.0:9000 ...
模型部署完后,接下来我们用java来编写线上请求服务的代码:
ManagedChannel channel = null;
try {
channel = ManagedChannelBuilder.forAddress("服务所部属的ip地址", 9000).usePlaintext(true).build();
PredictionServiceGrpc.PredictionServiceBlockingStub stub =
PredictionServiceGrpc.newBlockingStub(channel);
Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName("test");
predictRequestBuilder.setModelSpec(modelSpecBuilder);
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
List<Float> floatList = new ArrayList<>();
Random random = new Random();
float x = random.nextFloat();
floatList.add(x);
tensorProtoBuilder.addAllFloatVal(floatList);
predictRequestBuilder.putInputs("x", tensorProtoBuilder.build());
Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
LOG.debug("x={}, result={}", x,
predictResponse.getOutputsOrThrow("y").getFloatValList().toString());
} catch (StatusRuntimeException e) {
LOG.error("StatusRuntimeException: ", e);
} finally {
if (channel != null) {
channel.shutdown();
}
}
依赖代码:
<dependency>
<groupId>com.yesup.oss</groupId>
<artifactId>tensorflow-client</artifactId>
<version>1.4-2</version>
</dependency>
请求后:
x=0.033170342, result=[1.9626166]
x=0.90723795, result=[2.846206]
更改train.py里的版本号,比如1改为2,然后执行。
执行完后,模型会进行自动更新。
参数如下:
[[1.0106797]]
[2.5606086]
20.278612
重新请求后:
x=0.8308283, result=[3.40031]
参考了不少博客,这里只列出两篇实用的。
原网址: 访问
创建于: 2019-04-12 01:29:31
目录: default
标签: 无
未标明原创文章均为采集,版权归作者所有,转载无需和我联系,请注明原出处,南摩阿彌陀佛,知识,不只知道,要得到
java windows火焰图_mob64ca12ec8020的技术博客_51CTO博客 - 在windows下不可行,不知道作者是怎样搞的 监听SpringBoot 服务启动成功事件并打印信息_监听springboot启动完毕-CSDN博客 SpringBoot中就绪探针和存活探针_management.endpoint.health.probes.enabled-CSDN博客 u2u转换板 - 嘉立创EDA开源硬件平台 Spring Boot 项目的轻量级 HTTP 客户端 retrofit 框架,快来试试它!_Java精选-CSDN博客 手把手教你打造一套最牛的知识笔记管理系统! - 知乎 - 想法有重合-理论可参考 安宇雨 闲鱼 机械键盘 客制化 开贴记录 文本 linux 使用find命令查找包含某字符串的文件_beijihukk的博客-CSDN博客_find 查找字符串 ---- mac 也适用 安宇雨 打字音 记录集合 B站 bilibili 自行搭建 开坑 真正的客制化 安宇雨 黑苹果开坑 查找工具包maven pom 引用地 工具网站 Dantelis 介绍的玩轴入坑攻略 --- 关于轴的一些说法 --- 非官方 ---- 心得而已 --- 长期开坑更新 [本人问题][新开坑位]关于自动化测试的工具与平台应用 机械键盘 开团 网站记录 -- 能做一个收集的程序就好了 不过现在没时间 -- 信息大多是在群里发的 - 你要让垃圾佬 都去一个地方看难度也是很大的 精神支柱 [超级前台]sprinbboot maven superdesk-app 记录 [信息有用] [环境准备] [基本完成] [sebp/elk] 给已创建的Docker容器增加新的端口映射 - qq_30599553的博客 - CSDN博客 [正在研究] Elasticsearch, Logstash, Kibana (ELK) Docker image documentation elasticsearch centos 安装记录 及 启动手记 正式服务器 39 elasticsearch 问题合集 不断更新 6.1.1 | 6.5.1 两个版本 博客程序 - 测试 - bug记录 等等问题 laravel的启动过程解析 - lpfuture - 博客园 OAuth2 Server PHP 用 Laravel 搭建带 OAuth2 验证的 RESTful 服务 | Laravel China 社区 - 高品质的 Laravel 和 PHP 开发者社区 利用Laravel 搭建oauth2 API接口 附 Unauthenticated 解决办法 - 煮茶的博客 - SegmentFault 思否 使用 OAuth2-Server-php 搭建 OAuth2 Server - 午时的海 - 博客园 基于PHP构建OAuth 2.0 服务端 认证平台 - Endv - 博客园 Laravel 的 Artisan 命令行工具 Laravel 的文件系统和云存储功能集成 浅谈Chromium中的设计模式--终--Observer模式 浅谈Chromium中的设计模式--二--pre/post和Delegate模式 浅谈Chromium中的设计模式--一--Chromium中模块分层和进程模型 DeepMind 4 Hacking Yourself README.md update 20211011
Laravel China 简书 知乎 博客园 CSDN博客 开源中国 Go Further Ryan是菜鸟 | LNMP技术栈笔记 云栖社区-阿里云 Netflix技术博客 Techie Delight Linkedin技术博客 Dropbox技术博客 Facebook技术博客 淘宝中间件团队 美团技术博客 360技术博客 古巷博客 - 一个专注于分享的不正常博客 软件测试知识传播 - 测试窝 有赞技术团队 阮一峰 语雀 静觅丨崔庆才的个人博客 软件测试从业者综合能力提升 - isTester IBM Java 开发 使用开放 Java 生态系统开发现代应用程序 pengdai 一个强大的博主 HTML5资源教程 | 分享HTML5开发资源和开发教程 蘑菇博客 - 专注于技术分享的博客平台 个人博客-leapMie 流星007 CSDN博客 - 舍其小伙伴 稀土掘金 Go 技术论坛 | Golang / Go 语言中国知识社区
最新评论