使用管道为Web服务器
关键是要理解我们可以使用迭代器,就像你在数据集上使用一样,因为网络服务器基本上是一个等待请求并处理它们的系统。
通常,Web服务器是多路复用的(多线程、异步等)以同时处理各种请求。另一方面,管道(以及大多数底层模型)并不真正适合并行处理;它们占用大量内存,因此在运行或执行计算密集型任务时,最好为它们提供所有可用的资源。
我们将通过让网络服务器处理接收和发送请求的轻负载,并让一个线程处理实际工作来解决这个问题。这个例子将使用starlette
。实际使用的框架并不重要,但如果你使用另一个框架来实现相同的效果,你可能需要调整或更改代码。
创建 server.py
:
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import pipeline
import asyncio
async def homepage(request):
payload = await request.body()
string = payload.decode("utf-8")
response_q = asyncio.Queue()
await request.app.model_queue.put((string, response_q))
output = await response_q.get()
return JSONResponse(output)
async def server_loop(q):
pipe = pipeline(model="google-bert/bert-base-uncased")
while True:
(string, response_q) = await q.get()
out = pipe(string)
await response_q.put(out)
app = Starlette(
routes=[
Route("/", homepage, methods=["POST"]),
],
)
@app.on_event("startup")
async def startup_event():
q = asyncio.Queue()
app.model_queue = q
asyncio.create_task(server_loop(q))
现在你可以开始使用:
uvicorn server:app
你可以查询它:
curl -X POST -d "test [MASK]" http://localhost:8000/
#[{"score":0.7742936015129089,"token":1012,"token_str":".","sequence":"test."},...]
好了,现在你对如何创建一个网络服务器有了很好的了解!
真正重要的是我们只加载模型一次,这样在Web服务器上就不会有模型的副本。这样就不会使用不必要的RAM。然后,排队机制允许你做一些高级操作,比如在推断之前积累一些项目以使用动态批处理:
下面的代码示例故意写成伪代码以提高可读性。 在检查是否适合您的系统资源之前,请不要运行此代码!
(string, rq) = await q.get()
strings = []
queues = []
while True:
try:
(string, rq) = await asyncio.wait_for(q.get(), timeout=0.001) # 1ms
except asyncio.exceptions.TimeoutError:
break
strings.append(string)
queues.append(rq)
strings
outs = pipe(strings, batch_size=len(strings))
for rq, out in zip(queues, outs):
await rq.put(out)
再次强调,所提出的代码是为了可读性而优化的,而不是为了成为最佳代码。 首先,没有批量大小限制,这通常不是一个好主意。其次,每次队列获取时都会重置超时,这意味着在运行推理之前,您可能会等待超过1毫秒(从而延迟第一个请求)。
最好有一个单一的1ms截止时间。
即使队列为空,这也总是会等待1毫秒,这可能不是最好的,因为如果队列中没有内容,你可能希望开始进行推理。 但如果批处理对你的用例确实非常关键,那么这可能是有意义的。 再次强调,确实没有一个最佳的解决方案。
你可能需要考虑的一些事情
错误检查
在生产环境中可能会出现很多问题:内存不足、空间不足、加载模型可能失败、查询可能出错、查询可能正确但由于模型配置错误仍然无法运行,等等。
通常来说,如果服务器能将错误输出给用户,这是很好的做法,因此添加大量的try..except
语句来显示这些错误是一个好主意。但请记住,根据您的安全上下文,揭示所有这些错误也可能是一个安全风险。
断路器
Web服务器在进行熔断时通常表现更好。这意味着当它们过载时,它们会返回适当的错误,而不是无限期地等待查询。返回503错误,而不是等待很长时间或长时间后返回504错误。
这在提议的代码中相对容易实现,因为只有一个队列。 查看队列大小是在你的网络服务器在负载下失败之前开始返回错误的基本方法。
阻塞主线程
目前 PyTorch 不支持异步操作,计算会在运行时阻塞主线程。这意味着如果 PyTorch 被强制在其自己的线程/进程中运行会更好。这里没有这样做是因为代码会变得更加复杂(主要是因为线程、异步和队列不能很好地协同工作)。但最终它做的事情是一样的。
如果单个项目的推断时间较长(> 1秒),这将非常重要,因为在这种情况下,意味着在推断期间的每个查询都必须等待1秒才能收到错误。
动态批处理
一般来说,批处理并不一定比一次传递1个项目更好(有关更多信息,请参见批处理详细信息)。但在正确的设置下,它可以非常有效。在API中,默认情况下没有动态批处理(太多可能导致减速的机会)。但对于BLOOM推理——这是一个非常大的模型——动态批处理对于为每个人提供良好的体验是至关重要的。
< > Update on GitHub