aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc/py/pw_rpc/packets.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_rpc/py/pw_rpc/packets.py')
-rw-r--r--pw_rpc/py/pw_rpc/packets.py67
1 files changed, 36 insertions, 31 deletions
diff --git a/pw_rpc/py/pw_rpc/packets.py b/pw_rpc/py/pw_rpc/packets.py
index ddcc03e74..54d4b5085 100644
--- a/pw_rpc/py/pw_rpc/packets.py
+++ b/pw_rpc/py/pw_rpc/packets.py
@@ -13,6 +13,7 @@
# the License.
"""Functions for working with pw_rpc packets."""
+import dataclasses
from typing import Optional
from google.protobuf import message
@@ -33,43 +34,47 @@ def decode_payload(packet, payload_type):
return payload
-def _ids(rpc: tuple) -> tuple:
- return tuple(item if isinstance(item, int) else item.id for item in rpc)
+@dataclasses.dataclass(eq=True, frozen=True)
+class RpcIds:
+ """Integer IDs that uniquely identify a remote procedure call."""
+ channel_id: int
+ service_id: int
+ method_id: int
+ call_id: int
-def encode_request(rpc: tuple, request: Optional[message.Message]) -> bytes:
- channel, service, method = _ids(rpc)
+
+def encode_request(rpc: RpcIds, request: Optional[message.Message]) -> bytes:
payload = request.SerializeToString() if request is not None else bytes()
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.REQUEST,
- channel_id=channel,
- service_id=service,
- method_id=method,
+ channel_id=rpc.channel_id,
+ service_id=rpc.service_id,
+ method_id=rpc.method_id,
+ call_id=rpc.call_id,
payload=payload,
).SerializeToString()
-def encode_response(rpc: tuple, response: message.Message) -> bytes:
- channel, service, method = _ids(rpc)
-
+def encode_response(rpc: RpcIds, response: message.Message) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.RESPONSE,
- channel_id=channel,
- service_id=service,
- method_id=method,
+ channel_id=rpc.channel_id,
+ service_id=rpc.service_id,
+ method_id=rpc.method_id,
+ call_id=rpc.call_id,
payload=response.SerializeToString(),
).SerializeToString()
-def encode_client_stream(rpc: tuple, request: message.Message) -> bytes:
- channel, service, method = _ids(rpc)
-
+def encode_client_stream(rpc: RpcIds, request: message.Message) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_STREAM,
- channel_id=channel,
- service_id=service,
- method_id=method,
+ channel_id=rpc.channel_id,
+ service_id=rpc.service_id,
+ method_id=rpc.method_id,
+ call_id=rpc.call_id,
payload=request.SerializeToString(),
).SerializeToString()
@@ -80,29 +85,29 @@ def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes:
channel_id=packet.channel_id,
service_id=packet.service_id,
method_id=packet.method_id,
+ call_id=packet.call_id,
status=status.value,
).SerializeToString()
-def encode_cancel(rpc: tuple) -> bytes:
- channel, service, method = _ids(rpc)
+def encode_cancel(rpc: RpcIds) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_ERROR,
status=Status.CANCELLED.value,
- channel_id=channel,
- service_id=service,
- method_id=method,
+ channel_id=rpc.channel_id,
+ service_id=rpc.service_id,
+ method_id=rpc.method_id,
+ call_id=rpc.call_id,
).SerializeToString()
-def encode_client_stream_end(rpc: tuple) -> bytes:
- channel, service, method = _ids(rpc)
-
+def encode_client_stream_end(rpc: RpcIds) -> bytes:
return packet_pb2.RpcPacket(
- type=packet_pb2.PacketType.CLIENT_STREAM_END,
- channel_id=channel,
- service_id=service,
- method_id=method,
+ type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
+ channel_id=rpc.channel_id,
+ service_id=rpc.service_id,
+ method_id=rpc.method_id,
+ call_id=rpc.call_id,
).SerializeToString()