在一些场合下,我们需要同时运行多个Python程序,并且希望这些Python进程之间能互相通讯,发送一些值或者接收一些值。本文我们就来测试一下Python的跨进程通信不同方案的效率。
本文包含的内容有:HTTP, websocket, multiprocessing, gRPC, RabbitMQ等。
背景介绍 请考虑以下场景,要处理一个数据,我们需要有3步比较耗时的操作,而这个每一步的操作需要上一步的结果,如下图所示。
flowchart LR
Input --> Step1 --> Step2 --> Step3 --> Output
在这里,有两种使用多进程并行的思路,使用多个进程,每个进程接受一个数据,处理完全部三步之后返回结果,每个进程之间相互独立,如下图所示。
flowchart LR
subgraph Process3
direction LR
Input3 --> p3s1[Step1] --> p3s2[Step2] --> p3s3[Step3] --> Output3
end
subgraph Process2
direction LR
Input2 --> p2s1[Step1] --> p2s2[Step2] --> p2s3[Step3] --> Output2
end
subgraph Process1
direction LR
Input1 --> p1s1[Step1] --> p1s2[Step2] --> p1s3[Step3] --> Output1
end
这种方法操作简单,不需要进程间通信,也容易扩展到更多的进程数,在绝大多数情况下都推荐使用。然而这种模式需要将3个Step的上下文都载入内存中,如果这些Step是占用内存很高的深度学习模型,那么内存将会成为一个严重瓶颈。
为了解决这个问题,我们可以使用另一种模式将其并行。
flowchart LR
Input --> Step1 .-> Step2 .-> Step3 .-> Output
subgraph Process1
Step1
end
subgraph Process2
Step2
end
subgraph Process3
Step3
end
其中这里的虚线表示进程间通信(IPC)。相比于第一种并行方式,这种并行方式操作复杂,需要进程间通信,但是可以有效的减少内存占用。
然而,这种并行相比于第一种方案,需要消耗额外的时间在IPC上,因此我们需要测试一下不同IPC方案的效率。
运行环境 以下实验都在以下环境运行:
1 2 3 4 5 6 7 8 9 10 CPU: i7-10900X RAM: 128GB Python: 3.10.8 websockets: 10.4 fastapi: 0.88.0 grpcio: 1.56.2 pika: 1.3.2 aio-pika: 9.2.0 numpy: 1.23.4
需要通讯的内容为4种不同尺寸的numpy.ndarray
[1] ,数据类型为float64
。分别为:
(1, 3, 224, 224), 模拟一张图片
(2, 1024, 16, 16), 模拟两张图像的低维特征
(2, 3, 16, 224, 224), 模拟两个16帧的视频片段
(128, 1024, 1024), 模拟128个序列特征
方案介绍 方案1: HTTP + JSON 序列化 这是最简单的方案,使用HTTP协议作为通信协议,将ndarray
转换成Python的嵌套List
,然后作为json发送。这种方案的优点是实现简单,不需要额外的依赖,缺点是从ndarray
和List
互相转换的开销大,而且json序列化的开销也很大。
这里HTTP通过fastapi
[2] 实现,fastapi
是一个高性能的异步框架,可以很好的支持大量的并发请求。
方案2: HTTP + Base64 Bytes 这种方案和方案1类似,不过将ndarray
通过numpy内置的方法转换成bytes
,然后使用base64
编码,这样可以避免ndarray
和List
之间的转换,但是HTTP传输大规模的base64
编码的开销也很大。
方案3: Websocket + Bytes 这种方案和方案2类似,不过使用Websocket作为通信协议,这样可以避免HTTP的开销。因为Websocket可以发送ascii之外的字节,所以不需要base64
编码。
方案4: HTTP + Shared Memory 这里采用了multiprocessing.shared_memory
模块,使用SharedMemory
对象将ndarray
的地址共享给子进程。然后将SharedMemory
对象名字作为HTTP的返回值,客户端再通过名字获取SharedMemory
对象,这样可以避免ndarray
和List
之间的转换,也避免了base64
编码的开销。
方案5: Websocket + Shared Memory 这种方案和方案4类似,不过使用Websocket作为通信协议,这样可以避免HTTP的开销。
方案6: Multiprocessing Listener / Client 这种方案使用multiprocessing
模块的Listener
和Client
对象,使用multiprocessing
的Pipe
作为通信协议,这样可以避免HTTP的开销。
方案7: gRPC + Bytes 这种方案使用gRPC
[3] 作为通信协议,使用protobuf
作为序列化协议,好处是方便客户端进行调用,但是gRPC有最大的消息长度限制(2GB)。
方案8: RabbitMQ + Bytes 这种方案使用RabbitMQ
[4] 作为通信协议,使用pika
[5] 作为Python的客户端和服务端。这种方案的好处是可以使用RabbitMQ
的其他特性,比如消息队列,消息持久化等,但是有最大的消息长度限制(512MB)。
测试结果
ndarray Shape
(1, 3, 224, 224)
(2, 1024, 16, 16)
(2, 3, 16, 224, 224)
(128, 1024, 1024)
HTTP + JSON
290.00 ms
1090 ms
9230 ms
259.00 s
HTTP + Base64 Bytes
26.40 ms
51.5 ms
398 ms
12.30 s
Websocket + Bytes
4.27 ms
15.0 ms
171 ms
5.14 s
HTTP + Shared Memory
10.70 ms
18.1 ms
127 ms
3.13 s
Websocket + Shared Memory
4.34 ms
14.9 ms
127 ms
3.82 s
Multiprocessing Listener
7.00 ms
17.2 ms
162 ms
4.73 s
gRPC + Bytes
7.34 ms
28.6 ms
291 ms
7.92 s
RabbitMQ + Bytes
9.35 ms
25.7 ms
243 ms
超出消息长度
根据这个结果我们可以发现,方案4和方案5的性能是最好的,方案6的性能也很好,方案1和方案2的性能最差。
考虑网络传输协议,Websocket的性能是比HTTP好的。所以应该尽量使用Websocket作为网络传输协议。
考虑使用Base64还是Shared Memory,我们可以发现大数据的情况下,Shared Memory的性能是比较好的,但是它需要手动管理内存,可能会有一些问题。所以对于小数据,可以使用Base64,对于大数据,可以使用Shared Memory。
对于multiprocessing
模块的Listener
和Client
,它的性能略弱于Shared Memory,但是它不需要手动管理共享内存,而且它不需要用fastapi
之类的外部库,而且不需要转换成别的类型的数据,比较方便。但是正因为它没有使用fastapi
,以至于它不是很方便的进行异步处理。
gRPC
和RabbitMQ
的性能比较差,只比HTTP好一点点,比不上Websocket,所以不推荐使用。而且它们有最大的消息大小限制,所以传输数据时不太方便。
在写以下比较重复的代码实现的时候,Github Copilot[6] 起到了很大的帮助。
代码实现 方案1: HTTP + JSON 序列化 1 2 3 4 5 6 7 8 9 10 import numpy as npfrom fastapi import FastAPIapp = FastAPI() shape = (1 , 3 , 224 , 224 ) @app.get("/random_tolist" ) async def random_tolist (): return np.random.randn(*shape).tolist()
1 2 3 4 5 6 7 import requestsimport numpy as npresult = requests.get("http://127.0.0.1:1234/random_tolist" ) result = np.array(result.json()) print (result.shape)
方案2: HTTP + Base64 Bytes 1 2 3 4 5 6 7 8 9 10 11 12 import numpy as npimport base64from fastapi import FastAPIfrom fastapi.responses import PlainTextResponseapp = FastAPI() shape = (1 , 3 , 224 , 224 ) @app.get("/random_tobytes" , response_class=PlainTextResponse ) async def random_tobytes (): return base64.b64encode(np.random.randn(*shape).tobytes())
1 2 3 4 5 6 7 8 import requestsimport numpy as npimport base64result = requests.get("http://127.0.0.1:1234/random_tobytes" ) result = np.frombuffer(base64.b64decode(result.text)).reshape(*shape) print (result.shape)
方案3: Websocket + Bytes 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import numpy as npfrom fastapi import FastAPI, WebSocketapp = FastAPI() shape = (1 , 3 , 224 , 224 ) @app.websocket("/ws/random_tobytes" ) async def websocket_random_tobytes (websocket: WebSocket ): await websocket.accept() while True : await websocket.receive_text() print ("Processing websocket" ) await websocket.send_text(np.random.randn(*shape).tobytes())
1 2 3 4 5 6 7 8 9 import numpy as npfrom websocket import create_connectionws = create_connection("ws://127.0.0.1:1234/ws/random_tobytes" ) ws.send("" ) result = np.frombuffer(ws.recv()).reshape(*shape) print (result.shape)
方案4: HTTP + Shared Memory 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 import numpy as npfrom fastapi import FastAPIfrom fastapi.responses import PlainTextResponsefrom multiprocessing import shared_memoryapp = FastAPI() shape = (1 , 3 , 224 , 224 ) @app.get("/random_sharedmemory" , response_class=PlainTextResponse ) async def random_sharedmemory (): arr = np.random.randn(*shape) shm = shared_memory.SharedMemory(create=True , size=arr.nbytes) out = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf) out[:] = arr[:] shm.close() return shm.name
1 2 3 4 5 6 7 8 9 10 11 12 13 import numpy as npimport requestsfrom multiprocessing import shared_memoryshm_name = requests.get("http://127.0.0.1:1234/random_sharedmemory" ) shm = shared_memory.SharedMemory(name=shm_name.text) result = np.ndarray(shape, dtype=float , buffer=shm.buf) shm.close() shm.unlink() print (result.shape)
方案5: Websocket + Shared Memory 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 import numpy as npfrom fastapi import FastAPI, WebSocketfrom multiprocessing import shared_memoryapp = FastAPI() shape = (1 , 3 , 224 , 224 ) @app.websocket("/ws/random_sharedmemory" ) async def websocket_random_sharedmemory (websocket: WebSocket ): await websocket.accept() while True : await websocket.receive_text() print ("Processing websocket" ) arr = np.random.randn(*shape) shm = shared_memory.SharedMemory(create=True , size=arr.nbytes) out = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf) out[:] = arr[:] shm.close() print (shm.name) await websocket.send_text(shm.name)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import numpy as npfrom websocket import create_connectionfrom multiprocessing import shared_memoryws = create_connection("ws://127.0.0.1:1234/ws/random_sharedmemory" ) ws.send("" ) shm_name = ws.recv() shm = shared_memory.SharedMemory(name=shm_name) result = np.ndarray(shape, dtype=float , buffer=shm.buf) shm.close() shm.unlink() print (result.shape)
方案6: Multiprocessing Listener / Client 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 from multiprocessing.connection import Listenerimport numpy as npshape = (1 , 3 , 224 , 224 ) address = ('localhost' , 1234 ) listener = Listener(address) conn = listener.accept() while True : msg = conn.recv() print ("Processing" ) conn.send(np.random.randn(*shape)) listener.close()
1 2 3 4 5 6 from multiprocessing.connection import Clientconn = Client(("localhost" , 1234 )) conn.send("" ) result = conn.recv() print (result.shape)
方案7: gRPC + Bytes 1 2 3 4 5 6 7 8 9 10 11 syntax = "proto3" ; import "google/protobuf/empty.proto" ;service Npy { rpc Get (google.protobuf.Empty) returns (ArrayData) {} } message ArrayData { bytes body = 1 ; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 from concurrent import futuresimport timeimport grpcimport npy_pb2import npy_pb2_grpcimport numpy as npshape = (1 , 3 , 224 , 224 ) class Npy (npy_pb2_grpc.NpyServicer): def Get (self, request, context ): arr = np.random.randn(*shape) return npy_pb2.ArrayData(body=np.ndarray.tobytes(arr)) def serve (): port = "50051" server = grpc.server(futures.ThreadPoolExecutor(max_workers=8 ), options=[ ('grpc.max_send_message_length' , 2 * 1024 **3 - 1 ), ('grpc.max_receive_message_length' , 2 * 1024 **3 - 1 ), ]) npy_pb2_grpc.add_NpyServicer_to_server(Npy(), server) server.add_insecure_port("[::]:" + port) server.start() print ("Server started, listening on " + port) server.wait_for_termination() if __name__ == '__main__' : serve()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 from __future__ import print_functionimport grpcimport numpy as npimport npy_pb2import npy_pb2_grpcshape = (1 , 3 , 224 , 224 ) channel = grpc.insecure_channel('localhost:50051' , options=[ ('grpc.max_send_message_length' , 2 * 1024 **3 - 1 ), ('grpc.max_receive_message_length' , 2 * 1024 **3 - 1 ), ] ) stub = npy_pb2_grpc.NpyStub(channel) response = stub.Get(npy_pb2.google_dot_protobuf_dot_empty__pb2.Empty()) arr = np.frombuffer(response.body, dtype=np.float64).reshape(shape) print (arr.shape)
方案8: RabbitMQ + Bytes 部署RabbitMQ的Docker容器:
1 docker run --name some-rabbit -p 5672:5672 rabbitmq:3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 import pikaimport numpy as npconnection = pika.BlockingConnection( pika.ConnectionParameters(host='localhost' )) channel = connection.channel() channel.queue_declare(queue='rpc_queue' ) shape = (1 , 3 , 224 , 224 ) def on_request (ch, method, props, body ): arr = np.random.randn(*shape) data = np.ndarray.tobytes(arr) ch.basic_publish(exchange='' , routing_key=props.reply_to, properties=pika.BasicProperties(correlation_id=props.correlation_id), body=data) ch.basic_ack(delivery_tag=method.delivery_tag) channel.basic_qos(prefetch_count=1 ) channel.basic_consume(queue='rpc_queue' , on_message_callback=on_request) print (" [x] Awaiting RPC requests" )channel.start_consuming()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 import pikaimport uuidconnection = pika.BlockingConnection(pika.ConnectionParameters(host='localhost' )) channel = connection.channel() result = channel.queue_declare(queue='' , exclusive=True ) callback_queue = result.method.queue response = None corr_id = None shape = (128 , 1024 , 1024 ) def on_response (ch, method, props, body ): global response if corr_id == props.correlation_id: response = body channel.basic_consume( queue=callback_queue, on_message_callback=on_response, auto_ack=True ) def call (): global response global corr_id response = None corr_id = str (uuid.uuid4()) channel.basic_publish( exchange='' , routing_key='rpc_queue' , properties=pika.BasicProperties( reply_to=callback_queue, correlation_id=corr_id, ), body="" ) connection.process_data_events(time_limit=None ) return response response = call() arr = np.frombuffer(response, dtype=np.float64).reshape(shape) print (arr.shape)
参考文献 [1] "NumPy", Numpy.org, 2022. https://numpy.org/. [2] "FastAPI", FastAPI, 2022. https://fastapi.tiangolo.com/ [3] "gRPC" gRPC. https://grpc.io/ [4] “Messaging that just works — RabbitMQ,” Rabbitmq.com, 2019. https://www.rabbitmq.com/ [5] pika, “Pika,” GitHub, Jul. 28, 2023. https://github.com/pika/pika (accessed Jul. 28, 2023). [6] "GitHub Copilot · Your AI pair programmer," GitHub, 2022. https://github.com/features/copilot