diff options
Diffstat (limited to 'pw_rpc/py/pw_rpc/packets.py')
-rw-r--r-- | pw_rpc/py/pw_rpc/packets.py | 67 |
1 files changed, 36 insertions, 31 deletions
diff --git a/pw_rpc/py/pw_rpc/packets.py b/pw_rpc/py/pw_rpc/packets.py index ddcc03e74..54d4b5085 100644 --- a/pw_rpc/py/pw_rpc/packets.py +++ b/pw_rpc/py/pw_rpc/packets.py @@ -13,6 +13,7 @@ # the License. """Functions for working with pw_rpc packets.""" +import dataclasses from typing import Optional from google.protobuf import message @@ -33,43 +34,47 @@ def decode_payload(packet, payload_type): return payload -def _ids(rpc: tuple) -> tuple: - return tuple(item if isinstance(item, int) else item.id for item in rpc) +@dataclasses.dataclass(eq=True, frozen=True) +class RpcIds: + """Integer IDs that uniquely identify a remote procedure call.""" + channel_id: int + service_id: int + method_id: int + call_id: int -def encode_request(rpc: tuple, request: Optional[message.Message]) -> bytes: - channel, service, method = _ids(rpc) + +def encode_request(rpc: RpcIds, request: Optional[message.Message]) -> bytes: payload = request.SerializeToString() if request is not None else bytes() return packet_pb2.RpcPacket( type=packet_pb2.PacketType.REQUEST, - channel_id=channel, - service_id=service, - method_id=method, + channel_id=rpc.channel_id, + service_id=rpc.service_id, + method_id=rpc.method_id, + call_id=rpc.call_id, payload=payload, ).SerializeToString() -def encode_response(rpc: tuple, response: message.Message) -> bytes: - channel, service, method = _ids(rpc) - +def encode_response(rpc: RpcIds, response: message.Message) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.RESPONSE, - channel_id=channel, - service_id=service, - method_id=method, + channel_id=rpc.channel_id, + service_id=rpc.service_id, + method_id=rpc.method_id, + call_id=rpc.call_id, payload=response.SerializeToString(), ).SerializeToString() -def encode_client_stream(rpc: tuple, request: message.Message) -> bytes: - channel, service, method = _ids(rpc) - +def encode_client_stream(rpc: RpcIds, request: message.Message) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.CLIENT_STREAM, - channel_id=channel, - service_id=service, - method_id=method, + channel_id=rpc.channel_id, + service_id=rpc.service_id, + method_id=rpc.method_id, + call_id=rpc.call_id, payload=request.SerializeToString(), ).SerializeToString() @@ -80,29 +85,29 @@ def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes: channel_id=packet.channel_id, service_id=packet.service_id, method_id=packet.method_id, + call_id=packet.call_id, status=status.value, ).SerializeToString() -def encode_cancel(rpc: tuple) -> bytes: - channel, service, method = _ids(rpc) +def encode_cancel(rpc: RpcIds) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.CLIENT_ERROR, status=Status.CANCELLED.value, - channel_id=channel, - service_id=service, - method_id=method, + channel_id=rpc.channel_id, + service_id=rpc.service_id, + method_id=rpc.method_id, + call_id=rpc.call_id, ).SerializeToString() -def encode_client_stream_end(rpc: tuple) -> bytes: - channel, service, method = _ids(rpc) - +def encode_client_stream_end(rpc: RpcIds) -> bytes: return packet_pb2.RpcPacket( - type=packet_pb2.PacketType.CLIENT_STREAM_END, - channel_id=channel, - service_id=service, - method_id=method, + type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, + channel_id=rpc.channel_id, + service_id=rpc.service_id, + method_id=rpc.method_id, + call_id=rpc.call_id, ).SerializeToString() |