aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc/py/tests/packets_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_rpc/py/tests/packets_test.py')
-rwxr-xr-xpw_rpc/py/tests/packets_test.py56
1 files changed, 26 insertions, 30 deletions
diff --git a/pw_rpc/py/tests/packets_test.py b/pw_rpc/py/tests/packets_test.py
index d6fa87935..3edded35e 100755
--- a/pw_rpc/py/tests/packets_test.py
+++ b/pw_rpc/py/tests/packets_test.py
@@ -21,12 +21,23 @@ from pw_status import Status
from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
from pw_rpc import packets
+_TEST_IDS = packets.RpcIds(1, 2, 3, 4)
+_TEST_STATUS = 321
_TEST_REQUEST = RpcPacket(
type=PacketType.REQUEST,
- channel_id=1,
- service_id=2,
- method_id=3,
- payload=RpcPacket(status=321).SerializeToString(),
+ channel_id=_TEST_IDS.channel_id,
+ service_id=_TEST_IDS.service_id,
+ method_id=_TEST_IDS.method_id,
+ call_id=_TEST_IDS.call_id,
+ payload=RpcPacket(status=_TEST_STATUS).SerializeToString(),
+)
+_TEST_RESPONSE = RpcPacket(
+ type=PacketType.RESPONSE,
+ channel_id=_TEST_IDS.channel_id,
+ service_id=_TEST_IDS.service_id,
+ method_id=_TEST_IDS.method_id,
+ call_id=_TEST_IDS.call_id,
+ payload=RpcPacket(status=_TEST_STATUS).SerializeToString(),
)
@@ -34,29 +45,23 @@ class PacketsTest(unittest.TestCase):
"""Tests for packet encoding and decoding."""
def test_encode_request(self):
- data = packets.encode_request((1, 2, 3), RpcPacket(status=321))
+ data = packets.encode_request(_TEST_IDS, RpcPacket(status=_TEST_STATUS))
packet = RpcPacket()
packet.ParseFromString(data)
self.assertEqual(_TEST_REQUEST, packet)
def test_encode_response(self):
- response = RpcPacket(
- type=PacketType.RESPONSE,
- channel_id=1,
- service_id=2,
- method_id=3,
- payload=RpcPacket(status=321).SerializeToString(),
+ data = packets.encode_response(
+ _TEST_IDS, RpcPacket(status=_TEST_STATUS)
)
-
- data = packets.encode_response((1, 2, 3), RpcPacket(status=321))
packet = RpcPacket()
packet.ParseFromString(data)
- self.assertEqual(response, packet)
+ self.assertEqual(_TEST_RESPONSE, packet)
def test_encode_cancel(self):
- data = packets.encode_cancel((9, 8, 7))
+ data = packets.encode_cancel(packets.RpcIds(9, 8, 7, 6))
packet = RpcPacket()
packet.ParseFromString(data)
@@ -68,6 +73,7 @@ class PacketsTest(unittest.TestCase):
channel_id=9,
service_id=8,
method_id=7,
+ call_id=6,
status=Status.CANCELLED.value,
),
)
@@ -82,9 +88,10 @@ class PacketsTest(unittest.TestCase):
packet,
RpcPacket(
type=PacketType.CLIENT_ERROR,
- channel_id=1,
- service_id=2,
- method_id=3,
+ channel_id=_TEST_IDS.channel_id,
+ service_id=_TEST_IDS.service_id,
+ method_id=_TEST_IDS.method_id,
+ call_id=_TEST_IDS.call_id,
status=Status.NOT_FOUND.value,
),
)
@@ -96,18 +103,7 @@ class PacketsTest(unittest.TestCase):
def test_for_server(self):
self.assertTrue(packets.for_server(_TEST_REQUEST))
-
- self.assertFalse(
- packets.for_server(
- RpcPacket(
- type=PacketType.RESPONSE,
- channel_id=1,
- service_id=2,
- method_id=3,
- payload=RpcPacket(status=321).SerializeToString(),
- )
- )
- )
+ self.assertFalse(packets.for_server(_TEST_RESPONSE))
if __name__ == '__main__':