diff options
Diffstat (limited to 'pw_rpc/py/tests/packets_test.py')
-rwxr-xr-x | pw_rpc/py/tests/packets_test.py | 56 |
1 files changed, 26 insertions, 30 deletions
diff --git a/pw_rpc/py/tests/packets_test.py b/pw_rpc/py/tests/packets_test.py index d6fa87935..3edded35e 100755 --- a/pw_rpc/py/tests/packets_test.py +++ b/pw_rpc/py/tests/packets_test.py @@ -21,12 +21,23 @@ from pw_status import Status from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket from pw_rpc import packets +_TEST_IDS = packets.RpcIds(1, 2, 3, 4) +_TEST_STATUS = 321 _TEST_REQUEST = RpcPacket( type=PacketType.REQUEST, - channel_id=1, - service_id=2, - method_id=3, - payload=RpcPacket(status=321).SerializeToString(), + channel_id=_TEST_IDS.channel_id, + service_id=_TEST_IDS.service_id, + method_id=_TEST_IDS.method_id, + call_id=_TEST_IDS.call_id, + payload=RpcPacket(status=_TEST_STATUS).SerializeToString(), +) +_TEST_RESPONSE = RpcPacket( + type=PacketType.RESPONSE, + channel_id=_TEST_IDS.channel_id, + service_id=_TEST_IDS.service_id, + method_id=_TEST_IDS.method_id, + call_id=_TEST_IDS.call_id, + payload=RpcPacket(status=_TEST_STATUS).SerializeToString(), ) @@ -34,29 +45,23 @@ class PacketsTest(unittest.TestCase): """Tests for packet encoding and decoding.""" def test_encode_request(self): - data = packets.encode_request((1, 2, 3), RpcPacket(status=321)) + data = packets.encode_request(_TEST_IDS, RpcPacket(status=_TEST_STATUS)) packet = RpcPacket() packet.ParseFromString(data) self.assertEqual(_TEST_REQUEST, packet) def test_encode_response(self): - response = RpcPacket( - type=PacketType.RESPONSE, - channel_id=1, - service_id=2, - method_id=3, - payload=RpcPacket(status=321).SerializeToString(), + data = packets.encode_response( + _TEST_IDS, RpcPacket(status=_TEST_STATUS) ) - - data = packets.encode_response((1, 2, 3), RpcPacket(status=321)) packet = RpcPacket() packet.ParseFromString(data) - self.assertEqual(response, packet) + self.assertEqual(_TEST_RESPONSE, packet) def test_encode_cancel(self): - data = packets.encode_cancel((9, 8, 7)) + data = packets.encode_cancel(packets.RpcIds(9, 8, 7, 6)) packet = RpcPacket() packet.ParseFromString(data) @@ -68,6 +73,7 @@ class PacketsTest(unittest.TestCase): channel_id=9, service_id=8, method_id=7, + call_id=6, status=Status.CANCELLED.value, ), ) @@ -82,9 +88,10 @@ class PacketsTest(unittest.TestCase): packet, RpcPacket( type=PacketType.CLIENT_ERROR, - channel_id=1, - service_id=2, - method_id=3, + channel_id=_TEST_IDS.channel_id, + service_id=_TEST_IDS.service_id, + method_id=_TEST_IDS.method_id, + call_id=_TEST_IDS.call_id, status=Status.NOT_FOUND.value, ), ) @@ -96,18 +103,7 @@ class PacketsTest(unittest.TestCase): def test_for_server(self): self.assertTrue(packets.for_server(_TEST_REQUEST)) - - self.assertFalse( - packets.for_server( - RpcPacket( - type=PacketType.RESPONSE, - channel_id=1, - service_id=2, - method_id=3, - payload=RpcPacket(status=321).SerializeToString(), - ) - ) - ) + self.assertFalse(packets.for_server(_TEST_RESPONSE)) if __name__ == '__main__': |