aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc/py/tests/client_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_rpc/py/tests/client_test.py')
-rwxr-xr-xpw_rpc/py/tests/client_test.py169
1 files changed, 112 insertions, 57 deletions
diff --git a/pw_rpc/py/tests/client_test.py b/pw_rpc/py/tests/client_test.py
index e07125287..92a1f8236 100755
--- a/pw_rpc/py/tests/client_test.py
+++ b/pw_rpc/py/tests/client_test.py
@@ -78,46 +78,58 @@ def _test_setup(output=None):
protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])
return protos, client.Client.from_modules(
callback_client.Impl(),
- [client.Channel(1, output),
- client.Channel(2, lambda _: None)], protos.modules())
+ [client.Channel(1, output), client.Channel(2, lambda _: None)],
+ protos.modules(),
+ )
class ChannelClientTest(unittest.TestCase):
"""Tests the ChannelClient."""
+
def setUp(self) -> None:
self._channel_client = _test_setup()[1].channel(1)
def test_access_service_client_as_attribute_or_index(self) -> None:
- self.assertIs(self._channel_client.rpcs.pw.test1.PublicService,
- self._channel_client.rpcs['pw.test1.PublicService'])
self.assertIs(
self._channel_client.rpcs.pw.test1.PublicService,
- self._channel_client.rpcs[pw_rpc.ids.calculate(
- 'pw.test1.PublicService')])
+ self._channel_client.rpcs['pw.test1.PublicService'],
+ )
+ self.assertIs(
+ self._channel_client.rpcs.pw.test1.PublicService,
+ self._channel_client.rpcs[
+ pw_rpc.ids.calculate('pw.test1.PublicService')
+ ],
+ )
def test_access_method_client_as_attribute_or_index(self) -> None:
- self.assertIs(self._channel_client.rpcs.pw.test2.Alpha.Unary,
- self._channel_client.rpcs['pw.test2.Alpha']['Unary'])
self.assertIs(
self._channel_client.rpcs.pw.test2.Alpha.Unary,
- self._channel_client.rpcs['pw.test2.Alpha'][pw_rpc.ids.calculate(
- 'Unary')])
+ self._channel_client.rpcs['pw.test2.Alpha']['Unary'],
+ )
+ self.assertIs(
+ self._channel_client.rpcs.pw.test2.Alpha.Unary,
+ self._channel_client.rpcs['pw.test2.Alpha'][
+ pw_rpc.ids.calculate('Unary')
+ ],
+ )
def test_service_name(self) -> None:
self.assertEqual(
- self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name,
- 'Alpha')
+ self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 'Alpha'
+ )
self.assertEqual(
self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name,
- 'pw.test2.Alpha')
+ 'pw.test2.Alpha',
+ )
def test_method_name(self) -> None:
self.assertEqual(
- self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name,
- 'Unary')
+ self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 'Unary'
+ )
self.assertEqual(
self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name,
- 'pw.test2.Alpha.Unary')
+ 'pw.test2.Alpha.Unary',
+ )
def test_iterate_over_all_methods(self) -> None:
channel_client = self._channel_client
@@ -133,8 +145,10 @@ class ChannelClientTest(unittest.TestCase):
def test_check_for_presence_of_services(self) -> None:
self.assertIn('pw.test1.PublicService', self._channel_client.rpcs)
- self.assertIn(pw_rpc.ids.calculate('pw.test1.PublicService'),
- self._channel_client.rpcs)
+ self.assertIn(
+ pw_rpc.ids.calculate('pw.test1.PublicService'),
+ self._channel_client.rpcs,
+ )
def test_check_for_presence_of_missing_services(self) -> None:
self.assertNotIn('PublicService', self._channel_client.rpcs)
@@ -153,14 +167,19 @@ class ChannelClientTest(unittest.TestCase):
self.assertNotIn(12345, service)
def test_method_fully_qualified_name(self) -> None:
- self.assertIs(self._channel_client.method('pw.test2.Alpha/Unary'),
- self._channel_client.rpcs.pw.test2.Alpha.Unary)
- self.assertIs(self._channel_client.method('pw.test2.Alpha.Unary'),
- self._channel_client.rpcs.pw.test2.Alpha.Unary)
+ self.assertIs(
+ self._channel_client.method('pw.test2.Alpha/Unary'),
+ self._channel_client.rpcs.pw.test2.Alpha.Unary,
+ )
+ self.assertIs(
+ self._channel_client.method('pw.test2.Alpha.Unary'),
+ self._channel_client.rpcs.pw.test2.Alpha.Unary,
+ )
class ClientTest(unittest.TestCase):
"""Tests the pw_rpc Client independently of the ClientImpl."""
+
def setUp(self) -> None:
self._last_packet_sent_bytes: Optional[bytes] = None
self._protos, self._client = _test_setup(self._save_packet)
@@ -200,11 +219,17 @@ class ClientTest(unittest.TestCase):
def test_method_present(self) -> None:
self.assertIs(
- self._client.method('pw.test1.PublicService.SomeUnary'), self.
- _client.services['pw.test1.PublicService'].methods['SomeUnary'])
+ self._client.method('pw.test1.PublicService.SomeUnary'),
+ self._client.services['pw.test1.PublicService'].methods[
+ 'SomeUnary'
+ ],
+ )
self.assertIs(
- self._client.method('pw.test1.PublicService/SomeUnary'), self.
- _client.services['pw.test1.PublicService'].methods['SomeUnary'])
+ self._client.method('pw.test1.PublicService/SomeUnary'),
+ self._client.services['pw.test1.PublicService'].methods[
+ 'SomeUnary'
+ ],
+ )
def test_method_invalid_format(self) -> None:
with self.assertRaises(ValueError):
@@ -218,37 +243,48 @@ class ClientTest(unittest.TestCase):
self._client.method('nothing.Good')
def test_process_packet_invalid_proto_data(self) -> None:
- self.assertIs(self._client.process_packet(b'NOT a packet!'),
- Status.DATA_LOSS)
+ self.assertIs(
+ self._client.process_packet(b'NOT a packet!'), Status.DATA_LOSS
+ )
def test_process_packet_not_for_client(self) -> None:
self.assertIs(
self._client.process_packet(
- RpcPacket(type=PacketType.REQUEST).SerializeToString()),
- Status.INVALID_ARGUMENT)
+ RpcPacket(type=PacketType.REQUEST).SerializeToString()
+ ),
+ Status.INVALID_ARGUMENT,
+ )
def test_process_packet_unrecognized_channel(self) -> None:
self.assertIs(
self._client.process_packet(
packets.encode_response(
- (123, 456, 789),
- self._protos.packages.pw.test2.Request())),
- Status.NOT_FOUND)
+ (123, 456, 789), self._protos.packages.pw.test2.Request()
+ )
+ ),
+ Status.NOT_FOUND,
+ )
def test_process_packet_unrecognized_service(self) -> None:
self.assertIs(
self._client.process_packet(
packets.encode_response(
- (1, 456, 789), self._protos.packages.pw.test2.Request())),
- Status.OK)
+ (1, 456, 789), self._protos.packages.pw.test2.Request()
+ )
+ ),
+ Status.OK,
+ )
self.assertEqual(
self._last_packet_sent(),
- RpcPacket(type=PacketType.CLIENT_ERROR,
- channel_id=1,
- service_id=456,
- method_id=789,
- status=Status.NOT_FOUND.value))
+ RpcPacket(
+ type=PacketType.CLIENT_ERROR,
+ channel_id=1,
+ service_id=456,
+ method_id=789,
+ status=Status.NOT_FOUND.value,
+ ),
+ )
def test_process_packet_unrecognized_method(self) -> None:
service = next(iter(self._client.services))
@@ -257,15 +293,22 @@ class ClientTest(unittest.TestCase):
self._client.process_packet(
packets.encode_response(
(1, service.id, 789),
- self._protos.packages.pw.test2.Request())), Status.OK)
+ self._protos.packages.pw.test2.Request(),
+ )
+ ),
+ Status.OK,
+ )
self.assertEqual(
self._last_packet_sent(),
- RpcPacket(type=PacketType.CLIENT_ERROR,
- channel_id=1,
- service_id=service.id,
- method_id=789,
- status=Status.NOT_FOUND.value))
+ RpcPacket(
+ type=PacketType.CLIENT_ERROR,
+ channel_id=1,
+ service_id=service.id,
+ method_id=789,
+ status=Status.NOT_FOUND.value,
+ ),
+ )
def test_process_packet_non_pending_method(self) -> None:
service = next(iter(self._client.services))
@@ -275,26 +318,36 @@ class ClientTest(unittest.TestCase):
self._client.process_packet(
packets.encode_response(
(1, service.id, method.id),
- self._protos.packages.pw.test2.Request())), Status.OK)
+ self._protos.packages.pw.test2.Request(),
+ )
+ ),
+ Status.OK,
+ )
self.assertEqual(
self._last_packet_sent(),
- RpcPacket(type=PacketType.CLIENT_ERROR,
- channel_id=1,
- service_id=service.id,
- method_id=method.id,
- status=Status.FAILED_PRECONDITION.value))
+ RpcPacket(
+ type=PacketType.CLIENT_ERROR,
+ channel_id=1,
+ service_id=service.id,
+ method_id=method.id,
+ status=Status.FAILED_PRECONDITION.value,
+ ),
+ )
def test_process_packet_non_pending_calls_response_callback(self) -> None:
method = self._client.method('pw.test1.PublicService.SomeUnary')
reply = method.response_type(payload='hello')
- def response_callback(rpc: client.PendingRpc, message,
- status: Optional[Status]) -> None:
+ def response_callback(
+ rpc: client.PendingRpc, message, status: Optional[Status]
+ ) -> None:
self.assertEqual(
rpc,
client.PendingRpc(
- self._client.channel(1).channel, method.service, method))
+ self._client.channel(1).channel, method.service, method
+ ),
+ )
self.assertEqual(message, reply)
self.assertIs(status, Status.OK)
@@ -302,8 +355,10 @@ class ClientTest(unittest.TestCase):
self.assertIs(
self._client.process_packet(
- packets.encode_response((1, method.service, method), reply)),
- Status.OK)
+ packets.encode_response((1, method.service, method), reply)
+ ),
+ Status.OK,
+ )
if __name__ == '__main__':