diff options
Diffstat (limited to 'pw_rpc/py/tests/client_test.py')
-rwxr-xr-x | pw_rpc/py/tests/client_test.py | 169 |
1 files changed, 112 insertions, 57 deletions
diff --git a/pw_rpc/py/tests/client_test.py b/pw_rpc/py/tests/client_test.py index e07125287..92a1f8236 100755 --- a/pw_rpc/py/tests/client_test.py +++ b/pw_rpc/py/tests/client_test.py @@ -78,46 +78,58 @@ def _test_setup(output=None): protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2]) return protos, client.Client.from_modules( callback_client.Impl(), - [client.Channel(1, output), - client.Channel(2, lambda _: None)], protos.modules()) + [client.Channel(1, output), client.Channel(2, lambda _: None)], + protos.modules(), + ) class ChannelClientTest(unittest.TestCase): """Tests the ChannelClient.""" + def setUp(self) -> None: self._channel_client = _test_setup()[1].channel(1) def test_access_service_client_as_attribute_or_index(self) -> None: - self.assertIs(self._channel_client.rpcs.pw.test1.PublicService, - self._channel_client.rpcs['pw.test1.PublicService']) self.assertIs( self._channel_client.rpcs.pw.test1.PublicService, - self._channel_client.rpcs[pw_rpc.ids.calculate( - 'pw.test1.PublicService')]) + self._channel_client.rpcs['pw.test1.PublicService'], + ) + self.assertIs( + self._channel_client.rpcs.pw.test1.PublicService, + self._channel_client.rpcs[ + pw_rpc.ids.calculate('pw.test1.PublicService') + ], + ) def test_access_method_client_as_attribute_or_index(self) -> None: - self.assertIs(self._channel_client.rpcs.pw.test2.Alpha.Unary, - self._channel_client.rpcs['pw.test2.Alpha']['Unary']) self.assertIs( self._channel_client.rpcs.pw.test2.Alpha.Unary, - self._channel_client.rpcs['pw.test2.Alpha'][pw_rpc.ids.calculate( - 'Unary')]) + self._channel_client.rpcs['pw.test2.Alpha']['Unary'], + ) + self.assertIs( + self._channel_client.rpcs.pw.test2.Alpha.Unary, + self._channel_client.rpcs['pw.test2.Alpha'][ + pw_rpc.ids.calculate('Unary') + ], + ) def test_service_name(self) -> None: self.assertEqual( - self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, - 'Alpha') + self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 'Alpha' + ) self.assertEqual( self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name, - 'pw.test2.Alpha') + 'pw.test2.Alpha', + ) def test_method_name(self) -> None: self.assertEqual( - self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, - 'Unary') + self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 'Unary' + ) self.assertEqual( self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name, - 'pw.test2.Alpha.Unary') + 'pw.test2.Alpha.Unary', + ) def test_iterate_over_all_methods(self) -> None: channel_client = self._channel_client @@ -133,8 +145,10 @@ class ChannelClientTest(unittest.TestCase): def test_check_for_presence_of_services(self) -> None: self.assertIn('pw.test1.PublicService', self._channel_client.rpcs) - self.assertIn(pw_rpc.ids.calculate('pw.test1.PublicService'), - self._channel_client.rpcs) + self.assertIn( + pw_rpc.ids.calculate('pw.test1.PublicService'), + self._channel_client.rpcs, + ) def test_check_for_presence_of_missing_services(self) -> None: self.assertNotIn('PublicService', self._channel_client.rpcs) @@ -153,14 +167,19 @@ class ChannelClientTest(unittest.TestCase): self.assertNotIn(12345, service) def test_method_fully_qualified_name(self) -> None: - self.assertIs(self._channel_client.method('pw.test2.Alpha/Unary'), - self._channel_client.rpcs.pw.test2.Alpha.Unary) - self.assertIs(self._channel_client.method('pw.test2.Alpha.Unary'), - self._channel_client.rpcs.pw.test2.Alpha.Unary) + self.assertIs( + self._channel_client.method('pw.test2.Alpha/Unary'), + self._channel_client.rpcs.pw.test2.Alpha.Unary, + ) + self.assertIs( + self._channel_client.method('pw.test2.Alpha.Unary'), + self._channel_client.rpcs.pw.test2.Alpha.Unary, + ) class ClientTest(unittest.TestCase): """Tests the pw_rpc Client independently of the ClientImpl.""" + def setUp(self) -> None: self._last_packet_sent_bytes: Optional[bytes] = None self._protos, self._client = _test_setup(self._save_packet) @@ -200,11 +219,17 @@ class ClientTest(unittest.TestCase): def test_method_present(self) -> None: self.assertIs( - self._client.method('pw.test1.PublicService.SomeUnary'), self. - _client.services['pw.test1.PublicService'].methods['SomeUnary']) + self._client.method('pw.test1.PublicService.SomeUnary'), + self._client.services['pw.test1.PublicService'].methods[ + 'SomeUnary' + ], + ) self.assertIs( - self._client.method('pw.test1.PublicService/SomeUnary'), self. - _client.services['pw.test1.PublicService'].methods['SomeUnary']) + self._client.method('pw.test1.PublicService/SomeUnary'), + self._client.services['pw.test1.PublicService'].methods[ + 'SomeUnary' + ], + ) def test_method_invalid_format(self) -> None: with self.assertRaises(ValueError): @@ -218,37 +243,48 @@ class ClientTest(unittest.TestCase): self._client.method('nothing.Good') def test_process_packet_invalid_proto_data(self) -> None: - self.assertIs(self._client.process_packet(b'NOT a packet!'), - Status.DATA_LOSS) + self.assertIs( + self._client.process_packet(b'NOT a packet!'), Status.DATA_LOSS + ) def test_process_packet_not_for_client(self) -> None: self.assertIs( self._client.process_packet( - RpcPacket(type=PacketType.REQUEST).SerializeToString()), - Status.INVALID_ARGUMENT) + RpcPacket(type=PacketType.REQUEST).SerializeToString() + ), + Status.INVALID_ARGUMENT, + ) def test_process_packet_unrecognized_channel(self) -> None: self.assertIs( self._client.process_packet( packets.encode_response( - (123, 456, 789), - self._protos.packages.pw.test2.Request())), - Status.NOT_FOUND) + (123, 456, 789), self._protos.packages.pw.test2.Request() + ) + ), + Status.NOT_FOUND, + ) def test_process_packet_unrecognized_service(self) -> None: self.assertIs( self._client.process_packet( packets.encode_response( - (1, 456, 789), self._protos.packages.pw.test2.Request())), - Status.OK) + (1, 456, 789), self._protos.packages.pw.test2.Request() + ) + ), + Status.OK, + ) self.assertEqual( self._last_packet_sent(), - RpcPacket(type=PacketType.CLIENT_ERROR, - channel_id=1, - service_id=456, - method_id=789, - status=Status.NOT_FOUND.value)) + RpcPacket( + type=PacketType.CLIENT_ERROR, + channel_id=1, + service_id=456, + method_id=789, + status=Status.NOT_FOUND.value, + ), + ) def test_process_packet_unrecognized_method(self) -> None: service = next(iter(self._client.services)) @@ -257,15 +293,22 @@ class ClientTest(unittest.TestCase): self._client.process_packet( packets.encode_response( (1, service.id, 789), - self._protos.packages.pw.test2.Request())), Status.OK) + self._protos.packages.pw.test2.Request(), + ) + ), + Status.OK, + ) self.assertEqual( self._last_packet_sent(), - RpcPacket(type=PacketType.CLIENT_ERROR, - channel_id=1, - service_id=service.id, - method_id=789, - status=Status.NOT_FOUND.value)) + RpcPacket( + type=PacketType.CLIENT_ERROR, + channel_id=1, + service_id=service.id, + method_id=789, + status=Status.NOT_FOUND.value, + ), + ) def test_process_packet_non_pending_method(self) -> None: service = next(iter(self._client.services)) @@ -275,26 +318,36 @@ class ClientTest(unittest.TestCase): self._client.process_packet( packets.encode_response( (1, service.id, method.id), - self._protos.packages.pw.test2.Request())), Status.OK) + self._protos.packages.pw.test2.Request(), + ) + ), + Status.OK, + ) self.assertEqual( self._last_packet_sent(), - RpcPacket(type=PacketType.CLIENT_ERROR, - channel_id=1, - service_id=service.id, - method_id=method.id, - status=Status.FAILED_PRECONDITION.value)) + RpcPacket( + type=PacketType.CLIENT_ERROR, + channel_id=1, + service_id=service.id, + method_id=method.id, + status=Status.FAILED_PRECONDITION.value, + ), + ) def test_process_packet_non_pending_calls_response_callback(self) -> None: method = self._client.method('pw.test1.PublicService.SomeUnary') reply = method.response_type(payload='hello') - def response_callback(rpc: client.PendingRpc, message, - status: Optional[Status]) -> None: + def response_callback( + rpc: client.PendingRpc, message, status: Optional[Status] + ) -> None: self.assertEqual( rpc, client.PendingRpc( - self._client.channel(1).channel, method.service, method)) + self._client.channel(1).channel, method.service, method + ), + ) self.assertEqual(message, reply) self.assertIs(status, Status.OK) @@ -302,8 +355,10 @@ class ClientTest(unittest.TestCase): self.assertIs( self._client.process_packet( - packets.encode_response((1, method.service, method), reply)), - Status.OK) + packets.encode_response((1, method.service, method), reply) + ), + Status.OK, + ) if __name__ == '__main__': |