在利用 kashgari
训练 BERT+BiLSTM+CRF
模型后,如何将模型预测结果进行部署是一个很重要的问题。按照 kashgari的官方文档 介绍,考虑采用 Tensorflow serving
来做模型预测的部署。
TensorFlow Serving 是一个用于机器学习模型 serving 的高性能开源库。它可以将训练好的机器学习模型部署到线上,使用 gRPC 作为接口接受外部调用。更加让人眼前一亮的是,它支持模型热更新与自动模型版本管理。这意味着一旦部署 TensorFlow Serving 后,就再也不需要为线上服务操心,只需要关心线下的模型训练。
TensorFlow Serving 可以方便我们部署 TensorFlow 模型,可以使用 TensorFlow Serving 的 Docker 镜像来使用 TensorFlow Serving ,安装命令如下:
docker pull tensorflow/serving
接下来将演示如何利用 tensorflow serving 来部署 kashgari 中的模型,项目结构如下:
上图中的 data 是标注的 NER 数据集,即标注出文本中的时间,采用 BIO 标注规则。chinese_wwm_ext 文件夹为哈工大的预训练模型文件。model_train.py 为模型训练的代码,主要功能是完成时间序列标注模型的训练,完整的代码如下:
# -*- coding: utf-8 -*-
# time: 2019-09-12
# place: Huangcun Beijing
import kashgari
from kashgari import utils
from kashgari.corpus import DataReader
from kashgari.embeddings import BERTEmbedding
from kashgari.tasks.labeling import BiLSTM_CRF_Model
# 模型训练
train_x, train_y = DataReader().read_conll_format_file('./data/time.train')
valid_x, valid_y = DataReader().read_conll_format_file('./data/time.dev')
test_x, test_y = DataReader().read_conll_format_file('./data/time.test')
bert_embedding = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12',
task=kashgari.LABELING,
sequence_length=128)
model = BiLSTM_CRF_Model(bert_embedding)
model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=1)
# Save model
utils.convert_to_saved_model(model,
model_path='saved_model/time_entity',
version=1)
运行该代码,模型训练完后会生成saved_model文件夹,里面含有模型训练好后的文件,方便利用 tensorflow/serving
进行部署。接着利用 tensorflow/serving
来完成模型的部署,命令如下:
docker run -t --rm -p 8501:8501 -v "/Users/jclian/PycharmProjects/kashgari_tf_serving/saved_model:/models/" -e MODEL_NAME=time_entity tensorflow/serving
需要注意该模型所在的路径,路径需要写完整路径,以及模型的名称(MODEL_NAME),这在训练代码(train.py)中已经给出(saved_model/time_entity)。接下来使用 tornado 来搭建 HTTP 服务,实现高并发地模型预测, runServer.py 的完整代码如下:
# -*- coding: utf-8 -*-
import requests
from kashgari import utils
import numpy as np
from model_predict import get_predict
import json
import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options
import traceback
# tornado高并发
import tornado.web
import tornado.gen
import tornado.concurrent
from concurrent.futures import ThreadPoolExecutor
# 定义端口为12333
define("port", default=16016, help="run on the given port", type=int)
# 模型预测
class ModelPredictHandler(tornado.web.RequestHandler):
executor = ThreadPoolExecutor(max_workers=5)
# get 函数
@tornado.gen.coroutine
def get(self):
origin_text = self.get_argument('text')
result = yield self.function(origin_text)
self.write(json.dumps(result, ensure_ascii=False))
@tornado.concurrent.run_on_executor
def function(self, text):
try:
text = text.replace(' ', '')
x = [_ for _ in text]
# Pre-processor data
processor = utils.load_processor(model_path='saved_model/time_entity/1')
tensor = processor.process_x_dataset([x])
# only for bert Embedding
tensor = [{
"Input-Token:0": i.tolist(),
"Input-Segment:0": np.zeros(i.shape).tolist()
} for i in tensor]
# predict
r = requests.post("http://localhost:8501/v1/models/time_entity:predict", json={"instances": tensor})
preds = r.json()['predictions']
# Convert result back to labels
labels = processor.reverse_numerize_label_sequences(np.array(preds).argmax(-1))
entities = get_predict('TIME', text, labels[0])
return entities
except Exception:
self.write(traceback.format_exc().replace('\n', '<br>'))
# get请求
class HelloHandler(tornado.web.RequestHandler):
def get(self):
self.write('Hello from lmj from Daxing Beijing!')
# 主函数
def main():
# 开启tornado服务
tornado.options.parse_command_line()
# 定义app
app = tornado.web.Application(
handlers=[(r'/model_predict', ModelPredictHandler),
(r'/hello', HelloHandler),
], #网页路径控制
)
http_server = tornado.httpserver.HTTPServer(app)
http_server.listen(options.port)
tornado.ioloop.IOLoop.instance().start()
main()
定义了 tornado 封装 HTTP 服务来进行模型预测,运行该脚本,启动模型预测的 HTTP 服务。接着再使用 Python 脚本测试模型的预测效果以及预测时间,预测的代码脚本的完整代码如下:
import time
import json
import requests
t1 = time.time()
texts = ['据《新闻联播》报道,9月9日至11日,中央纪委书记赵乐际到河北调研。',
'记者从国家发展改革委、商务部相关方面获悉,日前美方已决定对拟于10月1日实施的中国输美商品加征关税措施做出调整,中方支持相关企业从即日起按照市场化原则和WTO规则,自美采购一定数量大豆、猪肉等农产品,国务院关税税则委员会将对上述采购予以加征关税排除。',
'据印度Zee新闻网站12日报道,亚洲新闻国际通讯社援引印度军方消息人士的话说,9月11日的对峙事件发生在靠近班公错北岸的实际控制线一带。',
'儋州市决定,从9月开始,对城市低保、农村低保、特困供养人员、优抚对象、领取失业保险金人员、建档立卡未脱贫人口等低收入群体共3万多人,发放猪肉价格补贴,每人每月发放不低于100元补贴,以后发放标准,将根据猪肉价波动情况进行动态调整。',
'9月11日,华为心声社区发布美国经济学家托马斯.弗里德曼在《纽约时报》上的专栏内容,弗里德曼透露,在与华为创始人任正非最近一次采访中,任正非表示华为愿意与美国司法部展开话题不设限的讨论。',
'造血干细胞移植治疗白血病技术已日益成熟,然而,通过该方法同时治愈艾滋病目前还是一道全球尚在攻克的难题。',
'英国航空事故调查局(AAIB)近日披露,今年2月6日一趟由德国法兰克福飞往墨西哥坎昆的航班上,因飞行员打翻咖啡使操作面板冒烟,导致飞机折返迫降爱尔兰。',
'当地时间周四(9月12日),印度尼西亚财政部长英卓华(Sri Mulyani Indrawati)明确表示:特朗普的推特是风险之一。',
'华中科技大学9月12日通过其官方网站发布通报称,9月2日,我校一硕士研究生不幸坠楼身亡。',
'微博用户@ooooviki 9月12日下午公布发生在自己身上的惊悚遭遇:一个自称网警、名叫郑洋的人利用职务之便,查到她的完备的个人信息,包括但不限于身份证号、家庭地址、电话号码、户籍变动情况等,要求她做他女朋友。',
'今天,贵阳取消了汽车限购,成为目前全国实行限购政策的9个省市中,首个取消限购的城市。',
'据悉,与全球同步,中国区此次将于9月13日于iPhone官方渠道和京东正式开启预售,京东成Apple中国区唯一官方授权预售渠道。',
'根据央行公布的数据,截至2019年6月末,存款类金融机构住户部门短期消费贷款规模为9.11万亿元,2019年上半年该项净增3293.19亿元,上半年增量看起来并不乐观。',
'9月11日,一段拍摄浙江万里学院学生食堂的视频走红网络,视频显示该学校食堂不仅在用餐区域设置了可以看电影、比赛的大屏幕,还推出了“一人食”餐位。',
'当日,在北京举行的2019年国际篮联篮球世界杯半决赛中,西班牙队对阵澳大利亚队。',
]
print(len(texts))
for text in texts:
url = 'http://localhost:16016/model_predict?text=%s' % text
req = requests.get(url)
print(json.loads(req.content))
t2 = time.time()
print(round(t2-t1, 4))
运行该代码,输出的结果如下:
一共预测15个句子。
['9月9日至11日']
['日前', '10月1日', '即日']
['12日', '9月11日']
['9月']
['9月11日']
[]
['近日', '今年2月6日']
['当地时间周四(9月12日)']
['9月12日', '9月2日']
['9月12日下午']
['今天', '目前']
['9月13日']
['2019年6月末', '2019年上半年', '上半年']
['9月11日']
['当日', '2019年']
预测耗时: 15.1085s.
模型预测效果不错,但平均每句话的预测时间为1秒多,稍微偏长,有待优化模型以缩短预测时间。
参考资料: