aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc/py/tests/client_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_rpc/py/tests/client_test.py')
-rwxr-xr-xpw_rpc/py/tests/client_test.py117
1 files changed, 93 insertions, 24 deletions
diff --git a/pw_rpc/py/tests/client_test.py b/pw_rpc/py/tests/client_test.py
index 92a1f8236..28bea9069 100755
--- a/pw_rpc/py/tests/client_test.py
+++ b/pw_rpc/py/tests/client_test.py
@@ -15,7 +15,7 @@
"""Tests creating pw_rpc client."""
import unittest
-from typing import Optional
+from typing import Any, Callable, Optional
from pw_protobuf_compiler import python_protos
from pw_status import Status
@@ -24,6 +24,8 @@ from pw_rpc import callback_client, client, packets
import pw_rpc.ids
from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
+RpcIds = packets.RpcIds
+
TEST_PROTO_1 = """\
syntax = "proto3";
@@ -73,13 +75,30 @@ service Bravo {
}
"""
+SOME_CHANNEL_ID: int = 237
+SOME_SERVICE_ID: int = 193
+SOME_METHOD_ID: int = 769
+SOME_CALL_ID: int = 452
+
+CLIENT_FIRST_CHANNEL_ID: int = 557
+CLIENT_SECOND_CHANNEL_ID: int = 474
+
+
+def create_protos() -> Any:
+ return python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])
-def _test_setup(output=None):
- protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])
- return protos, client.Client.from_modules(
+
+def create_client(
+ proto_modules: Any,
+ first_channel_output_fn: Optional[Callable[[bytes], Any]] = None,
+) -> client.Client:
+ return client.Client.from_modules(
callback_client.Impl(),
- [client.Channel(1, output), client.Channel(2, lambda _: None)],
- protos.modules(),
+ [
+ client.Channel(CLIENT_FIRST_CHANNEL_ID, first_channel_output_fn),
+ client.Channel(CLIENT_SECOND_CHANNEL_ID, lambda _: None),
+ ],
+ proto_modules,
)
@@ -87,7 +106,10 @@ class ChannelClientTest(unittest.TestCase):
"""Tests the ChannelClient."""
def setUp(self) -> None:
- self._channel_client = _test_setup()[1].channel(1)
+ client_instance = create_client(create_protos().modules())
+ self._channel_client: client.ChannelClient = client_instance.channel(
+ CLIENT_FIRST_CHANNEL_ID
+ )
def test_access_service_client_as_attribute_or_index(self) -> None:
self.assertIs(
@@ -182,7 +204,8 @@ class ClientTest(unittest.TestCase):
def setUp(self) -> None:
self._last_packet_sent_bytes: Optional[bytes] = None
- self._protos, self._client = _test_setup(self._save_packet)
+ self._protos = create_protos()
+ self._client = create_client(self._protos.modules(), self._save_packet)
def _save_packet(self, packet) -> None:
self._last_packet_sent_bytes = packet
@@ -194,11 +217,19 @@ class ClientTest(unittest.TestCase):
return packet
def test_channel(self) -> None:
- self.assertEqual(self._client.channel(1).channel.id, 1)
- self.assertEqual(self._client.channel(2).channel.id, 2)
+ self.assertEqual(
+ self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel.id,
+ CLIENT_FIRST_CHANNEL_ID,
+ )
+ self.assertEqual(
+ self._client.channel(CLIENT_SECOND_CHANNEL_ID).channel.id,
+ CLIENT_SECOND_CHANNEL_ID,
+ )
def test_channel_default_is_first_listed(self) -> None:
- self.assertEqual(self._client.channel().channel.id, 1)
+ self.assertEqual(
+ self._client.channel().channel.id, CLIENT_FIRST_CHANNEL_ID
+ )
def test_channel_invalid(self) -> None:
with self.assertRaises(KeyError):
@@ -259,7 +290,13 @@ class ClientTest(unittest.TestCase):
self.assertIs(
self._client.process_packet(
packets.encode_response(
- (123, 456, 789), self._protos.packages.pw.test2.Request()
+ RpcIds(
+ SOME_CHANNEL_ID,
+ SOME_SERVICE_ID,
+ SOME_METHOD_ID,
+ SOME_CALL_ID,
+ ),
+ self._protos.packages.pw.test2.Request(),
)
),
Status.NOT_FOUND,
@@ -269,7 +306,13 @@ class ClientTest(unittest.TestCase):
self.assertIs(
self._client.process_packet(
packets.encode_response(
- (1, 456, 789), self._protos.packages.pw.test2.Request()
+ RpcIds(
+ CLIENT_FIRST_CHANNEL_ID,
+ SOME_SERVICE_ID,
+ SOME_METHOD_ID,
+ SOME_CALL_ID,
+ ),
+ self._protos.packages.pw.test2.Request(),
)
),
Status.OK,
@@ -279,9 +322,10 @@ class ClientTest(unittest.TestCase):
self._last_packet_sent(),
RpcPacket(
type=PacketType.CLIENT_ERROR,
- channel_id=1,
- service_id=456,
- method_id=789,
+ channel_id=CLIENT_FIRST_CHANNEL_ID,
+ service_id=SOME_SERVICE_ID,
+ method_id=SOME_METHOD_ID,
+ call_id=SOME_CALL_ID,
status=Status.NOT_FOUND.value,
),
)
@@ -292,7 +336,12 @@ class ClientTest(unittest.TestCase):
self.assertIs(
self._client.process_packet(
packets.encode_response(
- (1, service.id, 789),
+ RpcIds(
+ CLIENT_FIRST_CHANNEL_ID,
+ service.id,
+ SOME_METHOD_ID,
+ SOME_CALL_ID,
+ ),
self._protos.packages.pw.test2.Request(),
)
),
@@ -303,9 +352,10 @@ class ClientTest(unittest.TestCase):
self._last_packet_sent(),
RpcPacket(
type=PacketType.CLIENT_ERROR,
- channel_id=1,
+ channel_id=CLIENT_FIRST_CHANNEL_ID,
service_id=service.id,
- method_id=789,
+ method_id=SOME_METHOD_ID,
+ call_id=SOME_CALL_ID,
status=Status.NOT_FOUND.value,
),
)
@@ -317,7 +367,12 @@ class ClientTest(unittest.TestCase):
self.assertIs(
self._client.process_packet(
packets.encode_response(
- (1, service.id, method.id),
+ RpcIds(
+ CLIENT_FIRST_CHANNEL_ID,
+ service.id,
+ method.id,
+ SOME_CALL_ID,
+ ),
self._protos.packages.pw.test2.Request(),
)
),
@@ -328,9 +383,10 @@ class ClientTest(unittest.TestCase):
self._last_packet_sent(),
RpcPacket(
type=PacketType.CLIENT_ERROR,
- channel_id=1,
+ channel_id=CLIENT_FIRST_CHANNEL_ID,
service_id=service.id,
method_id=method.id,
+ call_id=SOME_CALL_ID,
status=Status.FAILED_PRECONDITION.value,
),
)
@@ -340,12 +396,17 @@ class ClientTest(unittest.TestCase):
reply = method.response_type(payload='hello')
def response_callback(
- rpc: client.PendingRpc, message, status: Optional[Status]
+ rpc: client.PendingRpc,
+ message,
+ status: Optional[Status],
) -> None:
self.assertEqual(
rpc,
client.PendingRpc(
- self._client.channel(1).channel, method.service, method
+ self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel,
+ method.service,
+ method,
+ call_id=SOME_CALL_ID,
),
)
self.assertEqual(message, reply)
@@ -355,7 +416,15 @@ class ClientTest(unittest.TestCase):
self.assertIs(
self._client.process_packet(
- packets.encode_response((1, method.service, method), reply)
+ packets.encode_response(
+ RpcIds(
+ CLIENT_FIRST_CHANNEL_ID,
+ method.service.id,
+ method.id,
+ SOME_CALL_ID,
+ ),
+ reply,
+ )
),
Status.OK,
)