diff options
Diffstat (limited to 'pw_rpc/py/tests/client_test.py')
-rwxr-xr-x | pw_rpc/py/tests/client_test.py | 117 |
1 files changed, 93 insertions, 24 deletions
diff --git a/pw_rpc/py/tests/client_test.py b/pw_rpc/py/tests/client_test.py index 92a1f8236..28bea9069 100755 --- a/pw_rpc/py/tests/client_test.py +++ b/pw_rpc/py/tests/client_test.py @@ -15,7 +15,7 @@ """Tests creating pw_rpc client.""" import unittest -from typing import Optional +from typing import Any, Callable, Optional from pw_protobuf_compiler import python_protos from pw_status import Status @@ -24,6 +24,8 @@ from pw_rpc import callback_client, client, packets import pw_rpc.ids from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket +RpcIds = packets.RpcIds + TEST_PROTO_1 = """\ syntax = "proto3"; @@ -73,13 +75,30 @@ service Bravo { } """ +SOME_CHANNEL_ID: int = 237 +SOME_SERVICE_ID: int = 193 +SOME_METHOD_ID: int = 769 +SOME_CALL_ID: int = 452 + +CLIENT_FIRST_CHANNEL_ID: int = 557 +CLIENT_SECOND_CHANNEL_ID: int = 474 + + +def create_protos() -> Any: + return python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2]) -def _test_setup(output=None): - protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2]) - return protos, client.Client.from_modules( + +def create_client( + proto_modules: Any, + first_channel_output_fn: Optional[Callable[[bytes], Any]] = None, +) -> client.Client: + return client.Client.from_modules( callback_client.Impl(), - [client.Channel(1, output), client.Channel(2, lambda _: None)], - protos.modules(), + [ + client.Channel(CLIENT_FIRST_CHANNEL_ID, first_channel_output_fn), + client.Channel(CLIENT_SECOND_CHANNEL_ID, lambda _: None), + ], + proto_modules, ) @@ -87,7 +106,10 @@ class ChannelClientTest(unittest.TestCase): """Tests the ChannelClient.""" def setUp(self) -> None: - self._channel_client = _test_setup()[1].channel(1) + client_instance = create_client(create_protos().modules()) + self._channel_client: client.ChannelClient = client_instance.channel( + CLIENT_FIRST_CHANNEL_ID + ) def test_access_service_client_as_attribute_or_index(self) -> None: self.assertIs( @@ -182,7 +204,8 @@ class ClientTest(unittest.TestCase): def setUp(self) -> None: self._last_packet_sent_bytes: Optional[bytes] = None - self._protos, self._client = _test_setup(self._save_packet) + self._protos = create_protos() + self._client = create_client(self._protos.modules(), self._save_packet) def _save_packet(self, packet) -> None: self._last_packet_sent_bytes = packet @@ -194,11 +217,19 @@ class ClientTest(unittest.TestCase): return packet def test_channel(self) -> None: - self.assertEqual(self._client.channel(1).channel.id, 1) - self.assertEqual(self._client.channel(2).channel.id, 2) + self.assertEqual( + self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel.id, + CLIENT_FIRST_CHANNEL_ID, + ) + self.assertEqual( + self._client.channel(CLIENT_SECOND_CHANNEL_ID).channel.id, + CLIENT_SECOND_CHANNEL_ID, + ) def test_channel_default_is_first_listed(self) -> None: - self.assertEqual(self._client.channel().channel.id, 1) + self.assertEqual( + self._client.channel().channel.id, CLIENT_FIRST_CHANNEL_ID + ) def test_channel_invalid(self) -> None: with self.assertRaises(KeyError): @@ -259,7 +290,13 @@ class ClientTest(unittest.TestCase): self.assertIs( self._client.process_packet( packets.encode_response( - (123, 456, 789), self._protos.packages.pw.test2.Request() + RpcIds( + SOME_CHANNEL_ID, + SOME_SERVICE_ID, + SOME_METHOD_ID, + SOME_CALL_ID, + ), + self._protos.packages.pw.test2.Request(), ) ), Status.NOT_FOUND, @@ -269,7 +306,13 @@ class ClientTest(unittest.TestCase): self.assertIs( self._client.process_packet( packets.encode_response( - (1, 456, 789), self._protos.packages.pw.test2.Request() + RpcIds( + CLIENT_FIRST_CHANNEL_ID, + SOME_SERVICE_ID, + SOME_METHOD_ID, + SOME_CALL_ID, + ), + self._protos.packages.pw.test2.Request(), ) ), Status.OK, @@ -279,9 +322,10 @@ class ClientTest(unittest.TestCase): self._last_packet_sent(), RpcPacket( type=PacketType.CLIENT_ERROR, - channel_id=1, - service_id=456, - method_id=789, + channel_id=CLIENT_FIRST_CHANNEL_ID, + service_id=SOME_SERVICE_ID, + method_id=SOME_METHOD_ID, + call_id=SOME_CALL_ID, status=Status.NOT_FOUND.value, ), ) @@ -292,7 +336,12 @@ class ClientTest(unittest.TestCase): self.assertIs( self._client.process_packet( packets.encode_response( - (1, service.id, 789), + RpcIds( + CLIENT_FIRST_CHANNEL_ID, + service.id, + SOME_METHOD_ID, + SOME_CALL_ID, + ), self._protos.packages.pw.test2.Request(), ) ), @@ -303,9 +352,10 @@ class ClientTest(unittest.TestCase): self._last_packet_sent(), RpcPacket( type=PacketType.CLIENT_ERROR, - channel_id=1, + channel_id=CLIENT_FIRST_CHANNEL_ID, service_id=service.id, - method_id=789, + method_id=SOME_METHOD_ID, + call_id=SOME_CALL_ID, status=Status.NOT_FOUND.value, ), ) @@ -317,7 +367,12 @@ class ClientTest(unittest.TestCase): self.assertIs( self._client.process_packet( packets.encode_response( - (1, service.id, method.id), + RpcIds( + CLIENT_FIRST_CHANNEL_ID, + service.id, + method.id, + SOME_CALL_ID, + ), self._protos.packages.pw.test2.Request(), ) ), @@ -328,9 +383,10 @@ class ClientTest(unittest.TestCase): self._last_packet_sent(), RpcPacket( type=PacketType.CLIENT_ERROR, - channel_id=1, + channel_id=CLIENT_FIRST_CHANNEL_ID, service_id=service.id, method_id=method.id, + call_id=SOME_CALL_ID, status=Status.FAILED_PRECONDITION.value, ), ) @@ -340,12 +396,17 @@ class ClientTest(unittest.TestCase): reply = method.response_type(payload='hello') def response_callback( - rpc: client.PendingRpc, message, status: Optional[Status] + rpc: client.PendingRpc, + message, + status: Optional[Status], ) -> None: self.assertEqual( rpc, client.PendingRpc( - self._client.channel(1).channel, method.service, method + self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel, + method.service, + method, + call_id=SOME_CALL_ID, ), ) self.assertEqual(message, reply) @@ -355,7 +416,15 @@ class ClientTest(unittest.TestCase): self.assertIs( self._client.process_packet( - packets.encode_response((1, method.service, method), reply) + packets.encode_response( + RpcIds( + CLIENT_FIRST_CHANNEL_ID, + method.service.id, + method.id, + SOME_CALL_ID, + ), + reply, + ) ), Status.OK, ) |