diff options
Diffstat (limited to 'pw_rpc/py/pw_rpc/callback_client/impl.py')
-rw-r--r-- | pw_rpc/py/pw_rpc/callback_client/impl.py | 40 |
1 files changed, 33 insertions, 7 deletions
diff --git a/pw_rpc/py/pw_rpc/callback_client/impl.py b/pw_rpc/py/pw_rpc/callback_client/impl.py index 474757373..df9129861 100644 --- a/pw_rpc/py/pw_rpc/callback_client/impl.py +++ b/pw_rpc/py/pw_rpc/callback_client/impl.py @@ -18,6 +18,7 @@ import logging import textwrap from typing import Any, Callable, Dict, Iterable, Optional, Type +from dataclasses import dataclass from pw_status import Status from google.protobuf.message import Message @@ -44,6 +45,15 @@ from pw_rpc.callback_client.call import ( _LOG = logging.getLogger(__package__) +@dataclass(eq=True, frozen=True) +class CallInfo: + method: Method + + @property + def service(self) -> Service: + return self.method.service + + class _MethodClient: """A method that can be invoked for a particular channel.""" @@ -57,20 +67,21 @@ class _MethodClient: ) -> None: self._impl = client_impl self._rpcs = rpcs - self._rpc = PendingRpc(channel, method.service, method) + self._channel = channel + self._method = method self.default_timeout_s: Optional[float] = default_timeout_s @property def channel(self) -> Channel: - return self._rpc.channel + return self._channel @property def method(self) -> Method: - return self._rpc.method + return self._method @property def service(self) -> Service: - return self._rpc.service + return self._method.service @property def request(self) -> type: @@ -118,8 +129,17 @@ class _MethodClient: if timeout_s is UseDefault.VALUE: timeout_s = self.default_timeout_s + if self._impl.on_call_hook: + self._impl.on_call_hook(CallInfo(self._method)) + + rpc = PendingRpc( + self._channel, + self.service, + self.method, + self._rpcs.allocate_call_id(), + ) call = call_type( - self._rpcs, self._rpc, timeout_s, on_next, on_completed, on_error + self._rpcs, rpc, timeout_s, on_next, on_completed, on_error ) call._invoke(request, ignore_errors) # pylint: disable=protected-access return call @@ -378,18 +398,24 @@ asynchronously using the invoke method. class Impl(client.ClientImpl): - """Callback-based ClientImpl, for use with pw_rpc.Client.""" + """Callback-based ClientImpl, for use with pw_rpc.Client. + + Args: + on_call_hook: A callable object to handle RPC method calls. + If hook is set, it will be called before RPC execution. + """ def __init__( self, default_unary_timeout_s: Optional[float] = None, default_stream_timeout_s: Optional[float] = None, + on_call_hook: Optional[Callable[[CallInfo], Any]] = None, cancel_duplicate_calls: Optional[bool] = True, ) -> None: super().__init__() self._default_unary_timeout_s = default_unary_timeout_s self._default_stream_timeout_s = default_stream_timeout_s - + self.on_call_hook = on_call_hook # Temporary workaround for clients that rely on mulitple in-flight # instances of an RPC on the same channel, which is not supported. # TODO(hepler): Remove this option when clients have updated. |