抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

ControlNet

个人博客 << https://controlnet.space

在一些场合下,我们需要同时运行多个Python程序,并且希望这些Python进程之间能互相通讯,发送一些值或者接收一些值。本文我们就来测试一下Python的跨进程通信不同方案的效率。

本文包含的内容有:HTTP, websocket, multiprocessing, gRPC, RabbitMQ等。

背景介绍

请考虑以下场景,要处理一个数据,我们需要有3步比较耗时的操作,而这个每一步的操作需要上一步的结果,如下图所示。

flowchart LR
    Input --> Step1 --> Step2 --> Step3 --> Output

在这里,有两种使用多进程并行的思路,使用多个进程,每个进程接受一个数据,处理完全部三步之后返回结果,每个进程之间相互独立,如下图所示。

flowchart LR
    subgraph Process3
        direction LR
        Input3 --> p3s1[Step1] --> p3s2[Step2] --> p3s3[Step3] --> Output3
    end
    subgraph Process2
        direction LR
        Input2 --> p2s1[Step1] --> p2s2[Step2] --> p2s3[Step3] --> Output2
    end
    subgraph Process1
        direction LR
        Input1 --> p1s1[Step1] --> p1s2[Step2] --> p1s3[Step3] --> Output1
    end

这种方法操作简单,不需要进程间通信,也容易扩展到更多的进程数,在绝大多数情况下都推荐使用。然而这种模式需要将3个Step的上下文都载入内存中,如果这些Step是占用内存很高的深度学习模型,那么内存将会成为一个严重瓶颈。

为了解决这个问题,我们可以使用另一种模式将其并行。

flowchart LR
Input --> Step1 .-> Step2 .-> Step3 .-> Output
subgraph Process1
    Step1
end
subgraph Process2
    Step2
end
subgraph Process3
    Step3
end

其中这里的虚线表示进程间通信(IPC)。相比于第一种并行方式,这种并行方式操作复杂,需要进程间通信,但是可以有效的减少内存占用。

然而,这种并行相比于第一种方案,需要消耗额外的时间在IPC上,因此我们需要测试一下不同IPC方案的效率。

如果以上图表没有正确渲染,请刷新页面。

运行环境

以下实验都在以下环境运行:

1
2
3
4
5
6
7
8
9
10
CPU: i7-10900X
RAM: 128GB

Python: 3.10.8
websockets: 10.4
fastapi: 0.88.0
grpcio: 1.56.2
pika: 1.3.2
aio-pika: 9.2.0
numpy: 1.23.4

需要通讯的内容为4种不同尺寸的numpy.ndarray[1],数据类型为float64。分别为:

  • (1, 3, 224, 224), 模拟一张图片
  • (2, 1024, 16, 16), 模拟两张图像的低维特征
  • (2, 3, 16, 224, 224), 模拟两个16帧的视频片段
  • (128, 1024, 1024), 模拟128个序列特征

方案介绍

方案1: HTTP + JSON 序列化

这是最简单的方案,使用HTTP协议作为通信协议,将ndarray转换成Python的嵌套List,然后作为json发送。这种方案的优点是实现简单,不需要额外的依赖,缺点是从ndarrayList互相转换的开销大,而且json序列化的开销也很大。

这里HTTP通过fastapi[2]实现,fastapi是一个高性能的异步框架,可以很好的支持大量的并发请求。

方案2: HTTP + Base64 Bytes

这种方案和方案1类似,不过将ndarray通过numpy内置的方法转换成bytes,然后使用base64编码,这样可以避免ndarrayList之间的转换,但是HTTP传输大规模的base64编码的开销也很大。

方案3: Websocket + Bytes

这种方案和方案2类似,不过使用Websocket作为通信协议,这样可以避免HTTP的开销。因为Websocket可以发送ascii之外的字节,所以不需要base64编码。

方案4: HTTP + Shared Memory

这里采用了multiprocessing.shared_memory模块,使用SharedMemory对象将ndarray的地址共享给子进程。然后将SharedMemory对象名字作为HTTP的返回值,客户端再通过名字获取SharedMemory对象,这样可以避免ndarrayList之间的转换,也避免了base64编码的开销。

方案5: Websocket + Shared Memory

这种方案和方案4类似,不过使用Websocket作为通信协议,这样可以避免HTTP的开销。

方案6: Multiprocessing Listener / Client

这种方案使用multiprocessing模块的ListenerClient对象,使用multiprocessingPipe作为通信协议,这样可以避免HTTP的开销。

方案7: gRPC + Bytes

这种方案使用gRPC[3]作为通信协议,使用protobuf作为序列化协议,好处是方便客户端进行调用,但是gRPC有最大的消息长度限制(2GB)。

方案8: RabbitMQ + Bytes

这种方案使用RabbitMQ[4]作为通信协议,使用pika[5]作为Python的客户端和服务端。这种方案的好处是可以使用RabbitMQ的其他特性,比如消息队列,消息持久化等,但是有最大的消息长度限制(512MB)。

测试结果

ndarray Shape (1, 3, 224, 224) (2, 1024, 16, 16) (2, 3, 16, 224, 224) (128, 1024, 1024)
HTTP + JSON 290.00 ms 1090 ms 9230 ms 259.00 s
HTTP + Base64 Bytes 26.40 ms 51.5 ms 398 ms 12.30 s
Websocket + Bytes 4.27 ms 15.0 ms 171 ms 5.14 s
HTTP +
Shared Memory
10.70 ms 18.1 ms 127 ms 3.13 s
Websocket +
Shared Memory
4.34 ms 14.9 ms 127 ms 3.82 s
Multiprocessing
Listener
7.00 ms 17.2 ms 162 ms 4.73 s
gRPC + Bytes 7.34 ms 28.6 ms 291 ms 7.92 s
RabbitMQ + Bytes 9.35 ms 25.7 ms 243 ms 超出消息长度

根据这个结果我们可以发现,方案4和方案5的性能是最好的,方案6的性能也很好,方案1和方案2的性能最差。

考虑网络传输协议,Websocket的性能是比HTTP好的。所以应该尽量使用Websocket作为网络传输协议。

考虑使用Base64还是Shared Memory,我们可以发现大数据的情况下,Shared Memory的性能是比较好的,但是它需要手动管理内存,可能会有一些问题。所以对于小数据,可以使用Base64,对于大数据,可以使用Shared Memory。

对于multiprocessing模块的ListenerClient,它的性能略弱于Shared Memory,但是它不需要手动管理共享内存,而且它不需要用fastapi之类的外部库,而且不需要转换成别的类型的数据,比较方便。但是正因为它没有使用fastapi,以至于它不是很方便的进行异步处理。

gRPCRabbitMQ的性能比较差,只比HTTP好一点点,比不上Websocket,所以不推荐使用。而且它们有最大的消息大小限制,所以传输数据时不太方便。

在写以下比较重复的代码实现的时候,Github Copilot[6]起到了很大的帮助。

代码实现

方案1: HTTP + JSON 序列化

1
2
3
4
5
6
7
8
9
10
# server.py
import numpy as np
from fastapi import FastAPI

app = FastAPI()
shape = (1, 3, 224, 224)

@app.get("/random_tolist")
async def random_tolist():
return np.random.randn(*shape).tolist()
1
2
3
4
5
6
7
# client.py
import requests
import numpy as np

result = requests.get("http://127.0.0.1:1234/random_tolist")
result = np.array(result.json())
print(result.shape)

方案2: HTTP + Base64 Bytes

1
2
3
4
5
6
7
8
9
10
11
12
# server.py
import numpy as np
import base64
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse

app = FastAPI()
shape = (1, 3, 224, 224)

@app.get("/random_tobytes", response_class=PlainTextResponse)
async def random_tobytes():
return base64.b64encode(np.random.randn(*shape).tobytes())
1
2
3
4
5
6
7
8
# client.py
import requests
import numpy as np
import base64

result = requests.get("http://127.0.0.1:1234/random_tobytes")
result = np.frombuffer(base64.b64decode(result.text)).reshape(*shape)
print(result.shape)

方案3: Websocket + Bytes

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# server.py
import numpy as np
from fastapi import FastAPI, WebSocket

app = FastAPI()
shape = (1, 3, 224, 224)

@app.websocket("/ws/random_tobytes")
async def websocket_random_tobytes(websocket: WebSocket):
await websocket.accept()
while True:
await websocket.receive_text()
print("Processing websocket")
await websocket.send_text(np.random.randn(*shape).tobytes())

1
2
3
4
5
6
7
8
9
# client.py
import numpy as np
from websocket import create_connection
ws = create_connection("ws://127.0.0.1:1234/ws/random_tobytes")

ws.send("")
result = np.frombuffer(ws.recv()).reshape(*shape)
print(result.shape)

方案4: HTTP + Shared Memory

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# server.py
import numpy as np
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse
from multiprocessing import shared_memory

app = FastAPI()
shape = (1, 3, 224, 224)

@app.get("/random_sharedmemory", response_class=PlainTextResponse)
async def random_sharedmemory():
arr = np.random.randn(*shape)
shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
out = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
out[:] = arr[:]
shm.close()
return shm.name
1
2
3
4
5
6
7
8
9
10
11
12
13
# client.py
import numpy as np
import requests
from multiprocessing import shared_memory

shm_name = requests.get("http://127.0.0.1:1234/random_sharedmemory")
shm = shared_memory.SharedMemory(name=shm_name.text)
result = np.ndarray(shape, dtype=float, buffer=shm.buf)

shm.close()
shm.unlink()

print(result.shape)

方案5: Websocket + Shared Memory

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# server.py
import numpy as np
from fastapi import FastAPI, WebSocket
from multiprocessing import shared_memory

app = FastAPI()
shape = (1, 3, 224, 224)

@app.websocket("/ws/random_sharedmemory")
async def websocket_random_sharedmemory(websocket: WebSocket):
await websocket.accept()
while True:
await websocket.receive_text()
print("Processing websocket")
arr = np.random.randn(*shape)
shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
out = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
out[:] = arr[:]
shm.close()
print(shm.name)
await websocket.send_text(shm.name)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# client.py
import numpy as np
from websocket import create_connection
from multiprocessing import shared_memory

ws = create_connection("ws://127.0.0.1:1234/ws/random_sharedmemory")
ws.send("")
shm_name = ws.recv()
shm = shared_memory.SharedMemory(name=shm_name)
result = np.ndarray(shape, dtype=float, buffer=shm.buf)

shm.close()
shm.unlink()

print(result.shape)

方案6: Multiprocessing Listener / Client

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# server.py
from multiprocessing.connection import Listener
import numpy as np

shape = (1, 3, 224, 224)

address = ('localhost', 1234)
listener = Listener(address)
conn = listener.accept()

while True:
msg = conn.recv()
print("Processing")
conn.send(np.random.randn(*shape))

listener.close()
1
2
3
4
5
6
# client.py
from multiprocessing.connection import Client
conn = Client(("localhost", 1234))
conn.send("")
result = conn.recv()
print(result.shape)

方案7: gRPC + Bytes

1
2
3
4
5
6
7
8
9
10
11
// npy.proto
syntax = "proto3";
import "google/protobuf/empty.proto";

service Npy {
rpc Get (google.protobuf.Empty) returns (ArrayData) {}
}

message ArrayData {
bytes body = 1;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# server.py
from concurrent import futures
import time

import grpc

import npy_pb2
import npy_pb2_grpc
import numpy as np

shape = (1, 3, 224, 224)

class Npy(npy_pb2_grpc.NpyServicer):
def Get(self, request, context):
arr = np.random.randn(*shape)
return npy_pb2.ArrayData(body=np.ndarray.tobytes(arr))


def serve():
port = "50051"
server = grpc.server(futures.ThreadPoolExecutor(max_workers=8),
options=[
('grpc.max_send_message_length', 2 * 1024**3 - 1),
('grpc.max_receive_message_length', 2 * 1024**3 - 1),
])
npy_pb2_grpc.add_NpyServicer_to_server(Npy(), server)
server.add_insecure_port("[::]:" + port)
server.start()
print("Server started, listening on " + port)
server.wait_for_termination()


if __name__ == '__main__':
serve()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# client.py
from __future__ import print_function

import grpc
import numpy as np

import npy_pb2
import npy_pb2_grpc

shape = (1, 3, 224, 224)
channel = grpc.insecure_channel('localhost:50051',
options=[
('grpc.max_send_message_length', 2 * 1024**3 - 1),
('grpc.max_receive_message_length', 2 * 1024**3 - 1),
]
)
stub = npy_pb2_grpc.NpyStub(channel)
response = stub.Get(npy_pb2.google_dot_protobuf_dot_empty__pb2.Empty())
arr = np.frombuffer(response.body, dtype=np.float64).reshape(shape)
print(arr.shape)

方案8: RabbitMQ + Bytes

部署RabbitMQ的Docker容器:

1
docker run --name some-rabbit -p 5672:5672 rabbitmq:3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# server.py
import pika
import numpy as np

connection = pika.BlockingConnection(
pika.ConnectionParameters(host='localhost'))

channel = connection.channel()
channel.queue_declare(queue='rpc_queue')

shape = (1, 3, 224, 224)

def on_request(ch, method, props, body):

arr = np.random.randn(*shape)
data = np.ndarray.tobytes(arr)

ch.basic_publish(exchange='',
routing_key=props.reply_to,
properties=pika.BasicProperties(correlation_id=props.correlation_id),
body=data)
ch.basic_ack(delivery_tag=method.delivery_tag)


channel.basic_qos(prefetch_count=1)
channel.basic_consume(queue='rpc_queue', on_message_callback=on_request)

print(" [x] Awaiting RPC requests")
channel.start_consuming()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# client.py
import pika
import uuid

connection = pika.BlockingConnection(pika.ConnectionParameters(host='localhost'))
channel = connection.channel()
result = channel.queue_declare(queue='', exclusive=True)
callback_queue = result.method.queue

response = None
corr_id = None

shape = (128, 1024, 1024)

def on_response(ch, method, props, body):
global response
if corr_id == props.correlation_id:
response = body

channel.basic_consume(
queue=callback_queue,
on_message_callback=on_response,
auto_ack=True
)

def call():
global response
global corr_id
response = None
corr_id = str(uuid.uuid4())
channel.basic_publish(
exchange='',
routing_key='rpc_queue',
properties=pika.BasicProperties(
reply_to=callback_queue,
correlation_id=corr_id,
),
body=""
)
connection.process_data_events(time_limit=None)
return response

response = call()
arr = np.frombuffer(response, dtype=np.float64).reshape(shape)
print(arr.shape)

参考文献

  • [1] "NumPy", Numpy.org, 2022. https://numpy.org/.
  • [2] "FastAPI", FastAPI, 2022. https://fastapi.tiangolo.com/
  • [3] "gRPC" gRPC. https://grpc.io/
  • [4] “Messaging that just works — RabbitMQ,” Rabbitmq.com, 2019. https://www.rabbitmq.com/
  • [5] pika, “Pika,” GitHub, Jul. 28, 2023. https://github.com/pika/pika (accessed Jul. 28, 2023).
  • [6] "GitHub Copilot · Your AI pair programmer," GitHub, 2022. https://github.com/features/copilot

评论